Skip to content

Commit

Permalink
[usability] debug tools dev
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhenjia committed Oct 25, 2024
1 parent a66fbee commit 09a52ab
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/lmflow/pipeline/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.base_tuner import BaseTuner
from lmflow.pipeline.utils.peft_trainer import PeftTrainer, PeftSavingCallback
from lmflow.utils.debug import get_parameter_names_in_param_groups
from lmflow.utils.debug.debug import get_parameter_names_in_param_groups


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -550,7 +550,10 @@ def on_step_begin(self, args, state, control, **kwargs):
if state.global_step % self.interval_steps == 0:
self.switch_active_layers()

# layers = eval('self.' + self.layers_attribute) # Re-fetch layer references
layers = eval('self.' + self.layers_attribute) # Re-fetch layer references
print(f'>>> on step {state.global_step} begin model params')
print(layers[self.active_layers_indices[0]].attn.c_attn.weight)
print(f'<<< on step {state.global_step} begin model params')
# self.previous_params = {
# name: param.clone().detach()
# for name, param in layers[self.active_layers_indices[0]].named_parameters()
Expand All @@ -563,6 +566,7 @@ def switch_active_layers(self):
# Randomly select n_layers to activate
layers = eval('self.' + self.layers_attribute) # Re-fetch layer references
self.active_layers_indices = np.random.choice(range(self.total_layers), self.n_layers, replace=False)
self.active_layers_indices.sort()
print(f"Activating layers at indices: {self.active_layers_indices} for the next steps.", flush=True)

# Enable gradients only for the selected layers
Expand All @@ -571,13 +575,16 @@ def switch_active_layers(self):
param.requires_grad = True

def on_step_end(self, args, state, control, **kwargs):
# layers = eval('self.' + self.layers_attribute) # Re-fetch layer references
layers = eval('self.' + self.layers_attribute) # Re-fetch layer references
# for name, param in layers[self.active_layers_indices[0]].named_parameters():
# if torch.equal(param, self.previous_params[name]):
# print(f"No change in parameter: {name}")
# else:
# print(f"Parameter updated: {name}")
pass
print(f'>>> on step {state.global_step-1} end model params')
print(layers[self.active_layers_indices[0]].attn.c_attn.weight.shape)
print(layers[self.active_layers_indices[0]].attn.c_attn.weight)
print(f'<<< on step {state.global_step-1} end model params')

