I have an existing model where I load some pre-trained weights and then do prediction (one image at a time) in pytorch. I am trying to basically convert it to a pytorch lightning module and am confused about a few things.
So currently, my __init__
method for the model looks like this:
self._load_config_file(cfg_file)
# just creates the pytorch network
self.create_network()
self.load_weights(weights_file)
self.cuda(device=0) # assumes GPU and uses one. This is probably suboptimal
self.eval() # prediction mode
What I can gather from the lightning docs, I can pretty much do the same, except not to do the cuda()
call. So something like:
self.create_network()
self.load_weights(weights_file)
self.freeze() # prediction mode
So, my first question is whether this is the correct way to use lightning? How would lightning know if it needs to use the GPU? I am guessing this needs to be specified somewhere.
Now, for the prediction, I have the following setup:
def infer(frame):
img = transform(frame) # apply some transformation to the input
img = torch.from_numpy(img).float().unsqueeze(0).cuda(device=0)
with torch.no_grad():
output = self.__call__(Variable(img)).data.cpu().numpy()
return output
This is the bit that has me confused. Which functions do I need to override to make a lightning compatible prediction?
Also, at the moment, the input comes as a numpy array. Is that something that would be possible from the lightning module or do things always have to use some sort of a dataloader?
At some point, I want to extend this model implementation to do training as well, so want to make sure I do it right but while most examples focus on training models, a simple example of just doing prediction at production time on a single image/data point might be useful.
I am using 0.7.5 with pytorch 1.4.0 on GPU with cuda 10.1
LightningModule
is a subclass of torch.nn.Module
so the same model class will work for both inference and training. For that reason, you should probably call the cuda()
and eval()
methods outside of __init__
.
Since it's just a nn.Module
under the hood, once you've loaded your weights you don't need to override any methods to perform inference, simply call the model instance. Here's a toy example you can use:
import torchvision.models as models
from pytorch_lightning.core import LightningModule
class MyModel(LightningModule):
def __init__(self):
super().__init__()
self.resnet = models.resnet18(pretrained=True, progress=False)
def forward(self, x):
return self.resnet(x)
model = MyModel().eval().cuda(device=0)
And then to actually run inference you don't need a method, just do something like:
for frame in video:
img = transform(frame)
img = torch.from_numpy(img).float().unsqueeze(0).cuda(0)
output = model(img).data.cpu().numpy()
# Do something with the output
The main benefit of PyTorchLighting is that you can also use the same class for training by implementing training_step()
, configure_optimizers()
and train_dataloader()
on that class. You can find a simple example of that in the PyTorchLightning docs.