Skip to content

Commit

Permalink
add evaluation of extra splits to test_model routine
Browse files Browse the repository at this point in the history
  • Loading branch information
lubbersnick committed Aug 9, 2024
1 parent 01b7c29 commit 864a8cb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
45 changes: 30 additions & 15 deletions hippynn/experiment/metric_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MetricTracker:
"""

def __init__(self, metric_names, stopping_key, quiet=False, split_names=("train", "valid", "test")):
def __init__(self, metric_names, stopping_key, quiet=False):
"""
:param metric_names:
Expand All @@ -48,7 +48,7 @@ def __init__(self, metric_names, stopping_key, quiet=False, split_names=("train"
self.n_metrics = len(metric_names)

# State variables
self.best_metric_values = {split: {mtype: float("inf") for mtype in self.metric_names} for split in split_names}
self.best_metric_values = {}
self.other_metric_values = {}
self.best_model = None
self.epoch_times = []
Expand Down Expand Up @@ -78,7 +78,18 @@ def register_metrics(self, metric_info, when):
better_metrics = {k: {} for k in self.best_metric_values}
for split_type, typevals in metric_info.items():
for mname, mval in typevals.items():
better = self.best_metric_values[split_type][mname] > mval
try:
old_best = self.best_metric_values[split_type][mname]
better = old_best > mval
del old_best # marking not needed.
except KeyError:
if split_type not in self.best_metric_values:
# Haven't seen this split before!
print("ADDING ",split_type)
self.best_metric_values[split_type] = {}
better_metrics[split_type] = {}
better = True # old best was not found!

if better:
self.best_metric_values[split_type][mname] = mval
better_metrics[split_type][mname] = better
Expand All @@ -97,13 +108,17 @@ def register_metrics(self, metric_info, when):

return better_metrics, better_model, stopping_key_metric

def evaluation_print(self, evaluation_dict):
if self.quiet:
def evaluation_print(self, evaluation_dict, quiet=None):
if quiet is None:
quiet = self.quiet
if quiet:
return
table_evaluation_print(evaluation_dict, self.metric_names, self.name_column_width)

def evaluation_print_better(self, evaluation_dict, better_dict):
if self.quiet:
def evaluation_print_better(self, evaluation_dict, better_dict, quiet=None):
if quiet is None:
quiet = self.quiet
if quiet:
return
table_evaluation_print_better(evaluation_dict, better_dict, self.metric_names, self.name_column_width)
if self.stopping_key:
Expand All @@ -119,14 +134,14 @@ def plot_over_time(self):

# Driver for printing evaluation table results, with * for better entries.
# Decoupled from the estate in case we want to more easily change print formatting.
def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, ncs):
def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_columns):
"""
Print metric results as a table, add a '*' character for metrics in better_dict.
:param evaluation_dict: dict[eval type]->dict[metric]->value
:param better_dict: dict[eval type]->dict[metric]->bool
:param metric_names: Names
:param ncs: Number of columns for name fields.
:param n_columns: Number of columns for name fields.
:return: None
"""
type_names = evaluation_dict.keys()
Expand All @@ -139,8 +154,8 @@ def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, nc

n_types = len(type_names)

header = " " * (ncs + 2) + "".join("{:>14}".format(tn) for tn in type_names)
rowstring = "{:<" + str(ncs) + "}: " + " {}{:>10.5g}" * n_types
header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names)
rowstring = "{:<" + str(n_columns) + "}: " + " {}{:>10.5g}" * n_types

print(header)
print("-" * len(header))
Expand All @@ -151,13 +166,13 @@ def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, nc

# Driver for printing evaluation table results.
# Decoupled from the estate in case we want to more easily change print formatting.
def table_evaluation_print(evaluation_dict, metric_names, ncs):
def table_evaluation_print(evaluation_dict, metric_names, n_columns):
"""
Print metric results as a table.
:param evaluation_dict: dict[eval type]->dict[metric]->value
:param metric_names: Names
:param ncs: Number of columns for name fields.
:param n_columns: Number of columns for name fields.
:return: None
"""

Expand All @@ -166,8 +181,8 @@ def table_evaluation_print(evaluation_dict, metric_names, ncs):

n_types = len(type_names)

header = " " * (ncs + 2) + "".join("{:>14}".format(tn) for tn in type_names)
rowstring = "{:<" + str(ncs) + "}: " + " {:>10.5g}" * n_types
header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names)
rowstring = "{:<" + str(n_columns) + "}: " + " {:>10.5g}" * n_types

print(header)
print("-" * len(header))
Expand Down
17 changes: 11 additions & 6 deletions hippynn/experiment/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def setup_training(
model, loss, evaluator, optimizer, setup_params.device or tools.device_fallback()
)

metrics = MetricTracker(evaluator.loss_names, stopping_key=controller.stopping_key)

metrics = MetricTracker(evaluator.loss_names,
stopping_key=controller.stopping_key)
return training_modules, controller, metrics


Expand Down Expand Up @@ -353,16 +353,21 @@ def test_model(database, evaluator, batch_size, when, metric_tracker=None):

if metric_tracker is None:
metric_tracker = MetricTracker(evaluator.loss_names, stopping_key=None)
metric_tracker.quiet = False

# A little dance to make sure train, valid, test always come first, when present.
basic_splits = ["train", "valid", "test"]
basic_splits = [s for s in basic_splits if s in database.splits]
splits = basic_splits + [s for s in database.splits if s not in basic_splits]

evaluation_data = collections.OrderedDict(
(
(key, database.make_generator(key, "eval", batch_size)) # During testing, run through all splits in the database.
for key in database.splits
(key, database.make_generator(key, "eval", batch_size))
for key in splits
)
)
evaluation_metrics = {k: evaluator.evaluate(gen, eval_type=k, when=when) for k, gen in evaluation_data.items()}
metric_tracker.register_metrics(evaluation_metrics, when=when)
metric_tracker.evaluation_print(evaluation_metrics)
metric_tracker.evaluation_print(evaluation_metrics, quiet=False)
return metric_tracker


Expand Down

0 comments on commit 864a8cb

Please sign in to comment.