How do you alter the size of a Pytorch Dataset?
It is important to note that when you create the DataLoader
object, it doesnt immediately load all of your data (its impractical for large datasets). It provides you an iterator that you can use to access each sample.
Unfortunately, DataLoader
doesnt provide you with any way to control the number of samples you wish to extract. You will have to use the typical ways of slicing iterators.
Simplest thing to do (without any libraries) would be to stop after the required number of samples is reached.
nsamples = 10000
for i, image, label in enumerate(train_loader):
if i > nsamples:
break
# Your training code here.
Or, you could use itertools.islice
to get the first 10k samples. Like so.
for image, label in itertools.islice(train_loader, stop=10000):
# your training code here.
Another quick way of slicing dataset is by using torch.utils.data.random_split()
(supported in PyTorch v0.4.1+). It helps in randomly splitting a dataset into non-overlapping new datasets of given lengths.
So we can have something like the following:
tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)
part_tr = torch.utils.data.random_split(tr, [tr_split_len, len(tr)-tr_split_len])[0]
part_te = torch.utils.data.random_split(te, [te_split_len, len(te)-te_split_len])[0]
train_loader = DataLoader(part_tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(part_te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
here you can set tr_split_len
and te_split_len
as the required split lengths for training and testing datasets respectively.
You can use torch.utils.data.Subset()
e.g. for the first 10,000 elements:
import torch.utils.data as data_utils
indices = torch.arange(10000)
tr_10k = data_utils.Subset(tr, indices)