What are transforms in PyTorch used for?

carioka88 picture carioka88 · Apr 24, 2018 · Viewed 23.4k times · Source

I am new with Pytorch and not very expert in CNN. I have done a successful classifier with the tutorial that they provide Tutorial Pytorch, but I don't really understand what I am doing when loading the data.

They do some data augmentation and normalisation for training, but when I try to modify the parameters, the code does not work.

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

Am I extending my training dataset? I don't see the data augmentation.

Why if I modify the value of transforms.RandomResizedCrop(224) the data loading stop working?

Do I need to transform as well the test dataset?

I am a bit confused with this data transformation that they do.

Answer

layog picture layog · Apr 24, 2018

transforms.Compose just clubs all the transforms provided to it. So, all the transforms in the transforms.Compose are applied to the input one by one.

Train transforms

  1. transforms.RandomResizedCrop(224): This will extract a patch of size (224, 224) from your input image randomly. So, it might pick this path from topleft, bottomright or anywhere in between. So, you are doing data augmentation in this part. Also, changing this value won't play nice with the fully-connected layers in your model, so not advised to change this.
  2. transforms.RandomHorizontalFlip(): Once we have our image of size (224, 224), we can choose to flip it. This is another part of data augmentation.
  3. transforms.ToTensor(): This just converts your input image to PyTorch tensor.
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]): This is just input data scaling and these values (mean and std) must have been precomputed for your dataset. Changing these values is also not advised.

Validation transforms

  1. transforms.Resize(256): First your input image is resized to be of size (256, 256)
  2. transforms.CentreCrop(224): Crops the center part of the image of shape (224, 224)

Rest are the same as train

P.S.: You can read more about these transformations in the official docs