How do I save a trained model in PyTorch?
I think it’s because torch.save() save all the intermediate variables as well, like intermediate outputs for back propagation use. But you only need to save the model parameters, like weight/bias etc. Sometimes the former can be much larger than the latter.
I tested torch.save(model, f) and torch.save(model.state_dict(), f) . The saved files have the same size. Now I am confused. Also, I found using pickle to save model.state_dict() extremely slow. I think the best way is to use torch.save(model.state_dict(), f) since you handle the creation of the model, and torch handles the loading of the model weights, thus eliminating possible issues. Reference: discuss.pytorch.org/t/saving-torch-models/838/4
Seems like PyTorch have addressed this a bit more explicitly in their tutorials section—there’s lots of good info there that’s not listed in the answers here, including saving more than one model at a time and warm starting models.
@CharlieParker torch.save is based on pickle. The following is from the tutorial linked above: «[torch.save] will save the entire module using Python’s pickle module. The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. The reason for this is because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code can break in various ways when used in other projects or after refactors.»
10 Answers 10
Found this page on their github repo:
Recommended approach for saving a model
There are two main approaches for serializing and restoring a model.
The first (recommended) saves and loads only the model parameters:
torch.save(the_model.state_dict(), PATH)
the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))
The second saves and loads the entire model:
However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.
See also: Save and Load the Model section from the official PyTorch tutorials.
According to @smth discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/… model reloads to train model by default. so need to manually call the_model.eval() after loading, if you are loading it for inference, not resuming training.
the second method gives stackoverflow.com/questions/53798009/… error on windows 10. wasn’t able to solve it
With that approach how do you keep track of the *args and **kwargs you need to pass in for the load case?
@dontloo the_model = TheModelClass(*args, **kwargs). Running this command says NameError: name ‘TheModelClass’ is not defined. How should I go about this
Hi guys, could anyone tell me what is the extension for model dict file(.pth?) and the extension for the entire model file(.pkl)?? Am I correct?
It depends on what you want to do.
Case # 1: Save the model to use it yourself for inference: You save the model, you restore it, and then you change the model to evaluation mode. This is done because you usually have BatchNorm and Dropout layers that by default are in train mode on construction:
torch.save(model.state_dict(), filepath) #Later to restore: model.load_state_dict(torch.load(filepath)) model.eval()
Case # 2: Save model to resume training later: If you need to keep training the model that you are about to save, you need to save more than just the model. You also need to save the state of the optimizer, epochs, score, etc. You would do it like this:
state = < 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), . >torch.save(state, filepath)
To resume training you would do things like: state = torch.load(filepath) , and then, to restore the state of each individual object, something like this:
model.load_state_dict(state['state_dict']) optimizer.load_state_dict(state['optimizer'])
Since you are resuming training, DO NOT call model.eval() once you restore the states when loading.
Case # 3: Model to be used by someone else with no access to your code: In Tensorflow you can create a .pb file that defines both the architecture and the weights of the model. This is very handy, specially when using Tensorflow serve . The equivalent way to do this in Pytorch would be:
torch.save(model, filepath) # Then later: model = torch.load(filepath)
This way is still not bullet proof and since pytorch is still undergoing a lot of changes, I wouldn’t recommend it.
In the Case #3 torch.load returns just an OrderedDict. How do you get the model in order to make predictions?
Hi, May I know how to do the mentioned «Case # 2: Save model to resume training later»? I managed to load the checkpoint to model, then I unable to run or resume to train model like «model.to(device) model = train_model_epoch(model, criterion, optimizer, sched, epochs)»
Hi, for case one which is for inference, in the official pytorch doc say that must save optimizer state_dict for either inference or completing training. «When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just the model’s state_dict. It is important to also save the optimizer’s state_dict, as this contains buffers and parameters that are updated as the model trains. «
The pickle Python library implements binary protocols for serializing and de-serializing a Python object.
When you import torch (or when you use PyTorch) it will import pickle for you and you don’t need to call pickle.dump() and pickle.load() directly, which are the methods to save and to load the object.
In fact, torch.save() and torch.load() will wrap pickle.dump() and pickle.load() for you.
A state_dict the other answer mentioned deserves just a few more notes.
What state_dict do we have inside PyTorch? There are actually two state_dict s.
The PyTorch model is torch.nn.Module which has model.parameters() call to get learnable parameters (w and b). These learnable parameters, once randomly set, will update over time as we learn. Learnable parameters are the first state_dict .
The second state_dict is the optimizer state dict. You recall that the optimizer is used to improve our learnable parameters. But the optimizer state_dict is fixed. Nothing to learn there.
Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.
Let’s create a super simple model to explain this:
import torch import torch.optim as optim model = torch.nn.Linear(5, 2) # Initialize optimizer optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) print("Model weight:") print(model.weight) print("Model bias:") print(model.bias) print("---") print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
This code will output the following:
Model's state_dict: weight torch.Size([2, 5]) bias torch.Size([2]) Model weight: Parameter containing: tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316], [ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True) Model bias: Parameter containing: tensor([ 0.4112, -0.0733], requires_grad=True) --- Optimizer's state_dict: state <> param_groups []
Note this is a minimal model. You may try to add stack of sequential
model = torch.nn.Sequential( torch.nn.Linear(D_in, H), torch.nn.Conv2d(A, B, C) torch.nn.Linear(H, D_out), )
Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm layers) have entries in the model’s state_dict .
Non-learnable things belong to the optimizer object state_dict , which contains information about the optimizer’s state, as well as the hyperparameters used.
The rest of the story is the same; in the inference phase (this is a phase when we use the model after training) for predicting; we do predict based on the parameters we learned. So for the inference, we just need to save the parameters model.state_dict() .
torch.save(model.state_dict(), filepath)
And to use later model.load_state_dict(torch.load(filepath)) model.eval()
Note: Don’t forget the last line model.eval() this is crucial after loading the model.
Also don’t try to save torch.save(model.parameters(), filepath) . The model.parameters() is just the generator object.
On the other hand, torch.save(model, filepath) saves the model object itself, but keep in mind the model doesn’t have the optimizer’s state_dict . Check the other excellent answer by @Jadiel de Armas to save the optimizer’s state dict.
Save and Load the Model¶
In this section we will look at how to persist model state with saving, loading and running model predictions.
import torch import torchvision.models as models
Saving and Loading Model Weights¶
PyTorch models store the learned parameters in an internal state dictionary, called state_dict . These can be persisted via the torch.save method:
model = models.vgg16(weights='IMAGENET1K_V1') torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/vgg16-397923af.pth 0%| | 0.00/528M [00:00To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model model.load_state_dict(torch.load('model_weights.pth')) model.eval()VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) )be sure to call model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.
Saving and Loading Models with Shapes¶
When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass model (and not model.state_dict() ) to the saving function:
We can then load the model like this:
This approach uses Python pickle module when serializing the model, thus it relies on the actual class definition to be available when loading the model.
Related Tutorials¶
Total running time of the script: ( 0 minutes 7.581 seconds)