Source code for ckpt_keras_utils

"""
CKPT KERAS UTILS

CANDLE checkpoint/restart utilities for Keras

Hyperparameters that affect CANDLE checkpoint/restart:

ckpt_restart_mode :  "off" | "auto" | "required"
    If 'auto' or 'required', automatically try to restart from most recent
    (highest epoch) model.h5.
    'required' will fail if a model cannot be found.
    Default: "auto"

ckpt_save_best : boolean
    If True, save whenever save_best_metric has improved.
    Default: True

ckpt_save_best_metric : string
    Required when ckpt_save_best=True, else unused.
    The metric in logs.model to track for improvement.
    Default: "val_loss"

ckpt_skip_epochs : integer
    Number of initial epochs to skip before writing checkpoints
    Default: 0

ckpt_save_interval: integer
    Save whenever epoch % save_interval == 0.
    Set save_interval=0 to disable these checkpoints (save nothing).
    Default: 1 (save everything)

ckpt_checksum : boolean
    If True, compute a checksum for the model
    and store it in the JSON.
    Also, confirm checksum at restart time.
    Default: False

ckpt_keep_mode : string
    "linear" or "exponential" (NYI)
    Default: "linear"

ckpt_keep_limit: integer GZ
    Maximal number of checkpoints to keep.
    This can be set lower to reduce disk usage.
    Default: 1000000

ckpt_directory: string
    The top directory to use.
    Default: "./save"
    Typical user values:
    "/tmp/user/ckpts": I.e. I am going to move these myself.
    "/other-fs/user/ckpts": I.e. My working FS is different from the FS
                            I want to use for checkpoints.

ckpt_metadata : string
    Arbitrary string to add to the JSON file regarding
    job ID, hardware location, etc.
    May be None or an empty string.
    Default: None

Usage:

  Add before training:

    initial_epoch = 0
    J = candle.restart(gParameters, model)
    if J is not None:
        initial_epoch = J['epoch']

  Set up a callback for checkpoints:

    ckpt = candle.CandleCheckpointCallback(gParameters)
    history = model.fit(epochs=gParameters['epochs'],
                        initial_epoch=initial_epoch,
                        ...
                        callbacks=[... , ckpt])

  Optionally, log a final report:

    ckpt.report_final()

Controlling restart:

Most restart control options are in gParameters.

Normally, restart() looks at the soft link ckpts/last,
which should point to a good ckpt directory number in epochs/*
and restarts from there.

To roll back, simply re-link ckpts/last to point to
a different directory number in epochs/* .
Any later epochs will simply be overwritten
(and a debug message will be reported).

If ckpts/last is missing or a broken link,
restart() will start from scratch.

Keep policy:

The ckpt_keep settings only apply to the current run.
Checkpoints from prior runs will never be deleted by clean().
You may simply remove any of them.
Normally you will not want to remove the one pointed to by ckpts/last,
but if you do, restart() will simply start from scratch.

Logging:

A log of ckpt operations is in ckpt_directory/ckpt.log

"""

import json
import os
import shutil
import time

from pathlib import PosixPath

from helper_utils import set_up_logger, str2bool
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback, ModelCheckpoint


