Source code for solr_keras

import json
import os
from datetime import datetime

import numpy as np
from tensorflow.keras.callbacks import Callback


[docs]def compute_trainable_params(model): """ Extract number of parameters from the given Keras model Parameters ----------- model : Keras model Return ---------- python dictionary that contains trainable_params, non_trainable_params and total_params """ if str(type(model)).startswith("<class 'keras."): from keras import backend as K else: import tensorflow.keras.backend as K trainable_count = int( np.sum([K.count_params(w) for w in model.trainable_weights]) ) non_trainable_count = int( np.sum([K.count_params(w) for w in model.non_trainable_weights]) ) return {'trainable_params': trainable_count, 'non_trainable_params': non_trainable_count, 'total_params': (trainable_count + non_trainable_count)}
[docs]class CandleRemoteMonitor(Callback): """Capture Run level output and store/send for monitoring """ def __init__(self, params=None): super(CandleRemoteMonitor, self).__init__() self.global_params = params # init self.experiment_id = None self.run_id = None self.run_timestamp = None self.epoch_timestamp = None self.log_messages = []
[docs] def on_train_begin(self, logs=None): logs = logs or {} self.run_timestamp = datetime.now() self.experiment_id = self.global_params['experiment_id'] if 'experiment_id' in self.global_params else "EXP_default" self.run_id = self.global_params['run_id'] if 'run_id' in self.global_params else "RUN_default" run_params = [] for key, val in self.global_params.items(): run_params.append("{}: {}".format(key, val)) send = {'experiment_id': self.experiment_id, 'run_id': self.run_id, 'parameters': run_params, 'start_time': str(self.run_timestamp), 'status': 'Started' } # print("on_train_begin", send) self.log_messages.append(send)
[docs] def on_epoch_begin(self, epoch, logs=None): self.epoch_timestamp = datetime.now()
[docs] def on_epoch_end(self, epoch, logs=None): logs = logs or {} loss = logs.get('loss') val_loss = logs.get('val_loss') epoch_total = self.global_params['epochs'] epoch_duration = datetime.now() - self.epoch_timestamp epoch_in_sec = epoch_duration.total_seconds() epoch_line = "epoch: {}/{}, duration: {}s, loss: {}, val_loss: {}".format( (epoch + 1), epoch_total, epoch_in_sec, loss, val_loss) send = {'run_id': self.run_id, 'status': {'set': 'Running'}, 'training_loss': {'set': loss}, 'validation_loss': {'set': val_loss}, 'run_progress': {'add': [epoch_line]} } # print("on_epoch_end", send) self.log_messages.append(send)
[docs] def on_train_end(self, logs=None): logs = logs or {} run_end = datetime.now() run_duration = run_end - self.run_timestamp run_in_hour = run_duration.total_seconds() / (60 * 60) send = {'run_id': self.run_id, 'runtime_hours': {'set': run_in_hour}, 'end_time': {'set': str(run_end)}, 'status': {'set': 'Finished'}, 'date_modified': {'set': 'NOW'} } # print("on_train_end", send) self.log_messages.append(send) # save to file when finished self.save()
[docs] def save(self): """Save log_messages to file """ # path = os.getenv('TURBINE_OUTPUT') if 'TURBINE_OUTPUT' in os.environ else '.' path = self.global_params['output_dir'] if 'output_dir' in self.global_params else '.' if not os.path.exists(path): os.makedirs(path) filename = "/run.{}.json".format(self.run_id) with open(path + filename, "a") as file_run_json: file_run_json.write(json.dumps(self.log_messages, indent=4, separators=(',', ': ')))
[docs]class TerminateOnTimeOut(Callback): """ This class implements timeout on model training. When the script reaches timeout, this class sets model.stop_training = True """ def __init__(self, timeout_in_sec=10): """Initialize TerminateOnTimeOut class. Parameters ----------- timeout_in_sec : int seconds to timeout """ super(TerminateOnTimeOut, self).__init__() self.run_timestamp = None self.timeout_in_sec = timeout_in_sec
[docs] def on_train_begin(self, logs={}): """ Start clock to calculate timeout """ self.run_timestamp = datetime.now()
[docs] def on_epoch_end(self, epoch, logs={}): """ On every epoch end, check whether it exceeded timeout and terminate training if necessary """ run_end = datetime.now() run_duration = run_end - self.run_timestamp run_in_sec = run_duration.total_seconds() print('Current time ....%2.3f' % run_in_sec) if self.timeout_in_sec != -1: if run_in_sec >= self.timeout_in_sec: print('Timeout==>Runtime: %2.3fs, Maxtime: %2.3fs' % (run_in_sec, self.timeout_in_sec)) self.model.stop_training = True