I have a network which I want to train on some dataset (as an example, say CIFAR10
). I can create data loader object via
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
My question is as follows: Suppose I want to make several different training iterations. Let's say I want at first to train the network on all images in odd positions, then on all images in even positions and so on. In order to do that, I need to be able to access to those images. Unfortunately, it seems that trainset
does not allow such access. That is, trying to do trainset[:1000]
or more generally trainset[mask]
will throw an error.
I could do instead
trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]
and then
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
However, that will force me to create a new copy of the full dataset in each iteration (as I already changed trainset.train_data
so I will need to redefine trainset
). Is there some way to avoid it?
Ideally, I would like to have something "equivalent" to
trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
shuffle=True, num_workers=2)
torch.utils.data.Subset
is easier, supports shuffle
, and doesn't require writing your own sampler:
import torchvision
import torch
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=None)
evens = list(range(0, len(trainset), 2))
odds = list(range(1, len(trainset), 2))
trainset_1 = torch.utils.data.Subset(trainset, evens)
trainset_2 = torch.utils.data.Subset(trainset, odds)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
shuffle=True, num_workers=2)