def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
pass
Expand Down
12 changes: 12 additions & 0 deletions src/lmflow/utils/debug/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
GPT2= {
"param_name_in_group": [
{'parameter_names': ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.c_proj.weight', 'transformer.h.2.mlp.c_fc.weight', 'transformer.h.2.mlp.c_proj.weight', 'transformer.h.3.attn.c_attn.weight', 'transformer.h.3.attn.c_proj.weight', 'transformer.h.3.mlp.c_fc.weight', 'transformer.h.3.mlp.c_proj.weight', 'transformer.h.4.attn.c_attn.weight', 'transformer.h.4.attn.c_proj.weight', 'transformer.h.4.mlp.c_fc.weight', 'transformer.h.4.mlp.c_proj.weight', 'transformer.h.5.attn.c_attn.weight', 'transformer.h.5.attn.c_proj.weight', 'transformer.h.5.mlp.c_fc.weight', 'transformer.h.5.mlp.c_proj.weight', 'transformer.h.6.attn.c_attn.weight', 'transformer.h.6.attn.c_proj.weight', 'transformer.h.6.mlp.c_fc.weight', 'transformer.h.6.mlp.c_proj.weight', 'transformer.h.7.attn.c_attn.weight', 'transformer.h.7.attn.c_proj.weight', 'transformer.h.7.mlp.c_fc.weight', 'transformer.h.7.mlp.c_proj.weight', 'transformer.h.8.attn.c_attn.weight', 'transformer.h.8.attn.c_proj.weight', 'transformer.h.8.mlp.c_fc.weight', 'transformer.h.8.mlp.c_proj.weight', 'transformer.h.9.attn.c_attn.weight', 'transformer.h.9.attn.c_proj.weight', 'transformer.h.9.mlp.c_fc.weight', 'transformer.h.9.mlp.c_proj.weight', 'transformer.h.10.attn.c_attn.weight', 'transformer.h.10.attn.c_proj.weight', 'transformer.h.10.mlp.c_fc.weight', 'transformer.h.10.mlp.c_proj.weight', 'transformer.h.11.attn.c_attn.weight', 'transformer.h.11.attn.c_proj.weight', 'transformer.h.11.mlp.c_fc.weight', 'transformer.h.11.mlp.c_proj.weight']},
{'parameter_names': ['transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.bias', 'transformer.h.2.attn.c_proj.bias', 'transformer.h.2.ln_2.weight', 'transformer.h.2.ln_2.bias', 'transformer.h.2.mlp.c_fc.bias', 'transformer.h.2.mlp.c_proj.bias', 'transformer.h.3.ln_1.weight', 'transformer.h.3.ln_1.bias', 'transformer.h.3.attn.c_attn.bias', 'transformer.h.3.attn.c_proj.bias', 'transformer.h.3.ln_2.weight', 'transformer.h.3.ln_2.bias', 'transformer.h.3.mlp.c_fc.bias', 'transformer.h.3.mlp.c_proj.bias', 'transformer.h.4.ln_1.weight', 'transformer.h.4.ln_1.bias', 'transformer.h.4.attn.c_attn.bias', 'transformer.h.4.attn.c_proj.bias', 'transformer.h.4.ln_2.weight', 'transformer.h.4.ln_2.bias', 'transformer.h.4.mlp.c_fc.bias', 'transformer.h.4.mlp.c_proj.bias', 'transformer.h.5.ln_1.weight', 'transformer.h.5.ln_1.bias', 'transformer.h.5.attn.c_attn.bias', 'transformer.h.5.attn.c_proj.bias', 'transformer.h.5.ln_2.weight', 'transformer.h.5.ln_2.bias', 'transformer.h.5.mlp.c_fc.bias', 'transformer.h.5.mlp.c_proj.bias', 'transformer.h.6.ln_1.weight', 'transformer.h.6.ln_1.bias', 'transformer.h.6.attn.c_attn.bias', 'transformer.h.6.attn.c_proj.bias', 'transformer.h.6.ln_2.weight', 'transformer.h.6.ln_2.bias', 'transformer.h.6.mlp.c_fc.bias', 'transformer.h.6.mlp.c_proj.bias', 'transformer.h.7.ln_1.weight', 'transformer.h.7.ln_1.bias', 'transformer.h.7.attn.c_attn.bias', 'transformer.h.7.attn.c_proj.bias', 'transformer.h.7.ln_2.weight', 'transformer.h.7.ln_2.bias', 'transformer.h.7.mlp.c_fc.bias', 'transformer.h.7.mlp.c_proj.bias', 'transformer.h.8.ln_1.weight', 'transformer.h.8.ln_1.bias', 'transformer.h.8.attn.c_attn.bias', 'transformer.h.8.attn.c_proj.bias', 'transformer.h.8.ln_2.weight', 'transformer.h.8.ln_2.bias', 'transformer.h.8.mlp.c_fc.bias', 'transformer.h.8.mlp.c_proj.bias', 'transformer.h.9.ln_1.weight', 'transformer.h.9.ln_1.bias', 'transformer.h.9.attn.c_attn.bias', 'transformer.h.9.attn.c_proj.bias', 'transformer.h.9.ln_2.weight', 'transformer.h.9.ln_2.bias', 'transformer.h.9.mlp.c_fc.bias', 'transformer.h.9.mlp.c_proj.bias', 'transformer.h.10.ln_1.weight', 'transformer.h.10.ln_1.bias', 'transformer.h.10.attn.c_attn.bias', 'transformer.h.10.attn.c_proj.bias', 'transformer.h.10.ln_2.weight', 'transformer.h.10.ln_2.bias', 'transformer.h.10.mlp.c_fc.bias', 'transformer.h.10.mlp.c_proj.bias', 'transformer.h.11.ln_1.weight', 'transformer.h.11.ln_1.bias', 'transformer.h.11.attn.c_attn.bias', 'transformer.h.11.attn.c_proj.bias', 'transformer.h.11.ln_2.weight', 'transformer.h.11.ln_2.bias', 'transformer.h.11.mlp.c_fc.bias', 'transformer.h.11.mlp.c_proj.bias', 'transformer.ln_f.weight', 'transformer.ln_f.bias']}
],
"num_params": {
"lm_head": 50257*768 + 1024*768, # wte, wpe
"gpt2block": 768*2304 + 768*768 + 768*3072 + 768*3072 + 6*768 + 2304 + 3072,
"gpt2block_in_pg0": 768*2304 + 768*768 + 768*3072 + 768*3072, # weight decay (if any)
"gpt2block_in_pg1": 768 + 768 + 2304 + 768 + 768 + 768 + 3072 + 768, # no weight decay (no matter what) ln_1.weight: 768 ln_1.bias: 768 attn.c_attn.bias: 2304 attn.c_proj.bias: 768 ln_2.weight: 768 ln_2.bias: 768 mlp.c_fc.bias: 3072 mlp.c_proj.bias: 768
}
}
File renamed without changes.

0 comments on commit 09a52ab

Please sign in to comment.