程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

python面试(5)

函数(0)

标签  

函数(0)

列表(0)

日期归档  

tensorflow入门教程(四十八)人体姿态检测(五)

发布于2019-08-22 16:30     阅读(1591)     评论(0)     点赞(6)     收藏(3)


#
#作者:韦访
#博客:https://blog.csdn.net/rookie_wei
#微信:1007895847
#添加微信的备注一下是CSDN的
#欢迎大家一起学习
#

------韦访 20190704

7定义网络

继续往下分析,

  1. # define model for multi-gpu
  2. # 如果有多块GPU,将队列划分为多块,以分给每块GPU
  3. q_inp_split, q_heat_split, q_vect_split = tf.split(q_inp, args.gpus), tf.split(q_heat, args.gpus), tf.split(q_vect, args.gpus)
  4. output_vectmap = []
  5. output_heatmap = []
  6. losses = []
  7. last_losses_l1 = []
  8. last_losses_l2 = []
  9. outputs = []
  10. # 将任务分配到多块GPU上
  11. for gpu_id in range(args.gpus):
  12. with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)):
  13. with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)):
  14. # 根据传入的model参数获取net,已经训练好了的模型路径,最后一层网络名
  15. net, pretrain_path, last_layer = get_network(args.model, q_inp_split[gpu_id])
  16. # 如果传入参数checkpoint,则pretrain_path直接用checkpoint的路径而不是默认的路径
  17. if args.checkpoint:
  18. pretrain_path = args.checkpoint
  19. # 获取最后一层的输出 L 和 S
  20. vect, heat = net.loss_last()
  21. output_vectmap.append(vect)
  22. output_heatmap.append(heat)
  23. # 获取最后输出结果
  24. outputs.append(net.get_output())
  25. # 获取 stage2 后的每一层的输出 L 和 S
  26. l1s, l2s = net.loss_l1_l2()
  27. # 求每一层的L2范数 loss
  28. for idx, (l1, l2) in enumerate(zip(l1s, l2s)):
  29. loss_l1 = tf.nn.l2_loss(tf.concat(l1, axis=0) - q_vect_split[gpu_id], name='loss_l1_stage%d_tower%d' % (idx, gpu_id))
  30. loss_l2 = tf.nn.l2_loss(tf.concat(l2, axis=0) - q_heat_split[gpu_id], name='loss_l2_stage%d_tower%d' % (idx, gpu_id))
  31. losses.append(tf.reduce_mean([loss_l1, loss_l2]))
  32. # 最后一层的L2 范数 loss
  33. last_losses_l1.append(loss_l1)
  34. last_losses_l2.append(loss_l2)
  35. outputs = tf.concat(outputs, axis=0)

