diff --git a/.github/workflows/3d_parallelism_unit_tests.yaml b/.github/workflows/3d_parallelism_unit_tests.yaml index 73804d6c..fe6c41d1 100644 --- a/.github/workflows/3d_parallelism_unit_tests.yaml +++ b/.github/workflows/3d_parallelism_unit_tests.yaml @@ -59,6 +59,7 @@ jobs: --durations=0 \ --ignore tests/kernels \ --ignore tests/fp8 \ + --ignore tests/test_llama.py \ --verbose \ tests/ # NOTE: T4 can't run FA2, DoReMi's LLaMa needs FÀ diff --git a/.github/workflows/fa2_unit_tests.yaml b/.github/workflows/fa2_unit_tests.yaml index cc8e58ee..c5b3346f 100644 --- a/.github/workflows/fa2_unit_tests.yaml +++ b/.github/workflows/fa2_unit_tests.yaml @@ -39,7 +39,7 @@ jobs: python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - name: Instal nanotron + - name: Install nanotron run: | python -m pip install --upgrade pip pip install packaging @@ -48,6 +48,7 @@ jobs: pip install -e . pip install -e .[dev] pip install -e .[test] + pip install transformers datasets - name: Show installed libraries and their versions run: pip freeze | tee installed.txt diff --git a/.github/workflows/llama_tests.yaml b/.github/workflows/llama_tests.yaml new file mode 100644 index 00000000..61ac7d79 --- /dev/null +++ b/.github/workflows/llama_tests.yaml @@ -0,0 +1,59 @@ +name: Run Llama loss test + +on: + push: + branches: [ main ] + # Only run tests if we modify the following files + paths: + - "src/**/*.py" + - "examples/**/*.py" + - "tests/**/*.py" + + pull_request: + branches: [ '**' ] + paths: + - "src/**/*.py" + - "examples/**/*.py" + - "tests/**/*.py" + +jobs: + tests: + # NOTE: 8-t4 to run LLama + runs-on: [multi-gpu, nvidia-gpu, 8-t4, ci] + container: + image: runpod/pytorch:2.1.1-py3.10-cuda12.1.1-devel-ubuntu22.04 + ports: + - 80 + options: --gpus all --shm-size "8G" + steps: + - uses: actions/checkout@v3 + - name: Python environment + run: | + which python + python --version + + - name: Check Pytorch version + run: | + nvidia-smi + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install nanotron's dependencies + run: | + python -m pip install --upgrade pip + pip install packaging + pip install wheel + pip install "flash-attn>=2.5.0" --no-build-isolation + pip install -e . + pip install -e .[dev] + pip install -e .[test] + pip install transformers datasets + + - name: Show installed libraries and their versions + run: pip freeze | tee installed.txt + + - name: Run Llama example + run: pytest --verbose tests/test_llama.py::test_tiny_llama + + - name: Run Llama loss test + run: pytest --verbose tests/test_llama.py::test_train_llama diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml index 58645e2d..c4371564 100644 --- a/examples/config_tiny_llama.yaml +++ b/examples/config_tiny_llama.yaml @@ -1,6 +1,6 @@ checkpoints: checkpoint_interval: 10 - checkpoints_path: checkpoints + checkpoints_path: /fsx/haojun/nanotron/checkpoints checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null save_initial_state: false diff --git a/examples/config_train_llama.py b/examples/config_train_llama.py new file mode 100644 index 00000000..d46fe610 --- /dev/null +++ b/examples/config_train_llama.py @@ -0,0 +1,120 @@ +""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" +import os + +from nanotron.config import ( + AdamWOptimizerArgs, + CheckpointsArgs, + Config, + DataArgs, + DatasetStageArgs, + GeneralArgs, + LlamaConfig, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + ParallelismArgs, + PretrainDatasetsArgs, + RandomInit, + TokenizerArgs, + TokensArgs, +) +from nanotron.logging import human_format + +model_config = LlamaConfig( + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=768, + initializer_range=0.02, + intermediate_size=3072, + max_position_embeddings=512, + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + tie_word_embeddings=True, + use_cache=True, + vocab_size=50272, +) + +num_params = human_format( + model_config.vocab_size * model_config.hidden_size * 2 + + model_config.num_hidden_layers + * ( + 3 * model_config.hidden_size * model_config.intermediate_size + + 4 * model_config.hidden_size * model_config.hidden_size + ) +).replace(".", "p") + +print(f"Model has {num_params} parameters") + +seed = 42 + +learning_rate = LRSchedulerArgs( + learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 +) + +optimizer = OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=True, + learning_rate_scheduler=learning_rate, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), +) + +parallelism = ParallelismArgs( + dp=4, + pp=1, + tp=2, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, +) + +# Tokens per batch = micro_batch_size * dp * sequence_length * batch_accumulation_per_replica +# 16 * 4 * 512 * 32 = 1,048,576. -> A global batch-size of 1M tokens. +# train 200 steps to observe the loss +tokens = TokensArgs(sequence_length=512, train_steps=200, micro_batch_size=16, batch_accumulation_per_replica=32) + +checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" +os.makedirs(checkpoints_path, exist_ok=True) + +config = Config( + general=GeneralArgs(project="debug", run="tiny_llama_%date_%jobid", seed=seed), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + parallelism=parallelism, + model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), + tokenizer=TokenizerArgs("gpt2"), + optimizer=optimizer, + logging=LoggingArgs(), + tokens=tokens, + data_stages=[ + DatasetStageArgs( + name="Stable Training Stage", + start_training_step=1, + data=DataArgs( + dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="roneneldan/TinyStories", text_column_name="text"), + seed=seed, + ), + ) + ], + profiler=None, +) + + +if __name__ == "__main__": + dir = os.path.dirname(__file__) + + # Save config as YAML file + config.save_as_yaml(f"{dir}/config_train_llama.yaml") + + # You can now train a model with this config using `/run_train.py` diff --git a/examples/config_train_llama.yaml b/examples/config_train_llama.yaml new file mode 100644 index 00000000..9c0a2a72 --- /dev/null +++ b/examples/config_train_llama.yaml @@ -0,0 +1,97 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: /fsx/haojun/nanotron/checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 768 + initializer_range: 0.02 + intermediate_size: 3072 + is_llama_config: true + max_position_embeddings: 512 + num_attention_heads: 16 + num_hidden_layers: 12 + num_key_value_heads: 16 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 50272 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 198 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 4 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: gpt2 + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 32 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 16 + sequence_length: 512 + train_steps: 200 + val_check_interval: -1 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 32aab9cd..1bc701bf 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch LLaMa model.""" - -from typing import Dict, Optional, Union, List +import os +from typing import Dict, Optional, Union import torch from torch import nn @@ -27,7 +27,7 @@ from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN -from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.nn.layer_norm import RMSNorm, TritonRMSNorm from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer @@ -46,6 +46,20 @@ logger = logging.get_logger(__name__) +# Replace flash attention implementations +# TritonRMSNorm + FlashRotaryEmbedding + CoreAttention -> Llama RMSNorm + RotaryEmbedding + SDPA +DISABLE_FLASH_ATTENTION = os.getenv("DISABLE_FLASH_ATTENTION", "0") == "1" + +if DISABLE_FLASH_ATTENTION: + print("Warning: Flash attention was disabled!") + # FSDP + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_flash_sdp(False) + + +RMSNorm = RMSNorm if DISABLE_FLASH_ATTENTION else TritonRMSNorm + + class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): super().__init__() @@ -262,7 +276,6 @@ def __init__( tp_pg: dist.ProcessGroup, layer_idx: int, ): - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding super().__init__() # Tensor parallel considerations: We split tensors along head dimension @@ -322,8 +335,22 @@ def __init__( end=config.max_position_embeddings, ) - # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, interleaved=True) + if not DISABLE_FLASH_ATTENTION: + from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + + self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, interleaved=True) + # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) + # forward( self, qkv: torch.Tensor, kv) + self.attention = CoreAttention( + config, + parallel_config=parallel_config, + layer_idx=layer_idx, + ) + else: + self.flash_rotary_embedding = RotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + ) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -334,12 +361,6 @@ def __init__( async_communication=tp_linear_async_communication, ) - self.attention = CoreAttention( - config, - parallel_config=parallel_config, - layer_idx=layer_idx, - ) - self.prefill_kv_len = ( config.max_position_embeddings ) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings @@ -558,12 +579,17 @@ def forward( # NOTE: The layout is different from models/llama.py which is [batch_size, num_heads, seq_length, d_qk] # Here it is, [batch_size, seq_length, num_heads, d_qk] # [2, batch_size, seq_length, num_heads, d_qk] - key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0) - # [batch_size, seq_length, 2, num_heads, d_qk] - key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous() - query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states) - # [batch_size, seq_length, num_heads, d_qk] - key_states, value_states = torch.split(key_value_states, 1, dim=2) + if not DISABLE_FLASH_ATTENTION: + key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0) + # [batch_size, seq_length, 2, num_heads, d_qk] + key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous() + query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states) + # [batch_size, seq_length, num_heads, d_qk] + key_states, value_states = torch.split(key_value_states, 1, dim=2) + else: + position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1 + query_states = self.flash_rotary_embedding(query_states, position_ids=position_ids) + key_states = self.flash_rotary_embedding(key_states, position_ids=position_ids) q_sequence_mask = sequence_mask kv_sequence_mask = sequence_mask @@ -571,24 +597,45 @@ def forward( kv_length = key_states.shape[1] # [batch_size, seq_length, num_heads, d_qk] # Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func` - query_states = query_states.view( - batch_size * q_length, self.n_local_q_heads, self.d_qk - ) # [batch_size * q_length, self.n_heads, d_qk] - - key_states = key_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_qk - ) # [batch_size * kv_length, self.n_heads, d_qk] - value_states = value_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_v - ) # [batch_size * kv_length, self.n_heads, d_v] - - attention_output = self.attention( - query_states=query_states, - key_states=key_states, - value_states=value_states, - q_sequence_mask=q_sequence_mask, - kv_sequence_mask=kv_sequence_mask, - ) + if not DISABLE_FLASH_ATTENTION: + query_states = query_states.view( + batch_size * q_length, self.n_local_q_heads, self.d_qk + ) # [batch_size * q_length, self.n_heads, d_qk] + key_states = key_states.view( + batch_size * kv_length, self.n_local_kv_heads, self.d_qk + ) # [batch_size * kv_length, self.n_heads, d_qk] + value_states = value_states.view( + batch_size * kv_length, self.n_local_kv_heads, self.d_v + ) # [batch_size * kv_length, self.n_heads, d_v] + attention_output = self.attention( + query_states=query_states, + key_states=key_states, + value_states=value_states, + q_sequence_mask=q_sequence_mask, + kv_sequence_mask=kv_sequence_mask, + ) + else: + # [batch_size, seq_length, n_local_q_heads, d_qk] -> [batch_size, n_local_q_heads, seq_length, d_qk] + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # attention_output: [batch_size, n_local_q_heads, seq_length, d_v] + attention_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + dropout_p=0.0, + is_causal=True, + ) + + # In order to have the same shape as flash attention output + # [batch_size, n_local_q_heads, q_length, d_v] -> [batch_size * q_length, n_local_q_heads, d_v] + attention_output = ( + attention_output.transpose(1, 2) + .contiguous() + .view(batch_size * q_length, self.n_local_q_heads, self.d_v) + ) attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) @@ -607,7 +654,7 @@ def __init__( layer_idx: int, ): super().__init__() - self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, @@ -615,7 +662,7 @@ def __init__( layer_idx=layer_idx, ) - self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) def forward( @@ -724,7 +771,7 @@ def __init__( self.final_layer_norm = PipelineBlock( p2p=self.p2p, - module_builder=TritonRMSNorm, + module_builder=RMSNorm, module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, module_input_keys={"input"}, module_output_keys={"hidden_states"}, diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index 688eaa78..6caa439c 100644 --- a/src/nanotron/nn/layer_norm.py +++ b/src/nanotron/nn/layer_norm.py @@ -51,3 +51,20 @@ def forward( is_rms_norm=True, return_dropout_mask=return_dropout_mask, ) + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, input): + input_dtype = input.dtype + input = input.to(torch.float32) + variance = input.pow(2).mean(-1, keepdim=True) + input = input * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * input.to(input_dtype) diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..d4c9c4ce 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -4,7 +4,7 @@ from typing import Dict from nanotron.config import ModelArgs -from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.nn.layer_norm import RMSNorm, TritonRMSNorm from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -37,6 +37,7 @@ def __init__(self, config: ModelArgs): TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, + RMSNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, } diff --git a/tests/test_llama.py b/tests/test_llama.py new file mode 100644 index 00000000..e69760b3 --- /dev/null +++ b/tests/test_llama.py @@ -0,0 +1,176 @@ +# Script to test correctness of training script by comparing loss value after 100th iteration with expected loss value +# pytest -sv tests/test_llama.py or python tests/test_train_llama.py + +import atexit +import os +import re +import signal +import subprocess + +CONFIG_FILE = "examples/config_train_llama.yaml" +CREATE_CONFIG_FILE = "examples/config_train_llama.py" +TRAIN_SCRIPT = "run_train.py" +NUM_GPUS = 8 + +TINY_LLLAMA_CONFIG_FILE = "examples/config_tiny_llama.yaml" +TINY_LLLAMA_CREATE_CONFIG_FILE = "examples/config_tiny_llama.py" + +## Experiment results: +## 100 steps: 3.28 +## 160 steps: 2.83 +## 200 steps: 2.75 + +## Expect +## 100+ steps: lm_loss < 3.4 +## 200 steps: lm_loss < 2.8 + +EXPECTED_LOSS = 3.4 +CHECK_ITERATION = 100 + +EXPECTED_LOSS_END = 2.8 +CHECK_ITERATION_END = 200 + + +def exit_with_children(): + """Kill all children processes when this process exits""" + os.killpg(0, signal.SIGKILL) + + +def extract_loss(line): + """Extract loss value from the line""" + # extract loss value of the type | lm_loss: 5.33 or lm_loss: 3 + try: + return float(re.search(r"lm_loss: (\d+(?:\.\d+)?)", line.decode("utf-8")).group(1)) + except AttributeError: + raise ValueError(f"Could not extract loss value from line: {line}") + + +def test_train_llama(): + # create config file + cmd = f"python {CREATE_CONFIG_FILE}" + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + # Read and print output in real-time + while True: + line = process.stdout.readline() + if process.poll() is not None and line == b"": + break + if line: + print(line.decode("utf-8"), end="") + + process.wait() # Wait for the process to finish + assert process.returncode == 0 + + # run training + # set DISABLE_FLASH_ATTENTION=1 to replace flash attention implementations + cmd = f'DISABLE_FLASH_ATTENTION=1 FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={NUM_GPUS} --rdzv_endpoint=localhost:29800 {TRAIN_SCRIPT} --config-file {CONFIG_FILE}' + os.setpgrp() # create new process group, become its leader + atexit.register(exit_with_children) # kill all children processes when this process exits + + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + # Read and print output in real-time + while True: + line = process.stdout.readline() + if process.poll() is not None and line == b"": + break + if line: + print(line.decode("utf-8"), end="") + # for all iterations >= CHECK_ITERATION, loss should be below EXPECTED_LOSS + if re.search(r"iteration: (\d+) / ", line.decode("utf-8")): + if int(re.search(r"iteration: (\d+) / ", line.decode("utf-8")).group(1)) >= CHECK_ITERATION: + loss = extract_loss(line) + assert loss < EXPECTED_LOSS + # for iteration = CHECK_ITERATION_END, loss should be below EXPECTED_LOSS_END + if re.search(rf"iteration: {CHECK_ITERATION_END} / ", line.decode("utf-8")): + loss = extract_loss(line) + assert loss < EXPECTED_LOSS_END + + process.wait() # Wait for the process to finish + assert process.returncode == 0 + + +# also run the tiny llama example. Only want to assert it can be ran. +def test_tiny_llama(): + # create config file + cmd = f"python {TINY_LLLAMA_CREATE_CONFIG_FILE}" + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + # Read and print output in real-time + while True: + line = process.stdout.readline() + if process.poll() is not None and line == b"": + break + if line: + print(line.decode("utf-8"), end="") + process.wait() # Wait for the process to finish + assert process.returncode == 0 + + # run training + # set DISABLE_FLASH_ATTENTION=1 to replace flash attention implementations + cmd = f'DISABLE_FLASH_ATTENTION=1 FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={NUM_GPUS} --rdzv_endpoint=localhost:29800 {TRAIN_SCRIPT} --config-file {TINY_LLLAMA_CONFIG_FILE}' + os.setpgrp() # create new process group, become its leader + atexit.register(exit_with_children) # kill all children processes when this process exits + + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + while True: + line = process.stdout.readline() + if process.poll() is not None and line == b"": + break + if line: + print(line.decode("utf-8"), end="") + + process.wait() # Wait for the process to finish + assert process.returncode == 0 + + +if __name__ == "__main__": + # create config file + cmd = f"python {CREATE_CONFIG_FILE}" + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + try: + # Read and print output in real-time + while True: + line = process.stdout.readline() + if process.poll() is not None and line == b"": + break + if line: + print(line.decode("utf-8"), end="") + process.wait() # Wait for the process to finish + assert process.returncode == 0 + except AssertionError: + print("Command failed with exit code:", process.returncode) + exit() + else: + print("Config created successfully.") + + # run training + # set DISABLE_FLASH_ATTENTION=1 to replace flash attention implementations + cmd = f'DISABLE_FLASH_ATTENTION=1 FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={NUM_GPUS} --rdzv_endpoint=localhost:29800 {TRAIN_SCRIPT} --config-file {CONFIG_FILE}' + os.setpgrp() # create new process group, become its leader + atexit.register(exit_with_children) # kill all children processes when this process exits + + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + try: + # Read and print output in real-time + while True: + line = process.stdout.readline() + if process.poll() is not None and line == b"": + break + if line: + print(line.decode("utf-8"), end="") + + # for all iterations >= CHECK_ITERATION, loss should be below EXPECTED_LOSS + if re.search(r"iteration: (\d+) / ", line.decode("utf-8")): + if int(re.search(r"iteration: (\d+) / ", line.decode("utf-8")).group(1)) >= CHECK_ITERATION: + loss = extract_loss(line) + assert loss < EXPECTED_LOSS + # at iteration= CHECK_ITERATION, loss should be below EXPECTED_LOSS_END + if re.search(rf"iteration: {CHECK_ITERATION_END} / ", line.decode("utf-8")): + loss = extract_loss(line) + assert loss < EXPECTED_LOSS_END + process.wait() # Wait for the process to finish + assert process.returncode == 0 + except AssertionError: + print("Command failed with exit code:", process.returncode) + else: + print("Command executed successfully.")