From 3a45a340f353c12dc82fd4ce4c1bc0a99e96a8d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 27 Aug 2024 15:22:26 +0000 Subject: [PATCH 1/3] remove torch compile --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 49ea86e6..ca6c2441 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -165,7 +165,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) + self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) From 3340a7bdaade8948246c2319805dcabe0e92e6f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 27 Aug 2024 22:18:29 +0000 Subject: [PATCH 2/3] only log ckpt on rank 0 --- src/nanotron/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 70d023fb..bef629c1 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -700,7 +700,7 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: ) reloaded_from_checkpoint = True if not reloaded_from_checkpoint: - log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO) + log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0) if isinstance(self.config.model.init_method, ExistingCheckpointInit): # Initialize model from an pretrained model checkpoint self.param_shard_metadata = load_weights( From 964ceae3097b178d307bf50e9ea573849f1f45b7 Mon Sep 17 00:00:00 2001 From: Tiancheng Chen Date: Tue, 3 Sep 2024 16:16:31 +0200 Subject: [PATCH 3/3] fix attempt 1: treat embedding as heavy as transformer layer --- src/nanotron/models/llama.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ca6c2441..e9c04ab5 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -813,15 +813,12 @@ def forward_with_hidden_states( def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" - model_config = self.config - d_ff = model_config.intermediate_size - d_qkv = model_config.hidden_size // model_config.num_attention_heads block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size - + 3 * d_ff * model_config.hidden_size, + Embedding: 1, + LlamaDecoderLayer: 1, # This is the last lm_head - TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + TensorParallelColumnLinear: 1, } return block_compute_costs