如果你是土豪,有多块GPU,上面的代码就是满足你的,将队列划分成多块,再分给每块GPU去训练,我只有一个GPU。主要看get_network函数,

  1. def get_network(type, placeholder_input, sess_for_load=None, trainable=True):
  2. if type == 'mobilenet':
  3. net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.75, conv_width2=1.00, trainable=trainable)
  4. pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt'
  5. last_layer = 'MConv_Stage6_L{aux}_5'
  6. elif type == 'mobilenet_fast':
  7. net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.5, conv_width2=0.5, trainable=trainable)
  8. pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt'
  9. last_layer = 'MConv_Stage6_L{aux}_5'
  10. elif type == 'mobilenet_accurate':
  11. net = MobilenetNetwork({'image': placeholder_input}, conv_width=1.00, conv_width2=1.00, trainable=trainable)
  12. pretrain_path = 'pretrained/mobilenet_v1_1.0_224_2017_06_14/mobilenet_v1_1.0_224.ckpt'
  13. last_layer = 'MConv_Stage6_L{aux}_5'
  14. elif type == 'mobilenet_thin':
  15. net = MobilenetNetworkThin({'image': placeholder_input}, conv_width=0.75, conv_width2=0.50, trainable=trainable)
  16. pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt'
  17. last_layer = 'MConv_Stage6_L{aux}_5'
  18. elif type in ['mobilenet_v2_w1.4_r1.0', 'mobilenet_v2_large', 'mobilenet_v2_large_quantize']: # m_v2_large
  19. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=1.4, conv_width2=1.0, trainable=trainable)
  20. pretrain_path = 'pretrained/mobilenet_v2_1.4_224/mobilenet_v2_1.4_224.ckpt'
  21. last_layer = 'MConv_Stage6_L{aux}_5'
  22. elif type == 'mobilenet_v2_w1.4_r0.5':
  23. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=1.4, conv_width2=0.5, trainable=trainable)
  24. pretrain_path = 'pretrained/mobilenet_v2_1.4_224/mobilenet_v2_1.4_224.ckpt'
  25. last_layer = 'MConv_Stage6_L{aux}_5'
  26. elif type == 'mobilenet_v2_w1.0_r1.0':
  27. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=1.0, conv_width2=1.0, trainable=trainable)
  28. pretrain_path = 'pretrained/mobilenet_v2_1.0_224/mobilenet_v2_1.0_224.ckpt'
  29. last_layer = 'MConv_Stage6_L{aux}_5'
  30. elif type == 'mobilenet_v2_w1.0_r0.75':
  31. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=1.0, conv_width2=0.75, trainable=trainable)
  32. pretrain_path = 'pretrained/mobilenet_v2_1.0_224/mobilenet_v2_1.0_224.ckpt'
  33. last_layer = 'MConv_Stage6_L{aux}_5'
  34. elif type == 'mobilenet_v2_w1.0_r0.5':
  35. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=1.0, conv_width2=0.5, trainable=trainable)
  36. pretrain_path = 'pretrained/mobilenet_v2_1.0_224/mobilenet_v2_1.0_224.ckpt'
  37. last_layer = 'MConv_Stage6_L{aux}_5'
  38. elif type == 'mobilenet_v2_w0.75_r0.75':
  39. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=0.75, conv_width2=0.75, trainable=trainable)
  40. pretrain_path = 'pretrained/mobilenet_v2_0.75_224/mobilenet_v2_0.75_224.ckpt'
  41. last_layer = 'MConv_Stage6_L{aux}_5'
  42. elif type == 'mobilenet_v2_w0.5_r0.5' or type == 'mobilenet_v2_small': # m_v2_fast
  43. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=0.5, conv_width2=0.5, trainable=trainable)
  44. pretrain_path = 'pretrained/mobilenet_v2_0.5_224/mobilenet_v2_0.5_224.ckpt'
  45. last_layer = 'MConv_Stage6_L{aux}_5'
  46. elif type == 'mobilenet_v2_1.4':
  47. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=1.4, trainable=trainable)
  48. pretrain_path = 'pretrained/mobilenet_v2_1.4_224/mobilenet_v2_1.4_224.ckpt'
  49. last_layer = 'MConv_Stage6_L{aux}_5'
  50. elif type == 'mobilenet_v2_1.0':
  51. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=1.0, trainable=trainable)
  52. pretrain_path = 'pretrained/mobilenet_v2_1.0_224/mobilenet_v2_1.0_224.ckpt'
  53. last_layer = 'MConv_Stage6_L{aux}_5'
  54. elif type == 'mobilenet_v2_0.75':
  55. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=0.75, trainable=trainable)
  56. pretrain_path = 'pretrained/mobilenet_v2_0.75_224/mobilenet_v2_0.75_224.ckpt'
  57. last_layer = 'MConv_Stage6_L{aux}_5'
  58. elif type == 'mobilenet_v2_0.5':
  59. net = Mobilenetv2Network({'image': placeholder_input}, conv_width=0.5, trainable=trainable)
  60. pretrain_path = 'pretrained/mobilenet_v2_0.5_224/mobilenet_v2_0.5_224.ckpt'
  61. last_layer = 'MConv_Stage6_L{aux}_5'
  62. elif type in ['cmu', 'openpose']:
  63. net = CmuNetwork({'image': placeholder_input}, trainable=trainable)
  64. pretrain_path = 'numpy/openpose_coco.npy'
  65. last_layer = 'Mconv7_stage6_L{aux}'
  66. elif type in ['cmu_quantize', 'openpose_quantize']:
  67. net = CmuNetwork({'image': placeholder_input}, trainable=trainable)
  68. pretrain_path = 'train/cmu/bs8_lr0.0001_q_e80/model_latest-18000'
  69. last_layer = 'Mconv7_stage6_L{aux}'
  70. elif type == 'vgg':
  71. net = CmuNetwork({'image': placeholder_input}, trainable=trainable)
  72. pretrain_path = 'numpy/openpose_vgg16.npy'
  73. last_layer = 'Mconv7_stage6_L{aux}'
  74. else:
  75. raise Exception('Invalid Model Name.')
  76. pretrain_path_full = os.path.join(_get_base_path(), pretrain_path)
  77. if sess_for_load is not None:
  78. if type in ['cmu', 'vgg', 'openpose']:
  79. if not os.path.isfile(pretrain_path_full):
  80. raise Exception('Model file doesn\'t exist, path=%s' % pretrain_path_full)
  81. net.load(os.path.join(_get_base_path(), pretrain_path), sess_for_load)
  82. else:
  83. try:
  84. s = '%dx%d' % (placeholder_input.shape[2], placeholder_input.shape[1])
  85. except:
  86. s = ''
  87. ckpts = {
  88. 'mobilenet': 'trained/mobilenet_%s/model-246038' % s,
  89. 'mobilenet_thin': 'trained/mobilenet_thin_%s/model-449003' % s,
  90. 'mobilenet_fast': 'trained/mobilenet_fast_%s/model-189000' % s,
  91. 'mobilenet_accurate': 'trained/mobilenet_accurate/model-170000',
  92. 'mobilenet_v2_w1.4_r0.5': 'trained/mobilenet_v2_w1.4_r0.5/model_latest-380401',
  93. 'mobilenet_v2_large': 'trained/mobilenet_v2_w1.4_r1.0/model-570000',
  94. 'mobilenet_v2_small': 'trained/mobilenet_v2_w0.5_r0.5/model_latest-380401',
  95. }
  96. ckpt_path = os.path.join(_get_base_path(), ckpts[type])
  97. loader = tf.train.Saver()
  98. try:
  99. loader.restore(sess_for_load, ckpt_path)
  100. except Exception as e:
  101. raise Exception('Fail to load model files. \npath=%s\nerr=%s' % (ckpt_path, str(e)))
  102. return net, pretrain_path_full, last_layer

