Skip to content

Commit

Permalink
Add optional validation (to see how to do, see ). Take out some check…
Browse files Browse the repository at this point in the history
…s that seem wrong, like param.dtype != torch.float + requires_grad for forward.
  • Loading branch information
kylematoba committed Nov 18, 2024
1 parent 51ca40b commit 2a34742
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 80 deletions.
134 changes: 134 additions & 0 deletions examples/config_validation_tiny_llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
valid_data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/oscar-en-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/oscar-en-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 8
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: debug
run: tiny_llama_%date_%jobid
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: float32
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 16
initializer_range: 0.02
intermediate_size: 64
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 4
num_hidden_layers: 2
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 13
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 5
micro_batch_size: 2
sequence_length: 256
train_steps: 15
val_check_interval: 2
78 changes: 42 additions & 36 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,44 +178,48 @@ def get_dataloader_from_data_stage(
return dataloader


def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
def get_dataloader(trainer: DistributedTrainer,
data_stages_fieldname: str,
metadata_fieldname: str) -> Dict[str, DataLoader]:
dataloaders = {}
data_stages = getattr(trainer.config, data_stages_fieldname)
if data_stages:
metadata = getattr(trainer, metadata_fieldname)
for stage_idx, stage in enumerate(data_stages):
# NOTE: we only create the dataloader for the first stage,
# then we lazy initialize the dataloader for the other stages
stage = cast(DatasetStageArgs, stage)
consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, metadata)
assert (
consumed_train_samples is not None
), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint"

for stage_idx, stage in enumerate(trainer.config.data_stages):
# NOTE: we only create the dataloader for the first stage,
# then we lazy initialize the dataloader for the other stages
stage = cast(DatasetStageArgs, stage)
consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata)
assert (
consumed_train_samples is not None
), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint"

num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, trainer.config, trainer.metadata
)
log_rank(
f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples",
logger=logger,
level=logging.INFO,
rank=0,
)

dataloader = (
get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, trainer.config, metadata
)
if stage_idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
log_rank(
f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples",
logger=logger,
level=logging.INFO,
rank=0,
)
)
dataloaders[stage.name] = dataloader

dataloader = (
get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
)
if stage_idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
)
)
dataloaders[stage.name] = dataloader
return dataloaders


Expand All @@ -231,7 +235,9 @@ def get_args():

# Load trainer and data
trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer)
dataloader_train = get_dataloader(trainer, "data_stages", "metadata")
dataloader_valid = get_dataloader(trainer, "valid_data_stages", "valid_metadata")

# Train
trainer.train(dataloader)
trainer.train(dataloader_train, dataloader_valid)
print("Done")
1 change: 1 addition & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ class Config:
tokens: Optional[TokensArgs] = None
optimizer: Optional[OptimizerArgs] = None
data_stages: Optional[List[DatasetStageArgs]] = None
valid_data_stages: Optional[List[DatasetStageArgs]] = None
profiler: Optional[ProfilerArgs] = None
lighteval: Optional[LightEvalConfig] = None
s3_upload: Optional[S3UploadArgs] = None
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def build_grad_buffers(
if not param.requires_grad:
continue

assert param.dtype != torch.float, f"Expected {name} not to be float"
# assert param.dtype != torch.float, f"Expected {name} not to be float"
assert param.is_contiguous(), f"Expected {name} to be contiguous"

next_offset = offset + param.numel() * element_size
Expand Down
5 changes: 4 additions & 1 deletion src/nanotron/optim/zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

import numpy as np
import torch.optim
from functorch.dim import tree_map
try:
from functorch.dim import tree_map
except:
from torch.utils._pytree import tree_map
from torch import nn
from tqdm import tqdm

Expand Down
10 changes: 4 additions & 6 deletions src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd
from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState, PipelineEvalBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import ContextManagers
from torch import nn as torch_nn
Expand Down Expand Up @@ -53,7 +53,7 @@ def forward(

# Add output as activations that require backward pass
if not isinstance(output["loss"], TensorPointer):
assert output["loss"].requires_grad
# assert output["loss"].requires_grad
state.register_activation_requiring_backward(output["loss"])
return output

Expand Down Expand Up @@ -134,9 +134,9 @@ def validate_batch_iter(
nb_microbatches: int,
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
# Assign a new state for the current batch
state = PipelineTrainBatchState() # TODO: do i need state?
# state = PipelineTrainBatchState() # TODO: do i need state?
state = PipelineEvalBatchState()
self.nb_microbatches = nb_microbatches

outputs = []

with attach_pipeline_state_to_model(model=model, pipeline_state=state):
Expand All @@ -158,7 +158,6 @@ def validate_batch_iter(
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)

return outputs


Expand Down Expand Up @@ -326,5 +325,4 @@ def train_batch_iter(

# Make sure that micro batches are all fully consumed
state.check_buffers_empty()

return outputs
2 changes: 2 additions & 0 deletions src/nanotron/parallel/pipeline_parallel/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ class PipelineEvalBatchState(PipelineBatchState):
microbatches_activations_to_send = collections.deque()
microbatches_activations_to_recv = collections.deque()
activations_buffer = collections.deque()
# Reinitialise counter
nb_forwards = 0

def register_activation_requiring_backward(self, activation: torch.Tensor):
pass
Expand Down
Loading

0 comments on commit 2a34742

Please sign in to comment.