方式1:
model_save.py代码如下:

1
2
3
4
5
6
7
8
import torch
import torchvision

# 导入vgg16网络
vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存模型
torch.save(vgg16, "vgg16_method1.pth")

model_load.py代码如下:

1
2
3
4
5
6
7
8
import torch

# 使用这种加载方式时需要将网络的类引用进来
from model_save import *

# 加载模型
model = torch.load("vgg16_method1.pth")


vgg16_method1.pth是模型结构+模型参数。
这种方式保存了整个网络的结构和其中的参数,所以保存的模型比较大。


方式2:

1
2
3
4
5
6
import torch
import torchvision

# 保存模型,只保存参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")


model_load.py代码如下:
1
2
3
4
5
6
7
8
9
import torch

# 导入vgg16网络
vgg16 = torchvision.models.vgg16(pretrained=False)

# 加载模型
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")


vgg16_method2.pth只有模型参数。
这种方式以字典的形式来保存网络的参数,因为只有参数没有网络,所以保存的模型较小。官方推荐使用。