PyTorch 模型保存Tips
PyTorch 模型保存方式(.pt, .pth, .pkl, .p)
最近在投稿论文,准备release code的时候,发现自己的pre-trained model 文件异常的大,竟然有2.5G,就觉得不对劲,要研究一下是不是由于模型保存的时候有什么问题造成的。
一、模型不同后缀名的区别
常见的pytorch模型保存的后缀名有.pt, .pth, .pkl,其并不是在格式上有区别,只是后缀名不同而已,也仅此而已。在用torch.save()函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth或.pkl.用相同的torch.save()语句保存出来的模型文件没有什么不同。
在pytorch官方的文档/代码里,有用.pt的,也有用.pth的。一般惯例是使用.pth,但是官方文档里貌似.pt更多,而且官方也不是很在意固定用一种。
二、模型保存与调用方式一
只保存模型参数,不保存模型结构。
保存:
torch.save(model.state_dict(), mymodel.pth)
# 只保存模型的权重参数,不保存模型结构
调用
model = My_model(*args, **kwargs) # 需要重新定义模型结构
model.load_state_dict(torch.load(mymodel.pth))
# 这里根据模型结构,调用存储的模型参数
model.eval()
三、模型保存与调用方式二
保存整个模型,包括模型结构+模型参数
保存
torch.save(model, mymodel.pth) #保存整个model的状态
调用
model = torch.load(mymodel.pth)
# 这里不需要重构模型结构,直接load
model.eval()
PyTorch中保存的模型文件.pth深入解析
探究一下,我们通常保存的模型文件 .pth 文件内部是什么?
一、 .pth 文件详解
如上文所述,在pytorch进行模型保存的时候,一般有两种保存方式,一种是保存这个模型,另一种是只保存模型的参数。保存的模型参数实际上是一个字典类型,通过key-value的形式来存储模型的所有参数。
1.1 .pth 文件基本信息查看。
import torch
pthfile = '/home/workspace/baseline_ckpt.pth'
net = torch.load(pthfile)
print(type(net)) # 类型是dict
print(len(net)) # 长度是4, 即存在4个key-value键值对
for k in net.keys():
print(k) # 查看四个键,分别是 model, optimizer, scheduler, iteraion
1.2 模型的四个键值详解
(1) net[‘model’]
print(net['model']) # 返回一个OrderedDict对象
for key, value in net['model'].items():
print(key, values.size(), sep='')
''' 运行结果 '''
module.backbone.body.stem.conv1.weight torch.Size([64, 3, 7, 7])
module.backbone.body.stem.bn1.weight torch.Size([64])
module.backbone.body.stem.bn1.bias torch.Size([64])
module.backbone.body.stem.bn1.running_mean torch.Size([64])
module.backbone.body.stem.bn1.running_var torch.Size([64])
module.backbone.body.layer1.0.downsample.0.weight torch.Size([256, 64, 1, 1])
module.backbone.body.layer1.0.downsample.1.weight torch.Size([256])
.....
总结:键model 所对应的是值是一个OrderedDict,而这个OrderedDict字典里边又存储着所有的每一层的参数名称以及对应的参数值。
需要注意的是,这里参数名称之所以很长,如:
module.backbone.body.stem.conv1.weight 是因为搭建网络结构的时候采用了组件式的设计,即整个模型里面构造了一个backbone的容器组件,backbone里面又构造了一个body容器组件,body里面又构造了一个stem容器,stem里面的第一个卷积层的权重。
(2)net[‘optimizer’]
# print(net["optimizer"]) # 返回的是一个一般的字典 Dict 对象
for key,value in net["optimizer"].items():
print(key,type(value),sep=" ")
'''运行结果为 '''
state <class 'dict'>
param_groups <class 'list'>
'''
发现这个这个字典只有两个key,一个是state,一个是param_groups
其中state所对应的值又是一个字典类型,
param_groups对应的值是一个列表
'''
先看一下**net[“optimizer”] [‘param_groups’] **这个列表里面放了一下啥:
groups=net["optimizer"]["param_groups"]
print(groups)
print(len(groups)) # 返回115.即在这个模型中,共有115组
'''
[{'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644061240]},
{'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644061960]},
{'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644062248]},
{'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644077336]},
.
.
.
{'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644061960]},
{'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566103171936]},
{'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566103172008]}]
这个列表的长度为115,每一个元素又是一个字典。
'''
再看一下**net[“optimizer”] [“states”] **这个字典里面放了啥:
state=net["optimizer"]["state"]
print(len(state)) # 返回115.即在这个模型中,state共有115组
for key,value in state.items():
print(key,type(value),sep=" ")
'''
140566644061240 <class 'dict'>
140566644061960 <class 'dict'>
140566644062248 <class 'dict'>
140566644077336 <class 'dict'>
.
.
.
140566103171936 <class 'dict'>
140566103172008 <class 'dict'>
这个字典的长度是115,而且和前面的param_groups有着对应关系,每一个元素的键值就是param_groups中每一个元素的params。
'''
(3) net [‘scheduler’]
scheduler=net["scheduler"] # 返回的依然是一个字典
print(len(scheduler)) # 字典的长度为 7
print(scheduler)
'''
{'milestones': (70000, 90000),
'gamma': 0.1,
'warmup_factor': 0.3333333333333333,
'warmup_iters': 500,
'warmup_method': 'linear',
'base_lrs': [0.005, 0.005, 0.005, 0.01, ......, 0.005, 0.005, 0.005, 0.005, 0.01],
'last_epoch': 99999}
继续看一下这个base_lrs的信息
'''
print(len(scheduler["base_lrs"])) # 返回115,→115个数组成的一个列表
(4) net[‘iteration’]
print(net["iteration"]) # 返回 9999 ,它是一个具体的数字
模型参数量与保存文件体积对应关系(估值)
Model | num_parameter | file_size |
---|---|---|
CNN from Cole.2017 | 0.8M | 54.9M |
SFCN | 3M | 11.28M |
ScaleDense | 102.38M | 2.5G |
3D-ResNet-18 | 33.2M | ~500M |
3D-ResNet-50 | 46.2M | ~500M |