神剑山庄资源网 Design By www.hcban.com
先说结论
model.state_dict()
是浅拷贝,返回的参数仍然会随着网络的训练而变化。
应该使用deepcopy(model.state_dict())
,或将参数及时序列化到硬盘。
再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。
补充:pytorch中state_dict的理解
在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
其实看了如下代码的输出应该就懂了
import torch import torch.nn as nn import torchvision import numpy as np from torchsummary import summary # Define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,"\t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
输出如下:
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]
我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!
补充:pytorch保存模型时报错***object has no attribute 'state_dict'
定义了一个类BaseNet并实例化该类:
net=BaseNet()
保存net时报错 object has no attribute 'state_dict'
torch.save(net.state_dict(), models_dir)
原因是定义类的时候不是继承nn.Module类,比如:
class BaseNet(object): def __init__(self):
把类定义改为
class BaseNet(nn.Module): def __init__(self): super(BaseNet, self).__init__()
以上为个人经验,希望能给大家一个参考,也希望大家多多支持。如有错误或未考虑完全的地方,望不吝赐教。
神剑山庄资源网 Design By www.hcban.com
神剑山庄资源网
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件!
如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
神剑山庄资源网 Design By www.hcban.com
暂无解决pytorch 的state_dict()拷贝问题的评论...
更新日志
2024年10月04日
2024年10月04日
- 邓丽君.1983-漫步人生路(2024环球MQA-UHQCD限量版)【环球】【WAV+CUE】
- 陈容森.1996-情断【上华】【WAV+CUE】
- 裘海正.1994-爱我的人和我爱的人【上华】【WAV+CUE】
- 庾澄庆.1988-错过的爱【福茂】【WAV+CUE】
- 班得瑞原装进口《第十三张新世纪专辑:旭日之丘》1CD[APE/CUE分轨][329.6MB]
- 纯音入心系列纯音乐《古筝|佛蕴|境法禅心意法自然》1CD[MP3][350MB]
- 纪钧瀚《钢琴阅读时光 雨中书店聆听轻音乐》[320K/MP3][203.44MB]
- 群星.1981-名曲65(2014环球复黑王·百代篇)【EMI百代】【WAV+CUE】
- 陈淑桦.1990-娃娃的故事【柯达】【WAV+CUE】
- 戴梅君.2011-问签诗【美华】【WAV+CUE】
- 戴梅君.2011-问签诗【美华】【WAV+CUE】
- 李国祥.1995-九五变奏【嘉音】【WAV+CUE】
- 许景淳.1992-你来自何方【全美唱片】【WAV+CUE】
- 石欣卉.2007-剧欣卉集·完整电视剧主题精丫华纳】【WAV+CUE】
- 群星.2005-LOVE情歌集VOL.5.2CD【正东】【WAV+CUE】