这里提供了很多个网络给我们选择,我们使用的是cmu网络,所以,用的是CmuNetwork类,这个类实现的就是我们论文里第3点讲的那个网络,

 

来看下代码怎么实现,

class CmuNetwork(network_base.BaseNetwork):

CmuNetwork类继承了network_base.BaseNetwork类,来看看network_base.BaseNetwork类的__init__函数做了什么,

  1. class BaseNetwork(object):
  2. def __init__(self, inputs, trainable=True):
  3. # The input nodes for this network
  4. self.inputs = inputs
  5. # The current list of terminal nodes
  6. self.terminals = []
  7. # Mapping from layer names to layers
  8. self.layers = dict(inputs)
  9. # If true, the resulting variables are set as trainable
  10. self.trainable = trainable
  11. # Switch variable for dropout
  12. self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
  13. shape=[],
  14. name='use_dropout')
  15. self.setup()

一些基本的初始化以后,调用setup函数,而这个setup函数主要在CmuNetwork类里实现的,代码如下,

  1. def setup(self):
  2. # 用了 VGG19 的前10层,对后基层网络进行了微调
  3. (self.feed('image')
  4. .normalize_vgg(name='preprocess')
  5. .conv(3, 3, 64, 1, 1, name='conv1_1')
  6. .conv(3, 3, 64, 1, 1, name='conv1_2')
  7. .max_pool(2, 2, 2, 2, name='pool1_stage1', padding='VALID')
  8. .conv(3, 3, 128, 1, 1, name='conv2_1')
  9. .conv(3, 3, 128, 1, 1, name='conv2_2')
  10. .max_pool(2, 2, 2, 2, name='pool2_stage1', padding='VALID')
  11. .conv(3, 3, 256, 1, 1, name='conv3_1')
  12. .conv(3, 3, 256, 1, 1, name='conv3_2')
  13. .conv(3, 3, 256, 1, 1, name='conv3_3')
  14. .conv(3, 3, 256, 1, 1, name='conv3_4')
  15. .max_pool(2, 2, 2, 2, name='pool3_stage1', padding='VALID')
  16. .conv(3, 3, 512, 1, 1, name='conv4_1')
  17. .conv(3, 3, 512, 1, 1, name='conv4_2') # 这里上去的都是VGG19的前10层网络
  18. .conv(3, 3, 256, 1, 1, name='conv4_3_CPM')
  19. .conv(3, 3, 128, 1, 1, name='conv4_4_CPM') # ***** 得到原始图片的特征图F
  20. ##########################################################################################
  21. # stage 1 ,分别得到 S1 和 L1
  22. .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L1')
  23. .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L1')
  24. .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L1')
  25. .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L1')
  26. .conv(1, 1, 38, 1, 1, relu=False, name='conv5_5_CPM_L1'))
  27. (self.feed('conv4_4_CPM')
  28. .conv(3, 3, 128, 1, 1, name='conv5_1_CPM_L2')
  29. .conv(3, 3, 128, 1, 1, name='conv5_2_CPM_L2')
  30. .conv(3, 3, 128, 1, 1, name='conv5_3_CPM_L2')
  31. .conv(1, 1, 512, 1, 1, name='conv5_4_CPM_L2')
  32. .conv(1, 1, 19, 1, 1, relu=False, name='conv5_5_CPM_L2'))
  33. ##########################################################################################
  34. # stage2,将上一个stage得到的 S 和 L,再加上原始图片特征图F当成输入
  35. # L1 是矢量图 L, L2 是热图(置信图)S
  36. (self.feed('conv5_5_CPM_L1',
  37. 'conv5_5_CPM_L2',
  38. 'conv4_4_CPM')
  39. .concat(3, name='concat_stage2')
  40. .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L1')
  41. .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L1')
  42. .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L1')
  43. .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L1')
  44. .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L1')
  45. .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L1')
  46. .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage2_L1'))
  47. (self.feed('concat_stage2')
  48. .conv(7, 7, 128, 1, 1, name='Mconv1_stage2_L2')
  49. .conv(7, 7, 128, 1, 1, name='Mconv2_stage2_L2')
  50. .conv(7, 7, 128, 1, 1, name='Mconv3_stage2_L2')
  51. .conv(7, 7, 128, 1, 1, name='Mconv4_stage2_L2')
  52. .conv(7, 7, 128, 1, 1, name='Mconv5_stage2_L2')
  53. .conv(1, 1, 128, 1, 1, name='Mconv6_stage2_L2')
  54. .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage2_L2'))
  55. ##########################################################################################
  56. # stage3,将上一个stage得到的 S 和 L,再加上原始图片特征图F当成输入
  57. (self.feed('Mconv7_stage2_L1',
  58. 'Mconv7_stage2_L2',
  59. 'conv4_4_CPM')
  60. .concat(3, name='concat_stage3')
  61. .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L1')
  62. .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L1')
  63. .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L1')
  64. .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L1')
  65. .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L1')
  66. .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L1')
  67. .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage3_L1'))
  68. (self.feed('concat_stage3')
  69. .conv(7, 7, 128, 1, 1, name='Mconv1_stage3_L2')
  70. .conv(7, 7, 128, 1, 1, name='Mconv2_stage3_L2')
  71. .conv(7, 7, 128, 1, 1, name='Mconv3_stage3_L2')
  72. .conv(7, 7, 128, 1, 1, name='Mconv4_stage3_L2')
  73. .conv(7, 7, 128, 1, 1, name='Mconv5_stage3_L2')
  74. .conv(1, 1, 128, 1, 1, name='Mconv6_stage3_L2')
  75. .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage3_L2'))
  76. ##########################################################################################
  77. # stage4,将上一个stage得到的 S 和 L,再加上原始图片特征图F当成输入
  78. (self.feed('Mconv7_stage3_L1',
  79. 'Mconv7_stage3_L2',
  80. 'conv4_4_CPM')
  81. .concat(3, name='concat_stage4')
  82. .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L1')
  83. .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L1')
  84. .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L1')
  85. .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L1')
  86. .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L1')
  87. .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L1')
  88. .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage4_L1'))
  89. (self.feed('concat_stage4')
  90. .conv(7, 7, 128, 1, 1, name='Mconv1_stage4_L2')
  91. .conv(7, 7, 128, 1, 1, name='Mconv2_stage4_L2')
  92. .conv(7, 7, 128, 1, 1, name='Mconv3_stage4_L2')
  93. .conv(7, 7, 128, 1, 1, name='Mconv4_stage4_L2')
  94. .conv(7, 7, 128, 1, 1, name='Mconv5_stage4_L2')
  95. .conv(1, 1, 128, 1, 1, name='Mconv6_stage4_L2')
  96. .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage4_L2'))
  97. ##########################################################################################
  98. # stage4,将上一个stage得到的 S 和 L,再加上原始图片特征图F当成输入
  99. (self.feed('Mconv7_stage4_L1',
  100. 'Mconv7_stage4_L2',
  101. 'conv4_4_CPM')
  102. .concat(3, name='concat_stage5')
  103. .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L1')
  104. .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L1')
  105. .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L1')
  106. .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L1')
  107. .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L1')
  108. .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L1')
  109. .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage5_L1'))
  110. (self.feed('concat_stage5')
  111. .conv(7, 7, 128, 1, 1, name='Mconv1_stage5_L2')
  112. .conv(7, 7, 128, 1, 1, name='Mconv2_stage5_L2')
  113. .conv(7, 7, 128, 1, 1, name='Mconv3_stage5_L2')
  114. .conv(7, 7, 128, 1, 1, name='Mconv4_stage5_L2')
  115. .conv(7, 7, 128, 1, 1, name='Mconv5_stage5_L2')
  116. .conv(1, 1, 128, 1, 1, name='Mconv6_stage5_L2')
  117. .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage5_L2'))
  118. ##########################################################################################
  119. # stage6,将上一个stage得到的 S 和 L,再加上原始图片特征图F当成输入
  120. (self.feed('Mconv7_stage5_L1',
  121. 'Mconv7_stage5_L2',
  122. 'conv4_4_CPM')
  123. .concat(3, name='concat_stage6')
  124. .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L1')
  125. .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L1')
  126. .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L1')
  127. .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L1')
  128. .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L1')
  129. .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L1')
  130. .conv(1, 1, 38, 1, 1, relu=False, name='Mconv7_stage6_L1'))
  131. (self.feed('concat_stage6')
  132. .conv(7, 7, 128, 1, 1, name='Mconv1_stage6_L2')
  133. .conv(7, 7, 128, 1, 1, name='Mconv2_stage6_L2')
  134. .conv(7, 7, 128, 1, 1, name='Mconv3_stage6_L2')
  135. .conv(7, 7, 128, 1, 1, name='Mconv4_stage6_L2')
  136. .conv(7, 7, 128, 1, 1, name='Mconv5_stage6_L2')
  137. .conv(1, 1, 128, 1, 1, name='Mconv6_stage6_L2')
  138. .conv(1, 1, 19, 1, 1, relu=False, name='Mconv7_stage6_L2'))
  139. ##########################################################################################
  140. # 最后一层,将得到的 S6 和 L6 拼接
  141. with tf.variable_scope('Openpose'):
  142. (self.feed('Mconv7_stage6_L2',
  143. 'Mconv7_stage6_L1')
  144. .concat(3, name='concat_stage7'))

怎么样?对比论文的图看,是不是茅塞顿开了?继续回到train.py的main函数,

8、学习率

得到网络以后,就是一些损失值的保存,有备注了就不讲了,继续往下看,

  1. with tf.device(tf.DeviceSpec(device_type="GPU")):
  2. # define loss
  3. # 计算每张图片的L1和L2总损失
  4. total_loss = tf.reduce_sum(losses) / args.batchsize
  5. # 计算每张图片的L1总损失
  6. total_loss_ll_paf = tf.reduce_sum(last_losses_l1) / args.batchsize
  7. # 计算每张图片的L2总损失
  8. total_loss_ll_heat = tf.reduce_sum(last_losses_l2) / args.batchsize
  9. # 计算每个batch 的L1和L2总损失
  10. total_loss_ll = tf.reduce_sum([total_loss_ll_paf, total_loss_ll_heat])
  11. # define optimizer
  12. # 设置学习率
  13. # 每个epoch执行的步数
  14. step_per_epoch = 121745 // args.batchsize
  15. global_step = tf.Variable(0, trainable=False)
  16. if ',' not in args.lr:
  17. starter_learning_rate = float(args.lr)
  18. # learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
  19. # decay_steps=10000, decay_rate=0.33, staircase=True)
  20. # 学习率余弦衰减
  21. learning_rate = tf.train.cosine_decay(starter_learning_rate, global_step, args.max_epoch * step_per_epoch, alpha=0.0)
  22. else:
  23. lrs = [float(x) for x in args.lr.split(',')]
  24. boundaries = [step_per_epoch * 5 * i for i, _ in range(len(lrs)) if i > 0]
  25. learning_rate = tf.train.piecewise_constant(global_step, boundaries, lrs)

上面也是一些损失的计算,还有学习率的设置,

9、优化器

继续往下看,

  1. # 优化器
  2. optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1e-8)
  3. # optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.8, use_locking=True, use_nesterov=True)
  4. # 关于tf.GraphKeys.UPDATE_OPS,这是一个tensorflow的计算图中内置的一个集合,其中会保存一些需要在训练操作之前完成的操作,并配合tf.control_dependencies函数使用。
  5. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  6. # tf.control_dependencies,该函数保证其辖域中的操作必须要在该函数所传递的参数中的操作完成后再进行
  7. with tf.control_dependencies(update_ops):
  8. train_op = optimizer.minimize(total_loss, global_step, colocate_gradients_with_ops=True)
  9. logger.info('define model-')
  10. # define summary
  11. tf.summary.scalar("loss", total_loss)
  12. tf.summary.scalar("loss_lastlayer", total_loss_ll)
  13. tf.summary.scalar("loss_lastlayer_paf", total_loss_ll_paf)
  14. tf.summary.scalar("loss_lastlayer_heat", total_loss_ll_heat)
  15. tf.summary.scalar("queue_size", enqueuer.size())
  16. tf.summary.scalar("lr", learning_rate)
  17. merged_summary_op = tf.summary.merge_all()

 

这里就是设置优化器,用的是adam梯度下降法,

10、定义占位符

继续往下看,

  1. # 定义验证集和示例的占位符
  2. valid_loss = tf.placeholder(tf.float32, shape=[])
  3. valid_loss_ll = tf.placeholder(tf.float32, shape=[])
  4. valid_loss_ll_paf = tf.placeholder(tf.float32, shape=[])
  5. valid_loss_ll_heat = tf.placeholder(tf.float32, shape=[])
  6. sample_train = tf.placeholder(tf.float32, shape=(4, 640, 640, 3))
  7. sample_valid = tf.placeholder(tf.float32, shape=(12, 640, 640, 3))
  8. train_img = tf.summary.image('training sample', sample_train, 4)
  9. valid_img = tf.summary.image('validation sample', sample_valid, 12)
  10. valid_loss_t = tf.summary.scalar("loss_valid", valid_loss)
  11. valid_loss_ll_t = tf.summary.scalar("loss_valid_lastlayer", valid_loss_ll)
  12. merged_validate_op = tf.summary.merge([train_img, valid_img, valid_loss_t, valid_loss_ll_t])

上面就是定义占位符了,又到了熟悉的配方,熟悉的味道了。

11、会话

继续看,

  1. # 用于保存模型
  2. saver = tf.train.Saver(max_to_keep=1000)
  3. # 创建会话
  4. config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
  5. config.gpu_options.allow_growth = True
  6. with tf.Session(config=config) as sess:
  7. logger.info('model weights initialization')
  8. sess.run(tf.global_variables_initializer())
  9. # 加载模型
  10. if args.checkpoint and os.path.isdir(args.checkpoint):
  11. logger.info('Restore from checkpoint...')
  12. # loader = tf.train.Saver(net.restorable_variables())
  13. # loader.restore(sess, tf.train.latest_checkpoint(args.checkpoint))
  14. saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint))
  15. logger.info('Restore from checkpoint...Done')
  16. elif pretrain_path:
  17. logger.info('Restore pretrained weights... %s' % pretrain_path)
  18. if '.npy' in pretrain_path:
  19. # 如果是npy的格式
  20. net.load(pretrain_path, sess, False)
  21. else:
  22. try:
  23. loader = tf.train.Saver(net.restorable_variables(only_backbone=False))
  24. loader.restore(sess, pretrain_path)
  25. except:
  26. logger.info('Restore only weights in backbone layers.')
  27. loader = tf.train.Saver(net.restorable_variables())
  28. loader.restore(sess, pretrain_path)
  29. logger.info('Restore pretrained weights...Done')
  30. logger.info('prepare file writer')
  31. file_writer = tf.summary.FileWriter(os.path.join(logpath, args.tag), sess.graph)
  32. # 启动队列
  33. logger.info('prepare coordinator')
  34. coord = tf.train.Coordinator()
  35. enqueuer.set_coordinator(coord)
  36. enqueuer.start()
  37. logger.info('Training Started.')
  38. time_started = time.time()
  39. last_gs_num = last_gs_num2 = 0
  40. initial_gs_num = sess.run(global_step)
  41. last_log_epoch1 = last_log_epoch2 = -1
  42. while True:
  43. # 开始训练
  44. _, gs_num = sess.run([train_op, global_step])
  45. # 当前epoch
  46. curr_epoch = float(gs_num) / step_per_epoch
  47. # 训练到指定次数了,退出
  48. if gs_num > step_per_epoch * args.max_epoch:
  49. break
  50. if gs_num - last_gs_num >= 500:
  51. # 训练500步输出一次损失
  52. train_loss, train_loss_ll, train_loss_ll_paf, train_loss_ll_heat, lr_val, summary = sess.run([total_loss, total_loss_ll, total_loss_ll_paf, total_loss_ll_heat, learning_rate, merged_summary_op])
  53. # log of training loss / accuracy
  54. batch_per_sec = (gs_num - initial_gs_num) / (time.time() - time_started)
  55. logger.info('epoch=%.2f step=%d, %0.4f examples/sec lr=%f, loss=%g, loss_ll=%g, loss_ll_paf=%g, loss_ll_heat=%g' % (gs_num / step_per_epoch, gs_num, batch_per_sec * args.batchsize, lr_val, train_loss, train_loss_ll, train_loss_ll_paf, train_loss_ll_heat))
  56. last_gs_num = gs_num
  57. if last_log_epoch1 < curr_epoch:
  58. file_writer.add_summary(summary, curr_epoch)
  59. last_log_epoch1 = curr_epoch
  60. if gs_num - last_gs_num2 >= 2000:
  61. # 训练2000次保存一次
  62. # save weights
  63. saver.save(sess, os.path.join(modelpath, args.tag, 'model_latest'), global_step=global_step)
  64. average_loss = average_loss_ll = average_loss_ll_paf = average_loss_ll_heat = 0
  65. total_cnt = 0
  66. if len(validation_cache) == 0:
  67. for images_test, heatmaps, vectmaps in tqdm(df_valid.get_data()):
  68. validation_cache.append((images_test, heatmaps, vectmaps))
  69. df_valid.reset_state()
  70. del df_valid
  71. df_valid = None
  72. # log of test accuracy
  73. # 输出测试准确率
  74. for images_test, heatmaps, vectmaps in validation_cache:
  75. lss, lss_ll, lss_ll_paf, lss_ll_heat, vectmap_sample, heatmap_sample = sess.run(
  76. [total_loss, total_loss_ll, total_loss_ll_paf, total_loss_ll_heat, output_vectmap, output_heatmap],
  77. feed_dict={q_inp: images_test, q_vect: vectmaps, q_heat: heatmaps}
  78. )
  79. average_loss += lss * len(images_test)
  80. average_loss_ll += lss_ll * len(images_test)
  81. average_loss_ll_paf += lss_ll_paf * len(images_test)
  82. average_loss_ll_heat += lss_ll_heat * len(images_test)
  83. total_cnt += len(images_test)
  84. logger.info('validation(%d) %s loss=%f, loss_ll=%f, loss_ll_paf=%f, loss_ll_heat=%f' % (total_cnt, args.tag, average_loss / total_cnt, average_loss_ll / total_cnt, average_loss_ll_paf / total_cnt, average_loss_ll_heat / total_cnt))
  85. last_gs_num2 = gs_num
  86. sample_image = [enqueuer.last_dp[0][i] for i in range(4)]
  87. outputMat = sess.run(
  88. outputs,
  89. feed_dict={q_inp: np.array((sample_image + val_image) * max(1, (args.batchsize // 16)))}
  90. )
  91. pafMat, heatMat = outputMat[:, :, :, 19:], outputMat[:, :, :, :19]
  92. sample_results = []
  93. for i in range(len(sample_image)):
  94. test_result = CocoPose.display_image(sample_image[i], heatMat[i], pafMat[i], as_numpy=True)
  95. test_result = cv2.resize(test_result, (640, 640))
  96. test_result = test_result.reshape([640, 640, 3]).astype(float)
  97. sample_results.append(test_result)
  98. test_results = []
  99. for i in range(len(val_image)):
  100. test_result = CocoPose.display_image(val_image[i], heatMat[len(sample_image) + i], pafMat[len(sample_image) + i], as_numpy=True)
  101. test_result = cv2.resize(test_result, (640, 640))
  102. test_result = test_result.reshape([640, 640, 3]).astype(float)
  103. test_results.append(test_result)
  104. # save summary
  105. summary = sess.run(merged_validate_op, feed_dict={
  106. valid_loss: average_loss / total_cnt,
  107. valid_loss_ll: average_loss_ll / total_cnt,
  108. valid_loss_ll_paf: average_loss_ll_paf / total_cnt,
  109. valid_loss_ll_heat: average_loss_ll_heat / total_cnt,
  110. sample_valid: test_results,
  111. sample_train: sample_results
  112. })
  113. if last_log_epoch2 < curr_epoch:
  114. file_writer.add_summary(summary, curr_epoch)
  115. last_log_epoch2 = curr_epoch
  116. saver.save(sess, os.path.join(modelpath, args.tag, 'model'), global_step=global_step)
  117. logger.info('optimization finished. %f' % (time.time() - time_started))

上面就是真正的训练了,不想解释了,懒,后续我会将带有注释的源码上传,你们自己看吧。下一将,就分析怎么使用这个网络。

如果您感觉本篇博客对您有帮助,请打开支付宝,领个红包支持一下,祝您扫到99元,谢谢~~



所属网站分类: 技术文章 > 博客

作者:皇后娘娘别惹我

链接:https://www.pythonheidong.com/blog/article/52927/033428ab5080dd4edb53/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

6 0
收藏该文
已收藏

评论内容:(最多支持255个字符)