Source code for candle.ckpt_keras_utils

"""
CKPT KERAS UTILS.

CANDLE checkpoint/restart utilities for Keras

"""

# Python imports
from typing import Dict

# TensorFlow imports
from tensorflow.keras.callbacks import Callback, ModelCheckpoint
from tensorflow.keras.models import Model

# CANDLE imports
from .ckpt_utils import CandleCkpt


[docs]class MultiGPUCheckpoint(ModelCheckpoint): def set_model(self, model): if isinstance(model.layers[-2], Model): self.model = model.layers[-2] else: self.model = model
class CandleCkptKeras(CandleCkpt, Callback): """ Keras Callback for CANDLE-compliant Benchmarks to use for checkpointing Creates a JSON file alongside the weights and optimizer checkpoints that includes important metadata, particularly for restarting and tracking complex workflows. """ def __init__(self, gParameters: Dict, logger="DEFAULT", verbose=True): super().__init__(gParameters, logger, verbose) def set_model(self, model): """model: The Keras model""" self.model = model # Keras Callback API: def on_epoch_end(self, epoch, logs=None): if self.save_best_metric not in logs.keys(): raise Exception( ( "CandleCheckpointCallback: " + "save_best_metric='%s' " + "not in list of model metrics: %s" ) % (self.save_best_metric, str(logs.keys())) ) # Known metrics and direction of progress known_metrics = { "loss": "-", "accuracy": "+", "val_loss": "-", "val_accuracy": "+", "lr": "-", } if self.save_best_metric not in known_metrics.keys(): raise Exception( ( "CandleCheckpointCallback: " + "save_best_metric='%s' " + "not in list of known_metrics: %s" ) % (self.save_best_metric, str(known_metrics.keys())) ) metric_value = logs[self.save_best_metric] direction = known_metrics[self.save_best_metric] self.ckpt_epoch(epoch, direction, metric_value) def write_model_backend(self, model, epoch): # epoch: unused by Keras # Keras-specific method model.save(self.model_file) # save_format="h5" def build_model(self, model_file): self.model.load_weights(model_file)