From 362cc61498c6115597baeb00df3cc3f48ddb5350 Mon Sep 17 00:00:00 2001 From: Yuyan Li Date: Wed, 27 Mar 2019 03:36:22 +0100 Subject: [PATCH 1/3] make serialization possible --- inferno/trainers/callbacks/tqdm.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/inferno/trainers/callbacks/tqdm.py b/inferno/trainers/callbacks/tqdm.py index 4d5d2cb5..464c8cc2 100644 --- a/inferno/trainers/callbacks/tqdm.py +++ b/inferno/trainers/callbacks/tqdm.py @@ -29,6 +29,13 @@ def __init__(self, *args, **kwargs): self.is_training = False self.is_validation = False + def get_config(self): + config_dict = dict(self.__dict__) + config_dict.update({'_trainer': None}) + config_dict.update({'epoch_bar': None}) + config_dict.update({'outer_bar': None}) + return config_dict + def bind_trainer(self, *args, **kwargs): super(TQDMProgressBar, self).bind_trainer(*args, **kwargs) self.trainer.console.toggle_progress(False) From 24b61c7eabcace6c8c1d4d284e7814bf90ebb96c Mon Sep 17 00:00:00 2001 From: Yuyan Li Date: Wed, 27 Mar 2019 15:44:35 +0100 Subject: [PATCH 2/3] use inheritance --- inferno/trainers/callbacks/tqdm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/inferno/trainers/callbacks/tqdm.py b/inferno/trainers/callbacks/tqdm.py index 464c8cc2..e0c1f8ab 100644 --- a/inferno/trainers/callbacks/tqdm.py +++ b/inferno/trainers/callbacks/tqdm.py @@ -30,8 +30,7 @@ def __init__(self, *args, **kwargs): self.is_validation = False def get_config(self): - config_dict = dict(self.__dict__) - config_dict.update({'_trainer': None}) + config_dict = super(TQDMProgressBar, self).get_config() config_dict.update({'epoch_bar': None}) config_dict.update({'outer_bar': None}) return config_dict From 8b5520588349e28d2bfddcc70de2d951a86c4f15 Mon Sep 17 00:00:00 2001 From: Yuyan Li Date: Thu, 15 Aug 2019 15:22:09 +0200 Subject: [PATCH 3/3] add epoch counter when loading checkpoint --- inferno/trainers/callbacks/tqdm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/inferno/trainers/callbacks/tqdm.py b/inferno/trainers/callbacks/tqdm.py index e0c1f8ab..4ee68609 100644 --- a/inferno/trainers/callbacks/tqdm.py +++ b/inferno/trainers/callbacks/tqdm.py @@ -58,6 +58,8 @@ def begin_of_fit(self, max_num_epochs, **_): self.outer_bar = tqdm(total=max_num_epochs, position=0, dynamic_ncols=True) else: self.outer_bar = tqdm(total=1000, position=0, dynamic_ncols=True) + + self.outer_bar.update(self.trainer._epoch_count) self.outer_bar.set_description("Epochs") def end_of_fit(self, **_):