from typing import Dict
import numpy as np
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import Callback
[docs]def clr_check_args(args: Dict) -> bool:
"""
Checks if the arguments for cyclical learning rate are valid.
"""
req_keys = ["clr_mode", "clr_base_lr", "clr_max_lr", "clr_gamma"]
keys_present = True
for key in req_keys:
if key not in args.keys():
keys_present = False
return keys_present
[docs]def clr_set_args(args: Dict) -> Dict:
req_keys = ["clr_mode", "clr_base_lr", "clr_max_lr", "clr_gamma"]
exclusive_keys = ["warmup_lr", "reduce_lr"]
keys_present = True
for key in req_keys:
if key not in args.keys():
keys_present = False
if keys_present and args["clr_mode"] is not None:
clr_keras_kwargs = {
"mode": args["clr_mode"],
"base_lr": args["clr_base_lr"],
"max_lr": args["clr_max_lr"],
"gamma": args["clr_gamma"],
}
for ex_key in exclusive_keys:
if ex_key in args.keys():
if args[ex_key] is True:
print("Key ", ex_key, " conflicts, setting to False")
args[ex_key] = False
else:
print("Incomplete CLR specification: will run without")
clr_keras_kwargs = {"mode": None, "base_lr": 0.1, "max_lr": 0.1, "gamma": 0.1}
return clr_keras_kwargs
[docs]def clr_callback(
mode: str = None,
base_lr: float = 1e-4,
max_lr: float = 1e-3,
gamma: float = 0.999994,
) -> Callback:
"""
Creates keras callback for cyclical learning rate.
"""
if mode == "trng1":
clr = CyclicLR(base_lr=base_lr, max_lr=max_lr, mode="triangular")
elif mode == "trng2":
clr = CyclicLR(base_lr=base_lr, max_lr=max_lr, mode="triangular2")
elif mode == "exp":
clr = CyclicLR(
base_lr=base_lr, max_lr=max_lr, mode="exp_range", gamma=gamma
) # 0.99994; 0.99999994; 0.999994
else:
clr = CyclicLR(base_lr=base_lr, max_lr=max_lr, mode=mode)
return clr
[docs]class CyclicLR(Callback):
"""
This callback implements a cyclical learning rate policy (CLR). The
method cycles the learning rate between two boundaries with some constant
frequency.
- base_lr: initial learning rate which is the lower boundary in the cycle.
- max_lr: upper boundary in the cycle. Functionally, it defines the cycle amplitude (max_lr - base_lr). \
The lr at any cycle is the sum of base_lr and some scaling of the amplitude; therefore \
max_lr may not actually be reached depending on scaling function.
- step_size: number of training iterations per half cycle. Authors suggest setting step_size \
2-8 x training iterations in epoch.
- mode: one of {triangular, triangular2, exp_range}. Default 'triangular'. \
Values correspond to policies detailed above. If scale_fn is not None, this argument is ignored.
- gamma: constant in 'exp_range' scaling function: gamma**(cycle iterations)
- scale_fn: Custom scaling policy defined by a single argument lambda function, where \
0 <= scale_fn(x) <= 1 for all x >= 0. mode paramater is ignored
- scale_mode: {'cycle', 'iterations'}. Defines whether scale_fn is evaluated on \
cycle number or cycle iterations (training iterations since start of cycle). Default is 'cycle'.
The amplitude of the cycle can be scaled on a per-iteration or per-cycle basis.
This class has three built-in policies, as put forth in the paper.
- "triangular": A basic triangular cycle w/ no amplitude scaling.
- "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
- "exp_range": A cycle that scales initial amplitude by gamma**(cycle iterations) at each cycle iteration.
For more detail, please see paper.
# Example for CIFAR-10 w/ batch size 100:
.. code-block::
clr = CyclicLR(base_lr=0.001, max_lr=0.006,
step_size=2000., mode='triangular')
model.fit(X_train, Y_train, callbacks=[clr])
Class also supports custom scaling functions:
.. code-block::
clr_fn = lambda x: 0.5*(1+np.sin(x*np.pi/2.))
clr = CyclicLR(base_lr=0.001, max_lr=0.006,
step_size=2000., scale_fn=clr_fn,
scale_mode='cycle')
model.fit(X_train, Y_train, callbacks=[clr])
# References
- [Cyclical Learning Rates for Training Neural Networks](https://arxiv.org/abs/1506.01186)
"""
def __init__(
self,
base_lr=0.001,
max_lr=0.006,
step_size=2000.0,
mode="triangular",
gamma=1.0,
scale_fn=None,
scale_mode="cycle",
):
super(CyclicLR, self).__init__()
if mode not in ["triangular", "triangular2", "exp_range"]:
raise KeyError(
"mode must be one of 'triangular', " "'triangular2', or 'exp_range'"
)
self.base_lr = base_lr
self.max_lr = max_lr
self.step_size = step_size
self.mode = mode
self.gamma = gamma
if scale_fn is None:
if self.mode == "triangular":
self.scale_fn = lambda x: 1.0
self.scale_mode = "cycle"
elif self.mode == "triangular2":
self.scale_fn = lambda x: 1 / (2.0 ** (x - 1))
self.scale_mode = "cycle"
elif self.mode == "exp_range":
self.scale_fn = lambda x: gamma**x
self.scale_mode = "iterations"
else:
self.scale_fn = scale_fn
self.scale_mode = scale_mode
self.clr_iterations = 0.0
self.trn_iterations = 0.0
self.history = {}
self._reset()
def _reset(self, new_base_lr=None, new_max_lr=None, new_step_size=None):
"""Resets cycle iterations.
Optional boundary/step size adjustment.
"""
if new_base_lr is not None:
self.base_lr = new_base_lr
if new_max_lr is not None:
self.max_lr = new_max_lr
if new_step_size is not None:
self.step_size = new_step_size
self.clr_iterations = 0.0
def clr(self):
cycle = np.floor(1 + self.clr_iterations / (2 * self.step_size))
x = np.abs(self.clr_iterations / self.step_size - 2 * cycle + 1)
if self.scale_mode == "cycle":
return self.base_lr + (self.max_lr - self.base_lr) * np.maximum(
0, (1 - x)
) * self.scale_fn(cycle)
else:
return self.base_lr + (self.max_lr - self.base_lr) * np.maximum(
0, (1 - x)
) * self.scale_fn(self.clr_iterations)
[docs] def on_train_begin(self, logs={}):
logs = logs or {}
if self.clr_iterations == 0:
K.set_value(self.model.optimizer.lr, self.base_lr)
else:
K.set_value(self.model.optimizer.lr, self.clr())
[docs] def on_batch_end(self, epoch, logs=None):
logs = logs or {}
self.trn_iterations += 1
self.clr_iterations += 1
K.set_value(self.model.optimizer.lr, self.clr())
self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))
self.history.setdefault("iterations", []).append(self.trn_iterations)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
[docs] def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs["lr"] = K.get_value(self.model.optimizer.lr)