In my pytorch model, I'm initializing my model and optimizer like this.
model = MyModelClass(config, shape, x_tr_mean, x_tr,std)
optimizer = optim.SGD(model.parameters(), lr=config.learning_rate)
And here is the path to my checkpoint file.
checkpoint_file = os.path.join(config.save_dir, "checkpoint.pth")
To load this checkpoint file, I check and see if the checkpoint file exists and then I load it as well as the model and optimizer.
if os.path.exists(checkpoint_file):
if config.resume:
torch.load(checkpoint_file)
model.load_state_dict(torch.load(checkpoint_file))
optimizer.load_state_dict(torch.load(checkpoint_file))
Also, here's how I'm saving my model and optimizer.
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'iter_idx': iter_idx, 'best_va_acc': best_va_acc}, checkpoint_file)
For some reason I keep getting a strange error whenever I run this code.
model.load_state_dict(torch.load(checkpoint_file))
File "/home/Josh/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for MyModelClass:
Missing key(s) in state_dict: "mean", "std", "attribute.weight", "attribute.bias".
Unexpected key(s) in state_dict: "model", "optimizer", "iter_idx", "best_va_acc"
Does anyone know why I'm getting this error?
You saved the model parameters in a dictionary. You're supposed to use the keys, that you used while saving earlier, to load the model checkpoint and state_dict
s like this:
if os.path.exists(checkpoint_file):
if config.resume:
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
You can check the official tutorial on PyTorch website for more info.