example of doing simple prediction with pytorch-lightning

Luca picture Luca · May 3, 2020 · Viewed 7.3k times · Source

I have an existing model where I load some pre-trained weights and then do prediction (one image at a time) in pytorch. I am trying to basically convert it to a pytorch lightning module and am confused about a few things.

So currently, my __init__ method for the model looks like this:

self._load_config_file(cfg_file)
# just creates the pytorch network
self.create_network()  

self.load_weights(weights_file)

self.cuda(device=0)  # assumes GPU and uses one. This is probably suboptimal
self.eval()  # prediction mode

What I can gather from the lightning docs, I can pretty much do the same, except not to do the cuda() call. So something like:

self.create_network()

self.load_weights(weights_file)
self.freeze()  # prediction mode

So, my first question is whether this is the correct way to use lightning? How would lightning know if it needs to use the GPU? I am guessing this needs to be specified somewhere.

Now, for the prediction, I have the following setup:

def infer(frame):
    img = transform(frame)  # apply some transformation to the input
    img = torch.from_numpy(img).float().unsqueeze(0).cuda(device=0)
    with torch.no_grad():
        output = self.__call__(Variable(img)).data.cpu().numpy()
    return output

This is the bit that has me confused. Which functions do I need to override to make a lightning compatible prediction?

Also, at the moment, the input comes as a numpy array. Is that something that would be possible from the lightning module or do things always have to use some sort of a dataloader?

At some point, I want to extend this model implementation to do training as well, so want to make sure I do it right but while most examples focus on training models, a simple example of just doing prediction at production time on a single image/data point might be useful.

I am using 0.7.5 with pytorch 1.4.0 on GPU with cuda 10.1

Answer

jbencook picture jbencook · Jul 25, 2020

LightningModule is a subclass of torch.nn.Module so the same model class will work for both inference and training. For that reason, you should probably call the cuda() and eval() methods outside of __init__.

Since it's just a nn.Module under the hood, once you've loaded your weights you don't need to override any methods to perform inference, simply call the model instance. Here's a toy example you can use:

import torchvision.models as models
from pytorch_lightning.core import LightningModule

class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet18(pretrained=True, progress=False)
    
    def forward(self, x):
        return self.resnet(x)

model = MyModel().eval().cuda(device=0)

And then to actually run inference you don't need a method, just do something like:

for frame in video:
    img = transform(frame)
    img = torch.from_numpy(img).float().unsqueeze(0).cuda(0)
    output = model(img).data.cpu().numpy()
    # Do something with the output

The main benefit of PyTorchLighting is that you can also use the same class for training by implementing training_step(), configure_optimizers() and train_dataloader() on that class. You can find a simple example of that in the PyTorchLightning docs.