程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

Keras中的MultiStepLR

发布于2019-12-22 21:20     阅读(1625)     评论(0)     点赞(30)     收藏(4)


Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:

1.代码

  1. from tensorflow.python.keras.callbacks import Callback
  2. from tensorflow.python.keras import backend as K
  3. import numpy as np
  4. import argparse
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument('--lr_decay_epochs', type=list, default=[2, 5, 7], help="For MultiFactorScheduler step")
  7. parser.add_argument('--lr_decay_factor', type=float, default=0.1)
  8. args, _ = parser.parse_known_args()
  9. def get_lr_scheduler(args):
  10. lr_scheduler = MultiStepLR(args=args)
  11. return lr_scheduler
  12. class MultiStepLR(Callback):
  13. """Learning rate scheduler.
  14. Arguments:
  15. args: parser_setting
  16. verbose: int. 0: quiet, 1: update messages.
  17. """
  18. def __init__(self, args, verbose=0):
  19. super(MultiStepLR, self).__init__()
  20. self.args = args
  21. self.steps = args.lr_decay_epochs
  22. self.factor = args.lr_decay_factor
  23. self.verbose = verbose
  24. def on_epoch_begin(self, epoch, logs=None):
  25. if not hasattr(self.model.optimizer, 'lr'):
  26. raise ValueError('Optimizer must have a "lr" attribute.')
  27. lr = self.schedule(epoch)
  28. if not isinstance(lr, (float, np.float32, np.float64)):
  29. raise ValueError('The output of the "schedule" function '
  30. 'should be float.')
  31. K.set_value(self.model.optimizer.lr, lr)
  32. print("learning rate: {:.7f}".format(K.get_value(self.model.optimizer.lr)).rstrip('0'))
  33. if self.verbose > 0:
  34. print('\nEpoch %05d: MultiStepLR reducing learning '
  35. 'rate to %s.' % (epoch + 1, lr))
  36. def schedule(self, epoch):
  37. lr = K.get_value(self.model.optimizer.lr)
  38. for i in range(len(self.steps)):
  39. if epoch == self.steps[i]:
  40. lr = lr * self.factor
  41. return lr

2.调用(callbacks里append这个lr_scheduler,fit_generator里callbacks传入这个变量)

  1. callbacks = []
  2. lr_scheduler = get_lr_scheduler(args=args)
  3. callbacks.append(lr_scheduler)
  4. ...
  5. model.fit_generator(train_generator,
  6. steps_per_epoch=train_generator.samples // args.batch_size,
  7. validation_data=test_generator,
  8. validation_steps=test_generator.samples // args.batch_size,
  9. workers=args.num_workers,
  10. callbacks=callbacks, # 你的callbacks, 包含了lr_scheduler
  11. epochs=args.epochs,
  12. )

大家可以拿去用~

发布了127 篇原创文章 · 获赞 914 · 访问量 134万+


所属网站分类: 技术文章 > 博客

作者:短发越来越短

链接:https://www.pythonheidong.com/blog/article/182094/89c8188be6a4b625d5e8/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

30 0
收藏该文
已收藏

评论内容:(最多支持255个字符)