方式1:
model_save.py代码如下:
1 | import torch |
model_load.py代码如下:1
2
3
4
5
6
7
8import torch
# 使用这种加载方式时需要将网络的类引用进来
from model_save import *
# 加载模型
model = torch.load("vgg16_method1.pth")
vgg16_method1.pth是模型结构+模型参数。
这种方式保存了整个网络的结构和其中的参数,所以保存的模型比较大。
方式2:1
2
3
4
5
6import torch
import torchvision
# 保存模型,只保存参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
model_load.py代码如下:1
2
3
4
5
6
7
8
9import torch
# 导入vgg16网络
vgg16 = torchvision.models.vgg16(pretrained=False)
# 加载模型
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")
vgg16_method2.pth只有模型参数。
这种方式以字典的形式来保存网络的参数,因为只有参数没有网络,所以保存的模型较小。官方推荐使用。