preface
This article is reproduced in PyTorch deep analysis: how to save and load PyTorch model?
The saving and loading methods of PyTorch model are explained in detail.
catalogue
1. You need to master three important functions
2.1 state_ Introduction to Dict
2.2 saving and loading state_dict (already trained, no need to continue training)
2.3 save and load the whole model (after training, there is no need to continue training)
2.4 saving and loading state_dict (you will continue training if you haven't finished training)
2.5 saving multiple models into one file
2.6 warm start your own model using the parameters of other models
2.7 save in GPU and load into CPU
2.8 save in GPU and load into GPU
2.9 save in CPU and load into GPU
1. You need to master three important functions
1) torch.save: save a serialized object to disk. This function is serialized using Python's pickle tool. Models, tensors, and dictionaries of various objects can be saved with this function.
2) torch.load: deserialize the pickled object file into memory, which is also convenient for loading data into the device.
3) torch.nn.Module.load_state_dict(): load the parameters of the model.
2 state_dict
2.1 state_ Introduction to Dict
In PyTorch, torch nn. The learnable parameters (weights and bias) in the module are placed in the model Parameters(). And state_dict is a Python dictionary object that maps each layer to its parameter tensor. Note: only layers with learnable parameters (linear layers) or registered buffers (batchnorm's running_mean) have state_dict. The optimizer's object (torch.optim) also has a state_dict, which stores the state of the optimizer and its super parameters.
Because state_dict is a Python dictionary object, so it is easier to save, load and update it.
Let's intuitively feel the state through an example_ Usage of dict:
# Define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
Output:
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
2.2 saving and loading state_dict (already trained, no need to continue training)
preservation:
torch.save(model.state_dict(), PATH)
load:
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
Generally saved as pt or pth # format file.
be careful:
-
You can use model Eval() sets the dropout and batch normalization layers to evaluation mode.
-
load_ state_ The dict () function requires an input of type dict, rather than the path that holds the model. So here's the model load_ state_ Dict (path) is wrong and should be model load_ state_ dict(torch.load(PATH)).
-
If you want to save the best performing model on the verification machine, this is the best_model_state=model.state_dict() is wrong. Because this is a shallow copy, that is, the best at this time_ model_ The state will be updated continuously with the subsequent training process. In fact, the last saved model is an overfit model. So the right thing to do is best_model_state=deepcopy(model.state_dict()).
2.3 save and load the whole model (after training, there is no need to continue training)
preservation:
torch.save(model, PATH)
load:
# Model class must be defined somewhere model = torch.load(PATH) model.eval()
Generally saved as pt or pth format file.
be careful:
-
You can use model Eval() sets the dropout and batch normalization layers to evaluation mode.
2.4 saving and loading state_dict (you will continue training if you haven't finished training)
preservation:
torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH)
The difference from 2.2 is that in addition to saving the model_ state_ In addition to dict, you also need to save: optimizer_state_dict, epoch and loss, because you need to know the status of the optimizer, epoch and so on when you continue training.
load:
model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train()
The difference from 2.2 is that in addition to loading the model_ state_ In addition to dict, you also need to load: optimizer_state_dict, epoch and loss.
2.5 saving multiple models into one file
preservation:
torch.save({ 'modelA_state_dict': modelA.state_dict(), 'modelB_state_dict': modelB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), ... }, PATH)
Put the state of models A and B_ Both dict and optimizer are stored in one file.
load:
modelA = TheModelAClass(*args, **kwargs) modelB = TheModelBClass(*args, **kwargs) optimizerA = TheOptimizerAClass(*args, **kwargs) optimizerB = TheOptimizerBClass(*args, **kwargs) checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) modelA.eval() modelB.eval() # - or - modelA.train() modelB.train()
2.6 warm start your own model using the parameters of other models
Sometimes when training a new complex model, it is necessary to load part of its pre training weight. Even if there are only a few available parameters, it will help the warm start training process and help the model achieve convergence faster.
If you have this state in your hand_ Dict lacks or has more keys. As long as the strict parameter is set to False, state can be set_ Dict can load the matching keys and ignore the non matching keys.
Save the state of model A_ dict :
torch.save(modelA.state_dict(), PATH)
Load into model B:
modelB = TheModelBClass(*args, **kwargs) modelB.load_state_dict(torch.load(PATH), strict=False)
2.7 save in GPU and load into CPU
preservation:
torch.save(model.state_dict(), PATH)
load:
device = torch.device('cpu') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
In this case, model state_ After dict() is saved, it is stored in GPU and directly torch Load (path) is loaded into the GPU. Therefore, if you want to load it into the CPU, you need to add map_location=torch.device('cpu').
2.8 save in GPU and load into GPU
preservation:
torch.save(model.state_dict(), PATH)
load:
map_location="cuda:0"device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
In this case, model state_ After dict() is saved, it is stored in GPU and directly torch Load (path) is loaded into the GPU. Therefore, if you want to load into GPU, you don't need to add map_location=device. Because the model needs to be loaded into the GPU and reinitialized (in the CPU), the model is required to(device).
2.9 save in CPU and load into GPU
preservation:
torch.save(model.state_dict(), PATH)
load:
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
In this case, model state_ After dict() is saved, it is stored directly in the CPU Load (path) is loaded into the CPU. Therefore, if you want to load it into the GPU, you need to add a map_location="cuda:0" . Because the model needs to be loaded into the GPU and reinitialized (in the CPU), the model is required to(device).