'DataLoader' object does not support indexing

Farshid Rayhan picture Farshid Rayhan · Jul 1, 2019 · Viewed 11.9k times · Source

I have downloaded the ImageNet dataset via this pytorch api by setting download=True. But I cannot iterate through the dataloader.

The error says "'DataLoader' object does not support indexing"

trainset = torch.utils.data.DataLoader(
    datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
                      download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)

I tried a simple approach I just tried to run the following,

trainloader[0]

In the root directory, the pattern is

root/  
    train/  
          n01440764/
          n01443537/ 
                   n01443537_2.jpg

The docs in the official website doesnt say anything else. https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet

What am I doing wrong ?

Answer

Szymon Maszke picture Szymon Maszke · Jul 2, 2019

Well, the answer is pretty simple (besides error mentioned in the other answer).

DataLoader has no __getitem__ method (see in the source code for yourself).

It is used for iterating, not random access, over data (or batches of data). If you want to access specific element you should use torch.utils.data.Dataset, in your case:

trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]

Getting a batch

If you want to get a batch you may iterate over it and break afterwards:

for batch in dataloader:
    print(batch) # or anything else you want to do
    break

DataLoader creates random indices in default or specified way (see samplers), hence there is no __getitem__ as it wouldn't make sense for this object.

You may also inherit from the DataLoader and create your own __getitem__ function doing what you want (more complicated though).

Full example

# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)

for batch in trainloader:
    print(batch)
    break

Above should print the first batch whatever is inside.