import argparse
import configparser
import inspect
import os
from typing import Any, List, Optional, Set
from candle.helper_utils import eval_string_as_list_of_lists
from candle.parsing_utils import (
ConfigDict,
ParseDict,
finalize_parameters,
parse_common,
parse_from_dictlist,
registered_conf,
)
DType = Any
[docs]class Benchmark:
"""
Class that implements an interface to handle configuration options for
the different CANDLE benchmarks.
It provides access to all the common configuration options and
configuration options particular to each individual benchmark. It
describes what minimum requirements should be specified to
instantiate the corresponding benchmark. It interacts with the
argparser to extract command-line options and arguments from the
benchmark's configuration files.
"""
def __init__(
self,
filepath: str,
defmodel: str,
framework: str,
prog: str = None,
desc: str = None,
parser=None,
additional_definitions=None,
required=None,
) -> None:
"""
Initialize Benchmark object.
:param string filepath: ./
os.path.dirname where the benchmark is located. Necessary to locate utils and
establish input/ouput paths
:param string defmodel: 'p*b*_default_model.txt'
string corresponding to the default model of the benchmark
:param string framework : 'keras', 'neon', 'mxnet', 'pytorch'
framework used to run the benchmark
:param string prog: 'p*b*_baseline_*'
string for program name (usually associated to benchmark and framework)
:param string desc: ' '
string describing benchmark (usually a description of the neural network model built)
:param argparser parser: (default None)
if 'neon' framework a NeonArgparser is passed. Otherwise an argparser is constructed.
"""
# Check that required system variable specifying path to data has been defined
if os.getenv("CANDLE_DATA_DIR") is None:
raise Exception(
"ERROR ! Required system variable not specified. You must define CANDLE_DATA_DIR ... Exiting"
)
# Check that default model configuration exits
fname = os.path.join(filepath, defmodel)
if not os.path.isfile(fname):
raise Exception(
"ERROR ! Required default configuration file not available. File "
+ fname
+ " ... Exiting"
)
self.model_name = self.get_parameter_from_file(fname, "model_name")
print("model name: ", self.model_name)
if parser is None:
parser = argparse.ArgumentParser(
prog=prog,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=desc,
conflict_handler="resolve",
)
self.parser = parser
self.file_path = filepath
self.default_model = defmodel
self.framework = framework
self.registered_conf: List[ParseDict] = []
for lst in registered_conf:
self.registered_conf.extend(lst)
if required is not None:
self.required = set(required)
else:
self.required: Set[str] = set([])
if additional_definitions is not None:
self.additional_definitions = additional_definitions
else:
self.additional_definitions: List[ParseDict] = []
# legacy call for compatibility with existing Benchmarks
self.set_locals()
[docs] def parse_parameters(self) -> None:
"""
Functionality to parse options common for all benchmarks.
This functionality is based on methods 'get_default_neon_parser'
and 'get_common_parser' which are defined previously(above). If
the order changes or they are moved, the calling has to be
updated.
"""
# Parse has been split between arguments that are common with the default neon parser
# and all the other options
self.parser = parse_common(self.parser)
self.parser = parse_from_dictlist(self.additional_definitions, self.parser)
# Set default configuration file
self.conffile = os.path.join(self.file_path, self.default_model)
[docs] def read_config_file(self, file: str) -> ConfigDict:
"""
Functionality to read the configue file specific for each
benchmark.
:param string file: path to the configuration file
:return: parameters read from configuration file
:rtype: ConfigDict
"""
config = configparser.ConfigParser()
config.read(file)
section = config.sections()
fileParams = {}
# parse specified arguments (minimal validation: if arguments
# are written several times in the file, just the first time
# will be used)
for sec in section:
for k, v in config.items(sec):
# if not k in fileParams:
if k not in fileParams:
fileParams[k] = eval(v)
fileParams = self.format_benchmark_config_arguments(fileParams)
# print(fileParams)
return fileParams
[docs] def get_parameter_from_file(self, absfname, param):
"""
Functionality to extract the value of one parameter from the configuration file given. Execution is terminated if the parameter specified is not found in the configuration file.
:param string absfname: filename of the the configuration file including absolute path.
:param string param: parameter to extract from configuration file.
:return: a string with the value of the parameter read from the configuration file.
:rtype: string
"""
aux = ""
with open(absfname, "r") as fp:
for line in fp:
# search string
if param in line:
aux = line.split("=")[-1].strip("'\n ")
# don't look for next lines
break
if aux == "":
raise Exception(
"ERROR ! Parameter "
+ param
+ " was not found in file "
+ absfname
+ "... Exiting"
)
return aux
[docs] def set_locals(self):
"""
Functionality to set variables specific for the benchmark.
- required: set of required parameters for the benchmark.
- additional_definitions: list of dictionaries describing \
the additional parameters for the benchmark.
"""
pass
[docs] def check_required_exists(self, gparam: ConfigDict) -> None:
"""
Functionality to verify that the required model parameters have been
specified.
"""
key_set = set(gparam.keys())
intersect_set = key_set.intersection(self.required)
diff_set = self.required.difference(intersect_set)
if len(diff_set) > 0:
raise Exception(
"ERROR ! Required parameters are not specified. These required parameters have not been initialized: "
+ str(sorted(diff_set))
+ "... Exiting"
)
def create_params(
file_path=None,
default_model=None,
framework=None,
prog_name=None,
desc=None,
additional_definitions=None,
required=None,
):
print("Generating parameters for standard benchmark\n")
# file_path = os.path.dirname(os.path.realpath(__file__))
tmp_bmk = Benchmark(
file_path,
default_model,
framework,
prog_name,
desc,
additional_definitions=additional_definitions,
required=required,
)
params = finalize_parameters(tmp_bmk)
return params