diff --git a/inferno/trainers/callbacks/tqdm.py b/inferno/trainers/callbacks/tqdm.py index 4d5d2cb5..4ee68609 100644 --- a/inferno/trainers/callbacks/tqdm.py +++ b/inferno/trainers/callbacks/tqdm.py @@ -29,6 +29,12 @@ def __init__(self, *args, **kwargs): self.is_training = False self.is_validation = False + def get_config(self): + config_dict = super(TQDMProgressBar, self).get_config() + config_dict.update({'epoch_bar': None}) + config_dict.update({'outer_bar': None}) + return config_dict + def bind_trainer(self, *args, **kwargs): super(TQDMProgressBar, self).bind_trainer(*args, **kwargs) self.trainer.console.toggle_progress(False) @@ -52,6 +58,8 @@ def begin_of_fit(self, max_num_epochs, **_): self.outer_bar = tqdm(total=max_num_epochs, position=0, dynamic_ncols=True) else: self.outer_bar = tqdm(total=1000, position=0, dynamic_ncols=True) + + self.outer_bar.update(self.trainer._epoch_count) self.outer_bar.set_description("Epochs") def end_of_fit(self, **_):