How do you alter the size of a Pytorch Dataset?

mikal94305 picture mikal94305 · Jul 1, 2017 · Viewed 8.2k times · Source

Say I am loading MNIST from torchvision.datasets.MNIST, but I only want to load in 10000 images total, how would I slice the data to limit it to only some number of data points? I understand that the DataLoader is a generator yielding data in the size of the specified batch size, but how do you slice datasets?

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)
train_loader = DataLoader(tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)

Answer

entrophy picture entrophy · Jul 8, 2017

It is important to note that when you create the DataLoader object, it doesnt immediately load all of your data (its impractical for large datasets). It provides you an iterator that you can use to access each sample.

Unfortunately, DataLoader doesnt provide you with any way to control the number of samples you wish to extract. You will have to use the typical ways of slicing iterators.

Simplest thing to do (without any libraries) would be to stop after the required number of samples is reached.

nsamples = 10000
for i, image, label in enumerate(train_loader):
    if i > nsamples:
        break

    # Your training code here.

Or, you could use itertools.islice to get the first 10k samples. Like so.

for image, label in itertools.islice(train_loader, stop=10000):

    # your training code here.