程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2024-11(2)

【PyTorch】state_dict详解

发布于2019-08-07 11:53     阅读(1120)     评论(0)     点赞(0)     收藏(2)


Introduce

在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm's running_mean,关于batchnorm详解可见https://blog.csdn.net/wzy_zju/article/details/81262453

torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。

因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。

Sample

通过一个简单的案例来输出state_dict字典对象中存放的变量

  1. #encoding:utf-8
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import torchvision
  6. import numpy as mp
  7. import matplotlib.pyplot as plt
  8. import torch.nn.functional as F
  9. #define model
  10. class TheModelClass(nn.Module):
  11. def __init__(self):
  12. super(TheModelClass,self).__init__()
  13. self.conv1=nn.Conv2d(3,6,5)
  14. self.pool=nn.MaxPool2d(2,2)
  15. self.conv2=nn.Conv2d(6,16,5)
  16. self.fc1=nn.Linear(16*5*5,120)
  17. self.fc2=nn.Linear(120,84)
  18. self.fc3=nn.Linear(84,10)
  19. def forward(self,x):
  20. x=self.pool(F.relu(self.conv1(x)))
  21. x=self.pool(F.relu(self.conv2(x)))
  22. x=x.view(-1,16*5*5)
  23. x=F.relu(self.fc1(x))
  24. x=F.relu(self.fc2(x))
  25. x=self.fc3(x)
  26. return x
  27. def main():
  28. # Initialize model
  29. model = TheModelClass()
  30. #Initialize optimizer
  31. optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
  32. #print model's state_dict
  33. print('Model.state_dict:')
  34. for param_tensor in model.state_dict():
  35. #打印 key value字典
  36. print(param_tensor,'\t',model.state_dict()[param_tensor].size())
  37. #print optimizer's state_dict
  38. print('Optimizer,s state_dict:')
  39. for var_name in optimizer.state_dict():
  40. print(var_name,'\t',optimizer.state_dict()[var_name])
  41. if __name__=='__main__':
  42. main()

 具体的输出结果如下:可以很清晰的观测到state_dict中存放的key和value的值

  1. Model.state_dict:
  2. conv1.weight torch.Size([6, 3, 5, 5])
  3. conv1.bias torch.Size([6])
  4. conv2.weight torch.Size([16, 6, 5, 5])
  5. conv2.bias torch.Size([16])
  6. fc1.weight torch.Size([120, 400])
  7. fc1.bias torch.Size([120])
  8. fc2.weight torch.Size([84, 120])
  9. fc2.bias torch.Size([84])
  10. fc3.weight torch.Size([10, 84])
  11. fc3.bias torch.Size([10])
  12. Optimizer,s state_dict:
  13. state {}
  14. param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]

 



所属网站分类: 技术文章 > 博客

作者:无敌是多么寂寞

链接:https://www.pythonheidong.com/blog/article/10510/907c727e0b7f76f0fb44/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

0 0
收藏该文
已收藏

评论内容:(最多支持255个字符)