This is potentially a very easy question. I just started with PyTorch lightning and can't figure out how to receive the output of my model after training.
I am interested in both predictions of y_train and y_test as an array of some sort (PyTorch tensor or NumPy array in a later step) to plot next to the labels using different scripts.
dataset = Dataset(train_tensor)
val_dataset = Dataset(val_tensor)
training_generator = torch.utils.data.DataLoader(dataset, **train_params)
val_generator = torch.utils.data.DataLoader(val_dataset, **val_params)
mynet = Net(feature_len)
trainer = pl.Trainer(gpus=0,max_epochs=max_epochs, logger=logger, progress_bar_refresh_rate=20, callbacks=[early_stop_callback], num_sanity_val_steps=0)
trainer.fit(mynet)
In my lightning module I have the functions:
def __init__(self, random_inputs):
def forward(self, x):
def train_dataloader(self):
def val_dataloader(self):
def training_step(self, batch, batch_nb):
def training_epoch_end(self, outputs):
def validation_step(self, batch, batch_nb):
def validation_epoch_end(self, outputs):
def configure_optimizers(self):
Do I need a specific predict function or is there any already implemented way I don't see?
I disagree with these answers: OP's question appears to be focused on how he should use a model trained in lightning to get predictions in general, rather than for a specific step in the training pipeline. In which case, a user shouldn't need to go anywhere near a Trainer object - those are not intended to be used for general prediction and the answers above are therefore encouraging an anti-pattern (carrying a trainer object around with us every time we want to do some prediction) to anyone who reads these answers in the future.
Instead of using trainer
, we can get predictions straight from the Lightning module that has been defined: if I have my (trained) instance of the lightning module model = Net(...)
then using that model to get predictions on inputs x
is achieved simply by calling model(x)
(so long as the forward
method has been implemented/overriden on the Lightning module - which is required).
In contrast, Trainer.predict()
is not the intended means of obtaining predictions using your trained model in general. The Trainer API provides methods to tune
, fit
and test
your LightningModule as part of your training pipeline, and it looks to me that the predict
method is provided for ad-hoc predictions on separate dataloaders as part of less 'standard' training steps.
The OP's question (Do I need a specific predict function or is there any already implemented way I don't see?) implies that they're not familiar with the way that the forward()
method works in PyTorch, but asks whether there's already a method for prediction that they can't see. A full answer therefore requires a further explanation of where the forward()
method fits into the prediction process:
The reason model(x)
works is because Lightning Modules are subclasses of torch.nn.Module
and these implement a magic method called __call__()
which means that we can call the class instance as if it were a function. __call__()
in turn calls forward()
, which is why we need to override that method in our Lightning module.
NB. because forward
is only one piece of the logic called when we use model(x)
, it is always recommended to use model(x)
instead of model.forward(x)
for prediction unless you have a specific reason to deviate.