[docs]class MultiGPUCheckpoint(ModelCheckpoint):
[docs] def set_model(self, model): if isinstance(model.layers[-2], Model): self.model = model.layers[-2] else: self.model = model
[docs]class ParamRequired: """ Indicates that the user params must contain this key """ pass
[docs]class CandleCheckpointCallback(Callback): """ Keras Callback for CANDLE-compliant Benchmarks to use for checkpointing Creates a JSON file alongside the weights and optimizer checkpoints that includes important metadata, particularly for restarting and tracking complex workflows. """ def __init__(self, gParameters, logger="DEFAULT", verbose=True): """ Parameters ---------- logger : Logger The logger to use. May be None to disable or "DEFAULT" to use the default. verbose : boolean If True, more verbose logging Passed to helper_utils.set_up_logger(verbose) for this logger """ self.logger = logger if self.logger == "DEFAULT": import logging self.logger = logging.getLogger("CandleCheckpointCallback") set_up_logger("save/ckpt.log", self.logger, verbose=verbose, fmt_line="%(asctime)s CandleCheckpoint: %(message)s") self.scan_params(gParameters) # List of epoch integers this instance has written. # Sorted from smallest to largest. self.epochs = [] # The best epoch wrt metric. Do not delete this! self.epoch_best = 0 self.report_initial()
[docs] def report_initial(self): """ Simply report that we are ready to run """ self.info("Callback initialized.") if self.save_interval == 0: self.info("Checkpoint save interval == 0 " + "-> checkpoints are disabled.") return # Skip the rest of this output if self.metadata is not None: self.info("metadata='%s'" % self.metadata) if self.save_best: self.info("save_best_metric='%s'" % self.save_best_metric) self.info("PWD: " + os.getcwd()) self.info("ckpt_directory: %s" % PosixPath(self.ckpt_directory).resolve())
[docs] def scan_params(self, gParameters): """ Simply translate gParameters into instance fields """ self.epoch_max = param(gParameters, "epochs", ParamRequired(), ParamType.INTEGER_NN) self.skip_epochs = param(gParameters, "ckpt_skip_epochs", 0, ParamType.INTEGER_NN) self.ckpt_directory = param(gParameters, "ckpt_directory", "./save", ParamType.STRING) self.save_best = param(gParameters, "ckpt_save_best", True, ParamType.BOOLEAN) self.save_best_metric = param(gParameters, "ckpt_save_best_metric", None, ParamType.STRING) self.best_metric_last = param(gParameters, "ckpt_best_metric_last", None, ParamType.FLOAT) if self.best_metric_last is None: import math self.best_metric_last = math.inf self.save_interval = param(gParameters, "ckpt_save_interval", 1, ParamType.INTEGER_NN) self.save_weights_only = param(gParameters, "ckpt_save_weights_only", True, ParamType.BOOLEAN) self.checksum_enabled = param(gParameters, "ckpt_checksum", False, ParamType.BOOLEAN) self.keep_mode = param(gParameters, "ckpt_keep_mode", "linear", ParamType.STRING, allowed=[None, "all", "linear"]) self.keep_limit = param(gParameters, "ckpt_keep_limit", 1000000, ParamType.INTEGER_GZ) self.metadata = param(gParameters, "metadata", None, ParamType.STRING) self.timestamp_last = param(gParameters, "ckpt_timestamp_last", None, ParamType.STRING) self.cwd = os.getcwd()
[docs] def on_epoch_end(self, epoch, logs=None): """ Note: We immediately increment epoch from index-from-0 to index-from-1 to match the TensorFlow output. Normally, ckpts/best is the best saved state, and ckpts/last is the last saved state. Procedure: 1. Write current state to ckpts/work 2. Rename ckpts/work to ckpts/epoch/NNN 3. If best, link ckpts/best to ckpts/epoch/NNN 4. Link ckpts/last to ckpts/epoch/NNN 5. Clean up old ckpts according to keep policy """ epoch += 1 dir_root = PosixPath(self.ckpt_directory).resolve() dir_work = dir_root / "ckpts/work" dir_best = dir_root / "ckpts/best" # a soft link dir_last = dir_root / "ckpts/last" # a soft link dir_epochs = dir_root / "ckpts/epochs" dir_this = dir_epochs / ("%03i" % epoch) if not self.save_check(logs, epoch): return if os.path.exists(dir_this): self.debug("remove: '%s'" % self.relpath(dir_this)) shutil.rmtree(dir_this) os.makedirs(dir_epochs, exist_ok=True) os.makedirs(dir_work, exist_ok=True) self.write_model(dir_work, epoch) self.debug("rename: '%s' -> '%s'" % (self.relpath(dir_work), self.relpath(dir_this))) os.rename(dir_work, dir_this) self.epochs.append(epoch) if self.epoch_best == epoch: self.symlink(dir_this, dir_best) self.symlink(dir_this, dir_last) self.clean(epoch)
[docs] def save_check(self, logs, epoch): """ Make sure we want to save this epoch based on the model metrics in given logs Also updates epoch_best if appropriate """ if self.save_interval == 0: return False # Checkpoints are disabled. # skip early epochs to improve speed if epoch < self.skip_epochs: self.debug("model saving disabled until epoch %d" % self.skip_epochs) return False # Do this first- it may set epoch_best: if self.save_check_best(logs, epoch): # The model improved- save! self.epoch_best = epoch return True if epoch == self.epoch_max: self.info("writing final epoch %i ..." % epoch) return True # Final epoch - save! if epoch % self.save_interval == 0: return True # We are on the save_interval: save! # else- not saving: self.debug("not writing this epoch.") return False
[docs] def save_check_best(self, logs, epoch): if not self.save_best: return False if self.save_best_metric not in logs.keys(): raise Exception(("CandleCheckpointCallback: " + "save_best_metric='%s' " + "not in list of model metrics: %s") % (self.save_best_metric, str(logs.keys()))) # Known metrics and direction of progress known_metrics = {"loss": "-", "accuracy": "+", "val_loss": "-", "val_accuracy": "+", "lr": "-"} if self.save_best_metric not in known_metrics.keys(): raise Exception(("CandleCheckpointCallback: " + "save_best_metric='%s' " + "not in list of known_metrics: %s") % (self.save_best_metric, str(known_metrics.keys()))) # Logging: if logs[self.save_best_metric] < self.best_metric_last: symbol = "<" elif logs[self.save_best_metric] > self.best_metric_last: symbol = ">" else: symbol = "=" self.debug("metrics: %s: current=%f %s last=%f" % (self.save_best_metric, logs[self.save_best_metric], symbol, self.best_metric_last)) # Check for improvement: improved = False # did the metric improve this epoch? if known_metrics[self.save_best_metric] == "-": if logs[self.save_best_metric] < self.best_metric_last: improved = True elif known_metrics[self.save_best_metric] == "+": if logs[self.save_best_metric] > self.best_metric_last: improved = True else: assert(False) if improved: self.best_metric_last = logs[self.save_best_metric] self.epoch_best = epoch return True return False
[docs] def write_model(self, dir_work, epoch): """ Do the I/O, report stats dir_work: A PosixPath """ model_file = dir_work / "model.h5" self.debug("writing model to: '%s'" % self.relpath(model_file)) start = time.time() # TODO: Implement save_weights_only self.model.save(model_file) # save_format="h5" stop = time.time() duration = stop - start stats = os.stat(model_file) MB = stats.st_size / (1024 * 1024) rate = MB / duration self.debug("model wrote: %0.3f MB in %0.3f seconds (%0.2f MB/s)." % (MB, duration, rate)) self.checksum(dir_work) self.write_json(dir_work / "ckpt-info.json", epoch)
[docs] def checksum(self, dir_work): """ Simple checksum dispatch dir_work: A PosixPath """ if self.checksum_enabled: self.cksum_model = checksum_file(self.logger, dir_work / "model.h5") else: self.cksum_model = "__DISABLED__"
[docs] def write_json(self, jsonfile, epoch): from datetime import datetime now = datetime.now() D = {} D["epoch"] = epoch D["save_best_metric"] = self.save_best_metric D["best_metric_last"] = self.best_metric_last D["model_file"] = "model.h5" D["checksum"] = self.cksum_model D["timestamp"] = now.strftime("%Y-%m-%d %H:%M:%S") if self.timestamp_last is None: time_elapsed = "__FIRST__" else: time_elapsed = (now - self.timestamp_last).total_seconds() self.timestamp_last = now D["time_elapsed"] = time_elapsed D["metadata"] = self.metadata with open(jsonfile, "w") as fp: json.dump(D, fp) fp.write("\n")
[docs] def clean(self, epoch_now): """ Clean old epoch directories in accordance with ckpt_keep policies. Return number of checkpoints kept and deleted """ deleted = 0 kept = 0 # Consider most recent epochs first: for epoch in reversed(self.epochs): self.debug('clean(): checking epoch directory: %i' % epoch) if not self.keep(epoch, epoch_now, kept): deleted += 1 self.delete(epoch) self.debug('clean(): deleted epoch: %i' % epoch) else: kept += 1 return (kept, deleted)
[docs] def keep(self, epoch, epoch_now, kept): """ kept: Number of epochs already kept return True if we are keeping this epoch, else False """ if epoch == epoch_now: # We just wrote this! self.debug('keep(): epoch is latest: %i' % epoch) return True if self.epoch_best == epoch: # This is the best epoch self.debug('keep(): epoch is best: %i' % epoch) return True if kept < self.keep_limit: self.debug('keep(): epoch count is < limit %i' % self.keep_limit) return True # No reason to keep this: delete it: return False
[docs] def delete(self, epoch): dir_old = "save/ckpts/epochs/%03i" % epoch if os.path.exists(dir_old): self.debug("removing: '%s'" % dir_old) shutil.rmtree(dir_old) else: self.info("checkpoint for epoch=%i disappeared!" % epoch) self.epochs.remove(epoch)
[docs] def relpath(self, p): return p.relative_to(self.cwd)
[docs] def info(self, message): if self.logger is not None: self.logger.info(message)
[docs] def debug(self, message): if self.logger is not None: self.logger.debug(message)
[docs] def on_train_end(self, logs=None): self.report_final()
[docs] def report_final(self): self.info("checkpoints kept: %i" % len(self.epochs)) self.info("checkpoints list: %s" % str(self.epochs))
[docs]def restart(gParameters, model, verbose=True): """ Possibly restarts model from CheckpointCallback according to given settings and the ckpt-info.json return The JSON dict if the restart happened or None if the restart did not happen. """ import logging logger = logging.getLogger("Candle.restart") directory = param(gParameters, "ckpt_directory", "./save") set_up_logger(directory + "/ckpt.log", logger, verbose=verbose, fmt_line="%(asctime)s CANDLE restart(): %(message)s") param_ckpt_mode = param(gParameters, "ckpt_restart_mode", "auto", allowed=["off", "auto", "required"]) if param_ckpt_mode == "off": return None dir_last = "save/ckpts/last" model_file = dir_last + "/model.h5" if not os.path.exists(model_file): if param_ckpt_mode == "required": raise Exception("ckpt_restart_mode=='required' but no checkpoint " + "could be found!") # We must be under AUTO - proceed without restart assert param_ckpt_mode == "auto" return None logger.info("restarting: '%s'" % model_file) result = restart_json(gParameters, logger, dir_last) logger.info("restarting: epoch=%i timestamp=%s", result["epoch"], result["timestamp"]) start = time.time() stats = os.stat(model_file) MB = stats.st_size / (1024 * 1024) model.load_weights(model_file) stop = time.time() duration = stop - start rate = MB / duration logger.info("restarting: model read: %0.3f MB in %0.3f seconds (%0.2f MB/s).", MB, duration, rate) return result
[docs]def restart_json(gParameters, logger, directory): json_file = directory + "/ckpt-info.json" if not os.path.exists(json_file): msg = "restart_json(): in: %s model exists but not json!" % \ directory logger.info(msg) if not disabled(gParameters, "require_json"): raise Exception(msg) with open(json_file) as fp: J = json.load(fp) # print(str(J)) logger.debug("ckpt-info.json contains:") logger.debug(json.dumps(J, indent=2)) logger.info("restarting from epoch: %i" % J["epoch"]) if param(gParameters, "ckpt_checksum", False, ParamType.BOOLEAN): checksum = checksum_file(logger, directory + "/model.h5") if checksum != J["checksum"]: raise Exception("checksum mismatch! directory: " % directory) return J
from enum import Enum, unique, auto
[docs]@unique class ParamType(Enum): """ Possible gParameters types """ STRING = auto() BOOLEAN = auto() INTEGER = auto() # integer: non-negative INTEGER_NN = auto() # integer: greater-than-zero INTEGER_GZ = auto() FLOAT = auto() FLOAT_NN = auto()
[docs]def enabled(gParameters, key): """ Is this parameter set to True? """ return key in gParameters and gParameters[key]
[docs]def disabled(gParameters, key): """ Is this parameter set to False? """ return key in gParameters and not gParameters[key]
[docs]def param(gParameters, key, dflt, type_=ParamType.STRING, allowed=None): """ Pull key from parameters with type checks and conversions """ if key in gParameters: result = gParameters[key] else: if isinstance(dflt, ParamRequired): raise Exception("param key must be provided: '%s'" % key) result = dflt result = param_type_check(key, result, type_) param_allowed(key, result, allowed) return result
[docs]def param_type_check(key, value, type_): """ Check that value is convertable to given type: if not, raise TypeError Return the value as converted to given type """ if value is None: return value if type_ is ParamType.STRING: return str(value) if type_ is ParamType.BOOLEAN: return param_type_check_bool(key, value) if type_ is ParamType.INTEGER or \ type_ is ParamType.INTEGER_NN or \ type_ is ParamType.INTEGER_GZ: return param_type_check_int(key, value, type_) if type_ is ParamType.FLOAT or \ type_ is ParamType.FLOAT_NN: return param_type_check_float(key, value, type_) raise TypeError("param_type_check(): unknown type: '%s'" % str(type_))
[docs]def param_type_check_bool(key, value): if isinstance(value, bool): return value try: v = str2bool(value) except TypeError: raise TypeError("parameter: '%s' is '%s' but must be a %s" % key, str(value), str(ParamType.BOOLEAN)) return v
[docs]def param_type_check_int(key, value, type_): if isinstance(value, int): result = value else: try: result = int(value) except TypeError: raise TypeError("parameter: '%s' is '%s' but must be a %s" % (key, str(value), str(type_))) if type_ == ParamType.INTEGER_NN: if result < 0: raise TypeError(("parameter: '%s' is '%s' " + "but must be non-negative") % (key, str(value))) if type_ == ParamType.INTEGER_GZ: if result <= 0: raise TypeError(("parameter: '%s' is '%s' " + "but must be greater-than-zero") % (key, str(value))) return result
[docs]def param_type_check_float(key, value, type_): if isinstance(value, float): result = value else: try: result = float(value) except TypeError: raise TypeError("parameter: '%s' is '%s' but must be a %s" % (key, str(value), str(type_))) if type_ == ParamType.FLOAT_NN: if result < 0: raise TypeError(("parameter: '%s' is '%s' " + "but must be non-negative") % (key, str(value))) return result
[docs]def checksum_file(logger, filename): """ Read file, compute checksum, return it as a string. """ import zlib start = time.time() chunk_size = 10 * 1024 * 1024 total = 0 with open(filename, "rb") as fp: checksum = 0 while True: chunk = fp.read(chunk_size) if not chunk: break total += len(chunk) checksum = zlib.crc32(chunk, checksum) stop = time.time() MB = total / (1024 * 1024) duration = stop - start rate = MB / duration logger.info("checksummed: %0.3f MB in %.3f seconds (%.2f MB/s)." % (MB, duration, rate)) return str(checksum)
[docs]def param_allowed(key, value, allowed): """ Check that the value is in the list of allowed values If allowed is None, there is no check, simply success """ if allowed is None: return if value not in allowed: raise ValueError(("hyperparameter '%s'='%s' is not in the " + "list of allowed values: %s") % (key, value, str(allowed)))
[docs]def ckpt_parser(parser): # global parser.add_argument("--ckpt_restart_mode", type=str, default='auto', choices=['off', 'auto', 'required'], help="Mode to restart from a saved checkpoint file, " + "choices are 'off', 'auto', 'required'") parser.add_argument("--ckpt_checksum", type=str2bool, default=False, help="Checksum the restart file after read+write") parser.add_argument("--ckpt_skip_epochs", type=int, default=0, help="Number of epochs to skip before saving epochs") parser.add_argument("--ckpt_directory", type=str, default='./save', help="Base directory in which to save checkpoints") # saving parser.add_argument("--ckpt_save_best", type=str2bool, default=True, help="Toggle saving best model") parser.add_argument("--ckpt_save_best_metric", type=str, default="val_loss", help="Metric for determining when to save best model") parser.add_argument("--ckpt_save_weights_only", type=str2bool, default=False, help="Toggle saving only weights (not optimizer) (NYI)") parser.add_argument("--ckpt_save_interval", type=int, default=1, help="Epoch interval to save checkpoints. " + "Set to 0 to disable writing checkpoints") # keeping parser.add_argument("--ckpt_keep_mode", choices=['linear', 'exponential'], help="Checkpoint saving mode. " + "Choices are 'linear' or 'exponential' ") parser.add_argument("--ckpt_keep_limit", type=int, default=1000000, help="Limit checkpoints to keep") return parser
[docs]def ckpt_defs(defs): # defs is an existing list # global new_defs = [ {'name': 'ckpt_restart_mode', 'type': str, 'default': 'auto', 'choices': ['off', 'auto', 'required'], 'help': 'Mode to restart from a saved checkpoint file'}, {'name': 'ckpt_checksum', 'type': str2bool, 'default': False, 'help': 'Checksum the restart file after read+write'}, {'name': 'ckpt_skip_epochs', 'type': int, 'default': 0, 'help': 'Number of epochs to skip before saving epochs'}, {'name': 'ckpt_directory', 'type': str, 'default': './save', 'help': 'Base directory in which to save checkpoints'}, # saving {'name': 'ckpt_save_best', 'type': str2bool, 'default': True, 'help': 'Toggle saving best model'}, {'name': 'ckpt_save_best_metric', 'type': str, 'default': 'val_loss', 'help': 'Metric for determining when to save best model'}, {'name': 'ckpt_save_weights_only', 'type': str2bool, 'default': False, 'help': 'Toggle saving only weights (not optimizer) (NYI)'}, {'name': 'ckpt_save_interval', 'type': int, 'default': 1, 'help': 'Interval to save checkpoints'}, # keeping {'name': 'ckpt_keep_mode', 'choices': ['linear', 'exponential'], 'help': 'Checkpoint saving mode. ' + "choices are 'linear' or 'exponential' "}, {'name': 'ckpt_keep_limit', 'type': int, 'default': 1000000, 'help': 'Limit checkpoints to keep'} ] defs = defs + new_defs return defs