# [PyTorch 学习笔记] 7.1 模型保存与加载

## 序列化与反序列化

• 序列化是指将内存中的数据以二进制序列的方式保存到硬盘中。PyTorch 的模型保存就是序列化。

• 反序列化是指将硬盘中的二进制序列加载到内存中，得到模型的对象。PyTorch 的模型加载就是反序列化。

## torch.save

`torch.save(obj, f, pickle_module, pickle_protocol=2, _use_new_zipfile_serialization=False)`

• obj：保存的对象，可以是模型。也可以是 dict。因为一般在保存模型时，不仅要保存模型，还需要保存优化器、此时对应的 epoch 等参数。这时就可以用 dict 包装起来。
• f：输出路径

### 保存整个 Module

`torch.savev(net, path)`

### 只保存模型的参数

```state_sict = net.state_dict()
torch.savev(state_sict, path)```

```import torch
import numpy as np
import torch.nn as nn
from common_tools import set_seed
class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
def initialize(self):
for p in self.parameters():
p.data.fill_(2020)
net = LeNet2(classes=2019)
# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"
# 保存整个模型
torch.save(net, path_model)
# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)```

model_state_dict.pkl`，分别保存了整个网络和网络的参数

`torch.load(f, map_location=None, pickle_module, **pickle_load_args)`

• f：文件路径
• map_location：指定存在 CPU 或者 GPU。

### 加载整个 Module

```path_model = "./model.pkl"

```LeNet2(
(features): Sequential(
(0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): ReLU()
(2): Linear(in_features=120, out_features=84, bias=True)
(3): ReLU()
(4): Linear(in_features=84, out_features=2019, bias=True)
)
)```

### 只加载模型的参数

```path_state_dict = "./model_state_dict.pkl"
net_new = LeNet2(classes=2019)
print("加载前: ", net_new.features[0].weight[0, ...])
print("加载后: ", net_new.features[0].weight[0, ...])```

## 模型的断点续训练

，这样如果意外终止训练了，下次就可以重新加载最新的 模型参数和优化器的参数
，在这个基础上继续训练。

```if (epoch+1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)```

```if epoch > 5:
print("训练意外中断...")
break```

```path_checkpoint = "./checkpoint_4_epoch.pkl"
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch```