4.5 读取和存储
读写Tensor
可以直接使用save
函数和load
函数分别存储和读取Tensor
。save
使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,使用save
可以保存各种对象,包括模型、张量和字典等。而load
使用pickle unpickle工具将pickle的对象文件反序列化为内存。
下面的例子创建了Tensor
变量x
,并将其存在文件名同为x.pt
的文件里。
1 | import torch |
然后将数据从存储的文件读回内存。
1 | x2 = torch.load('x.pt') |
输出:
1 | tensor([1., 1., 1.]) |
还可以存储一个Tensor
列表并读回内存。
1 | y = torch.zeros(4) |
输出:
1 | [tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])] |
存储并读取一个从字符串映射到Tensor
的字典。
1 | torch.save({'x': x, 'y': y}, 'xy_dict.pt') |
输出:
1 | {'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])} |
读写模型
state_dict
在PyTorch中,Module
的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()
访问)。state_dict
是一个从参数名称隐射到参数Tesnor
的字典对象。
1 | class MLP(nn.Module): |
输出:
1 | OrderedDict([('hidden.weight', tensor([[ 0.2448, 0.1856, -0.5678], |
注意,只有具有可学习参数的层(卷积层、线性层等)才有state_dict
中的条目。优化器(optim
)也有一个state_dict
,其中包含关于优化器状态以及所使用的超参数的信息。
1 | optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) |
输出:
1 | {'param_groups': [{'dampening': 0, |
保存和加载模型
PyTorch中保存和加载训练模型有两种常见的方法:
- 仅保存和加载模型参数(
state_dict
); - 保存和加载整个模型。
无论是保存整个模型还是只保存模型参数,在加载模型时都需要在代码中定义出模型的结构。
1. 保存和加载state_dict
(推荐方式)
保存:
1 | torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth |
加载:
1 | model = TheModelClass(*args, **kwargs) |
2. 保存和加载整个模型
保存:
1 | torch.save(model, PATH) |
加载:
1 | model = torch.load(PATH) |
注意,模型类必须一致 。保存整个模型时,模型类的定义会被序列化。因此,在加载模型时,模型类的定义必须与保存时完全一致。如果模型类的定义发生了变化(例如修改了模型的结构或方法或模型类的名称),可能会导致加载失败或行为异常。保存整个模型时,模型的结构、参数、优化器状态等都会被保存下来,适合需要完整恢复训练场景的情况。但是灵活性较低 ,由于保存的是整个模型对象,因此在加载时要求模型类的定义必须与保存时一致。如果模型类的定义发生变化,可能会导致加载失败。
采用方法一来实验一下:
1 | X = torch.randn(2, 3) |
输出:
1 | tensor([[1], |
因为这net
和net2
都有同样的模型参数,那么对同一个输入X
的计算结果将会是一样的。上面的输出也验证了这一点。