发布于2019-12-22 21:20 阅读(1625) 评论(0) 点赞(30) 收藏(4)
Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:
1.代码
- from tensorflow.python.keras.callbacks import Callback
- from tensorflow.python.keras import backend as K
- import numpy as np
- import argparse
-
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--lr_decay_epochs', type=list, default=[2, 5, 7], help="For MultiFactorScheduler step")
- parser.add_argument('--lr_decay_factor', type=float, default=0.1)
- args, _ = parser.parse_known_args()
-
-
- def get_lr_scheduler(args):
- lr_scheduler = MultiStepLR(args=args)
- return lr_scheduler
-
-
- class MultiStepLR(Callback):
- """Learning rate scheduler.
- Arguments:
- args: parser_setting
- verbose: int. 0: quiet, 1: update messages.
- """
-
- def __init__(self, args, verbose=0):
- super(MultiStepLR, self).__init__()
- self.args = args
- self.steps = args.lr_decay_epochs
- self.factor = args.lr_decay_factor
- self.verbose = verbose
-
- def on_epoch_begin(self, epoch, logs=None):
- if not hasattr(self.model.optimizer, 'lr'):
- raise ValueError('Optimizer must have a "lr" attribute.')
- lr = self.schedule(epoch)
- if not isinstance(lr, (float, np.float32, np.float64)):
- raise ValueError('The output of the "schedule" function '
- 'should be float.')
- K.set_value(self.model.optimizer.lr, lr)
- print("learning rate: {:.7f}".format(K.get_value(self.model.optimizer.lr)).rstrip('0'))
- if self.verbose > 0:
- print('\nEpoch %05d: MultiStepLR reducing learning '
- 'rate to %s.' % (epoch + 1, lr))
-
- def schedule(self, epoch):
- lr = K.get_value(self.model.optimizer.lr)
- for i in range(len(self.steps)):
- if epoch == self.steps[i]:
- lr = lr * self.factor
-
- return lr
2.调用(callbacks里append这个lr_scheduler,fit_generator里callbacks传入这个变量)
- callbacks = []
- lr_scheduler = get_lr_scheduler(args=args)
- callbacks.append(lr_scheduler)
-
- ...
- model.fit_generator(train_generator,
- steps_per_epoch=train_generator.samples // args.batch_size,
- validation_data=test_generator,
- validation_steps=test_generator.samples // args.batch_size,
- workers=args.num_workers,
- callbacks=callbacks, # 你的callbacks, 包含了lr_scheduler
- epochs=args.epochs,
- )
大家可以拿去用~
作者:短发越来越短
链接:https://www.pythonheidong.com/blog/article/182094/89c8188be6a4b625d5e8/
来源:python黑洞网
任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任
昵称:
评论内容:(最多支持255个字符)
---无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事,而不是让内心的烦躁、焦虑,坏掉你本来就不多的热情和定力
Copyright © 2018-2021 python黑洞网 All Rights Reserved 版权所有,并保留所有权利。 京ICP备18063182号-1
投诉与举报,广告合作请联系vgs_info@163.com或QQ3083709327
免责声明:网站文章均由用户上传,仅供读者学习交流使用,禁止用做商业用途。若文章涉及色情,反动,侵权等违法信息,请向我们举报,一经核实我们会立即删除!