From 2e8f1b0322e97ec27cb21006f68176c0ef9bfa7d Mon Sep 17 00:00:00 2001 From: Grigory Reznikov Date: Sat, 2 Nov 2024 18:50:44 +0100 Subject: [PATCH] wip --- src/nanotron/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 45d704ee..75a943a4 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -87,6 +87,7 @@ from nanotron.serialize import ( load_lr_scheduler, load_meta, + load_random_states, load_weights, parse_ckpt_path, save, @@ -170,6 +171,11 @@ def __init__( self.random_states = init_random_states( parallel_config=self.config.parallelism, tp_pg=self.parallel_context.tp_pg ) + if self.init_checkpoint_path is not None: + self.random_states = load_random_states( + parallel_context=self.parallel_context, + root_folder=self.init_checkpoint_path, + ) self.model = self.init_model() # Defines self.model self.unwrapped_model: NanotronModel = ( self.model.module if isinstance(self.model, DistributedDataParallel) else self.model