diff --git a/hippynn/experiment/metric_tracker.py b/hippynn/experiment/metric_tracker.py index 989042c7..f43426e6 100644 --- a/hippynn/experiment/metric_tracker.py +++ b/hippynn/experiment/metric_tracker.py @@ -30,7 +30,7 @@ class MetricTracker: """ - def __init__(self, metric_names, stopping_key, quiet=False, split_names=("train", "valid", "test")): + def __init__(self, metric_names, stopping_key, quiet=False): """ :param metric_names: @@ -48,7 +48,7 @@ def __init__(self, metric_names, stopping_key, quiet=False, split_names=("train" self.n_metrics = len(metric_names) # State variables - self.best_metric_values = {split: {mtype: float("inf") for mtype in self.metric_names} for split in split_names} + self.best_metric_values = {} self.other_metric_values = {} self.best_model = None self.epoch_times = [] @@ -78,7 +78,18 @@ def register_metrics(self, metric_info, when): better_metrics = {k: {} for k in self.best_metric_values} for split_type, typevals in metric_info.items(): for mname, mval in typevals.items(): - better = self.best_metric_values[split_type][mname] > mval + try: + old_best = self.best_metric_values[split_type][mname] + better = old_best > mval + del old_best # marking not needed. + except KeyError: + if split_type not in self.best_metric_values: + # Haven't seen this split before! + print("ADDING ",split_type) + self.best_metric_values[split_type] = {} + better_metrics[split_type] = {} + better = True # old best was not found! + if better: self.best_metric_values[split_type][mname] = mval better_metrics[split_type][mname] = better @@ -97,13 +108,17 @@ def register_metrics(self, metric_info, when): return better_metrics, better_model, stopping_key_metric - def evaluation_print(self, evaluation_dict): - if self.quiet: + def evaluation_print(self, evaluation_dict, quiet=None): + if quiet is None: + quiet = self.quiet + if quiet: return table_evaluation_print(evaluation_dict, self.metric_names, self.name_column_width) - def evaluation_print_better(self, evaluation_dict, better_dict): - if self.quiet: + def evaluation_print_better(self, evaluation_dict, better_dict, quiet=None): + if quiet is None: + quiet = self.quiet + if quiet: return table_evaluation_print_better(evaluation_dict, better_dict, self.metric_names, self.name_column_width) if self.stopping_key: @@ -119,14 +134,14 @@ def plot_over_time(self): # Driver for printing evaluation table results, with * for better entries. # Decoupled from the estate in case we want to more easily change print formatting. -def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, ncs): +def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_columns): """ Print metric results as a table, add a '*' character for metrics in better_dict. :param evaluation_dict: dict[eval type]->dict[metric]->value :param better_dict: dict[eval type]->dict[metric]->bool :param metric_names: Names - :param ncs: Number of columns for name fields. + :param n_columns: Number of columns for name fields. :return: None """ type_names = evaluation_dict.keys() @@ -139,8 +154,8 @@ def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, nc n_types = len(type_names) - header = " " * (ncs + 2) + "".join("{:>14}".format(tn) for tn in type_names) - rowstring = "{:<" + str(ncs) + "}: " + " {}{:>10.5g}" * n_types + header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) + rowstring = "{:<" + str(n_columns) + "}: " + " {}{:>10.5g}" * n_types print(header) print("-" * len(header)) @@ -151,13 +166,13 @@ def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, nc # Driver for printing evaluation table results. # Decoupled from the estate in case we want to more easily change print formatting. -def table_evaluation_print(evaluation_dict, metric_names, ncs): +def table_evaluation_print(evaluation_dict, metric_names, n_columns): """ Print metric results as a table. :param evaluation_dict: dict[eval type]->dict[metric]->value :param metric_names: Names - :param ncs: Number of columns for name fields. + :param n_columns: Number of columns for name fields. :return: None """ @@ -166,8 +181,8 @@ def table_evaluation_print(evaluation_dict, metric_names, ncs): n_types = len(type_names) - header = " " * (ncs + 2) + "".join("{:>14}".format(tn) for tn in type_names) - rowstring = "{:<" + str(ncs) + "}: " + " {:>10.5g}" * n_types + header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) + rowstring = "{:<" + str(n_columns) + "}: " + " {:>10.5g}" * n_types print(header) print("-" * len(header)) diff --git a/hippynn/experiment/routines.py b/hippynn/experiment/routines.py index 596e6fd2..f6aee191 100644 --- a/hippynn/experiment/routines.py +++ b/hippynn/experiment/routines.py @@ -207,8 +207,8 @@ def setup_training( model, loss, evaluator, optimizer, setup_params.device or tools.device_fallback() ) - metrics = MetricTracker(evaluator.loss_names, stopping_key=controller.stopping_key) - + metrics = MetricTracker(evaluator.loss_names, + stopping_key=controller.stopping_key) return training_modules, controller, metrics @@ -353,16 +353,21 @@ def test_model(database, evaluator, batch_size, when, metric_tracker=None): if metric_tracker is None: metric_tracker = MetricTracker(evaluator.loss_names, stopping_key=None) - metric_tracker.quiet = False + + # A little dance to make sure train, valid, test always come first, when present. + basic_splits = ["train", "valid", "test"] + basic_splits = [s for s in basic_splits if s in database.splits] + splits = basic_splits + [s for s in database.splits if s not in basic_splits] + evaluation_data = collections.OrderedDict( ( - (key, database.make_generator(key, "eval", batch_size)) # During testing, run through all splits in the database. - for key in database.splits + (key, database.make_generator(key, "eval", batch_size)) + for key in splits ) ) evaluation_metrics = {k: evaluator.evaluate(gen, eval_type=k, when=when) for k, gen in evaluation_data.items()} metric_tracker.register_metrics(evaluation_metrics, when=when) - metric_tracker.evaluation_print(evaluation_metrics) + metric_tracker.evaluation_print(evaluation_metrics, quiet=False) return metric_tracker