Best way to save a trained model in PyTorch?

Wasi Ahmad picture Wasi Ahmad · Mar 9, 2017 · Viewed 151k times · Source

I was looking for alternative ways to save a trained model in PyTorch. So far, I have found two alternatives.

  1. torch.save() to save a model and torch.load() to load a model.
  2. model.state_dict() to save a trained model and model.load_state_dict() to load the saved model.

I have come across to this discussion where approach 2 is recommended over approach 1.

My question is, why the second approach is preferred? Is it only because torch.nn modules have those two function and we are encouraged to use them?

Answer

dontloo picture dontloo · May 6, 2017

I've found this page on their github repo, I'll just paste the content here.


Recommended approach for saving a model

There are two main approaches for serializing and restoring a model.

The first (recommended) saves and loads only the model parameters:

torch.save(the_model.state_dict(), PATH)

Then later:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

The second saves and loads the entire model:

torch.save(the_model, PATH)

Then later:

the_model = torch.load(PATH)

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.