Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM issue with same batch size that was running ok on 0.0.80 #184

Open
salrowili opened this issue Jan 8, 2025 · 14 comments
Open

OOM issue with same batch size that was running ok on 0.0.80 #184

salrowili opened this issue Jan 8, 2025 · 14 comments

Comments

@salrowili
Copy link

Hi,

I've noticed that recent updates are causing the SFT trainer code to throw an OutOfMemory (OOM) error with the same batch size that previously ran without issue on version 0.0.80.

I attempted SFT tuning using bfloat16 (no LoRA) with LLaMA 3.1 8B, max_length=1024, and batch=8 on TPUv4-8, but encountered an OOM error. This fine-tuning setup was working ok with 0.0.80.

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

Hello and thanks for reporting issue, can you share the code please?

@salrowili
Copy link
Author

Hi @erfanzar ,
Please see the code below :

import easydel as ed
from easydel.utils.analyze_memory import SMPMemoryMonitor  # Optional for memory analysis
import jax
from jax import numpy as jnp, sharding, lax, random as jrnd
from huggingface_hub import HfApi
import datasets
from flax.core import FrozenDict
from datasets import load_dataset
import jsonlines ## pip install jsonlines
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
import wandb
wandb.init(project="test")
def train():
	pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct"
	max_length=1024
	PartitionSpec, api = sharding.PartitionSpec, HfApi()
	json_entry=[]
	qa_data = load_dataset("Stanford/web_questions" ,split="train")
	instruction ="You are a helpful AI assistant."
	for item in qa_data:
		question=item["question"]
		answer=item["answers"][0]
		json_entry.append({"messages": [{"role": "system", "content": instruction}, {"role": "user", "content": item["question"]}, {"role": "assistant", "content": answer}]})

	with jsonlines.open('SFT_Train.jsonl', 'w') as writer:
    		writer.write_all(json_entry)

	train_dataset = load_dataset("json", data_files={"train" :"/home/big35manf/SFT_Train.jsonl"},split="train")
	sharding_axis_dims = (1, 1, -1, 1)
	new_repo_id = "Test"
	dtype = jnp.bfloat16

	model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
		pretrained_model_name_or_path,
		auto_shard_model=True,
		sharding_axis_dims=sharding_axis_dims,
		config_kwargs=ed.EasyDeLBaseConfigDict(
			use_scan_mlp=False,
			attn_dtype=jnp.bfloat16,
			freq_max_position_embeddings=max_length,
			mask_max_position_embeddings=max_length,
			attn_mechanism=ed.AttentionMechanisms.VANILLA,
		),
		param_dtype=dtype,
		torch_dtype=torch.bfloat16,
		dtype=dtype,
        	from_torch=True,
		precision=lax.Precision("fastest"),
	)
	config = model.config
	tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path,trust_remote_code=True)
	tokenizer.pad_token = (tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token)
	tokenizer.padding_side = "right"
	train_arguments = ed.SFTConfig(
		num_train_epochs=10,
		learning_rate=5e-5,
		learning_rate_end=0,
		warmup_steps=100,
		optimizer=ed.EasyDeLOptimizers.ADAMW,
		scheduler=ed.EasyDeLSchedulers.WARM_UP_COSINE,
		weight_decay=0.02,
		total_batch_size=8,
		max_sequence_length=max_length,
		gradient_accumulation_steps=1,
		do_last_save=True,
		model_name=new_repo_id,
		track_memory=False,
        	packing=True,
        	num_of_sequences=max_length,
		dataset_text_field=None,
		dataset_num_proc=32,
	)

	trainer = ed.SFTTrainer(
		processing_class=tokenizer,
		arguments=train_arguments,
		model=model,
		train_dataset=train_dataset,
		eval_dataset=None,
		formatting_func=lambda x: [tokenizer.apply_chat_template(x["messages"], tokenize=False)]
	)

	output = trainer.train()
	logger.info("Training Done")
	tokenizer.save_pretrained(output.last_save_file_name)
train()

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

Can you rerun the code? There was an issue with the loss function, which wasn't using the fused version.

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

And since the sharding mechanism you're using is tensor parallel you can except OOM but not on 1k sequence length

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

In v0.0.80 trainer will automatically use gradient checkpointing (this behavior is removed in 0.1.0 and you should pass gradient_checkponiting to model_kwargs (ill take blame for not having good documentation))

@salrowili
Copy link
Author

salrowili commented Jan 8, 2025

You are right!. In 0.0.80, it was part of training arguments as we can see in this example :

gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,

However, it was removed in the recent updates
class TrainingArguments:

after updating the code with :

                config_kwargs=ed.EasyDeLBaseConfigDict(
                        use_scan_mlp=False,
                        attn_dtype=jnp.bfloat16,
                        freq_max_position_embeddings=max_length,
                        mask_max_position_embeddings=max_length,
                        attn_mechanism=ed.AttentionMechanisms.VANILLA,
                        gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
                ),

i was able to run SFT code with 8 batch size but i got couple of warning:

[easydel.trainers.base_trainer] Prevent Running Model Due to NaN Loss

FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.

The issue of NaN is presented even with 0.0.80 every time i use meta-llama/Llama-3.1-8B-Instruct model with Packing=True. It would go in some runs and in some runs will appear. I have worked around this issue by re-running the script multiple time till i had no NaN, without changing any arguments. With Packing=False, the issue would disappear .This issue is not presented with other llama3.2 models.

Last note on sharding_axis_dims = (1, 1, -1, 1) choice. This setting give me 113 FLOPS (0.0.8) with TPUv4-8 against other sharding axis setting (98 FLOPS) . Hence, that why i chose it over other options.

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

Look flops calculation method is changed in last version every thing was manually calculated but in this version it's calculated from jax analysis so except it to be wrong for example you might be running and getting 160flops but in some parts xla play a bit dumb and show 130Flops

Check this out

jax-ml/jax#17912

@salrowili
Copy link
Author

Thank you for the detailed reply.

I have tested the TFLOPs in term of runtime speed on both 0.0.80 and 0.1dev using the same example of WebQuestion with llama3.1 8B.

As you can see from the results, there is a problem with the speed with the recent updated even when with using different sharding strategies. I let the code run for a while till s/it metric become stable.

You can see that we have double the speed with 0.0.80. Also notice how (1,1,-1,1) is the best setting in term of speed for TPUv-8 as i stated earlier. Actually, the difference between (1,1,-1,1) and (1,1,1,-1) become worse (almost double) with the new update. I have also noticed that it will take a while (3-5 mins) before the script start running with 0.1.0.dev update, so we should add 3-5 mins to the runtime to have an accurate head-to-head comparison.

0.0.80

(1 ,-1 , 1, 1)

27%|▎| 71/260 [03:12<03:56, 1.25s/it, TFLOPs=101, accuracy=0.8993157, epoch=2, learning_rate=3.5015e-05, loss=0.389, max_grad_norm=1.3, mean_accuracy=0.8596358, mean_grad_norm=0.0757, mean_loss=0.78678674, perplexi

(1,1,-1,1)

46%|▍| 120/260 [03:26<02:21, 1.01s/it, TFLOPs=132, accuracy=0.93218476, epoch=4, learning_rate=4.8624297e-05, loss=0.248, max_grad_norm=2.33, mean_accuracy=0.88241583, mean_grad_norm=0.0815, mean_loss=0.5953737, pe

(1,1,1,-1)

22%|▏| 58/260 [03:06<04:23, 1.30s/it, TFLOPs=97.2, accuracy=0.8931932, epoch=2, learning_rate=2.85215e-05, loss=0.457, max_grad_norm=2.42, mean_accuracy=0.850847, mean_grad_norm=0.0981, mean_loss=0.87310237, perple

-------------------------------------------------------------------------------------------

0.1.0.dev

(1,-1,1,1)

15%|▏| 40/260 [02:48<15:06, 4.12s/it, TFLOPs=5.63e+13, accuracy=0.889, epoch=1, learning_rate=1.9530498e-05, loss=0.494, max_grad_norm=2.92, mean_accuracy=0.83462554, mean_grad_norm=0.109, mean_loss=1.0513492, perp

(1,1,-1,1)

42%|▍| 110/260 [03:53<05:15, 2.10s/it, TFLOPs=1.16e+14, accuracy=0.923, epoch=4, learning_rate=4.9610666e-05, loss=0.267, max_grad_norm=1.88, mean_accuracy=0.87859696, mean_grad_norm=0.0786, mean_loss=0.6243303, pe

(1,1,1,-1)

58%|▌| 150/260 [10:37<07:44, 4.22s/it, TFLOPs=5.58e+13, accuracy=0.94, epoch=5, learning_rate=3.9294693e-05, loss=0.208, max_grad_norm=1.25, mean_accuracy=0.8924935, mean_grad_norm=0.0669, mean_loss=0.52312225, per

-------------------------------------------------------------------------------------------------------

Script to run the code on 0.0.80

import easydel as ed
from easydel.utils.analyze_memory import SMPMemoryMonitor  # Optional for memory analysis
import jax
from transformers import AutoTokenizer
from jax import numpy as jnp, sharding, lax, random as jrnd
from huggingface_hub import HfApi
import datasets
from flax.core import FrozenDict
from datasets import load_dataset
import jsonlines
import torch
import sys
import wandb
wandb.init(project="test",mode='online')

json_entry=[]
qa_data = load_dataset("Stanford/web_questions" ,split="train")
instruction ="You are a helpful AI assistant."
for item in qa_data:
	question=item["question"]
	answer=item["answers"][0]
	json_entry.append({"messages": [{"role": "system", "content": instruction}, {"role": "user", "content": item["question"]}, {"role": "assistant", "content": answer}]})

with jsonlines.open('SFT_Train.jsonl', 'w') as writer:
	writer.write_all(json_entry)
train_dataset = load_dataset("json", data_files={"train" :"/home/big35manf/SFT_Train.jsonl"},split="train")
PartitionSpec, api = sharding.PartitionSpec, HfApi()
sharding_axis_dims = (1, -1, 1, 1)
max_length = 1024
input_shape = (len(jax.devices()), max_length)
pretrained_model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"
pretrained_model_name_or_path_tokenizer = pretrained_model_name_or_path
new_repo_id = "Test"
dtype = jnp.bfloat16

model, params = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
	pretrained_model_name_or_path,
	input_shape=input_shape,
	auto_shard_params=True,
	sharding_axis_dims=sharding_axis_dims,
	config_kwargs=ed.EasyDeLBaseConfigDict(
		use_scan_mlp=False,
		freq_max_position_embeddings=max_length,
		mask_max_position_embeddings=max_length,
		attn_dtype=jnp.bfloat16,
		attn_mechanism=ed.AttentionMechanisms.VANILLA,
	),
	param_dtype=dtype,
	dtype=dtype,
	torch_dtype=torch.bfloat16,
        from_torch=True,
	precision=lax.Precision("fastest"),
)

config = model.config
model_use_tie_word_embedding = config.tie_word_embeddings
model_parameters = FrozenDict({"params": params})

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path_tokenizer, trust_remote_code=True)
tokenizer.pad_token = (tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token)
tokenizer.padding_side = "right"

train_arguments = ed.TrainingArguments(
	num_train_epochs=10,
	learning_rate=5e-5,
	learning_rate_end=0,
	warmup_steps=100,
	optimizer=ed.EasyDeLOptimizers.ADAMW,
	scheduler=ed.EasyDeLSchedulers.WARM_UP_COSINE,
	weight_decay=0.02,
	total_batch_size=8,
	max_sequence_length=max_length,
	gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
	sharding_array=sharding_axis_dims,
	gradient_accumulation_steps=1,
	init_input_shape=input_shape,
	dtype=dtype,
	do_last_save=False,
	param_dtype=dtype,
	model_name=new_repo_id,
	training_time="70H",
	track_memory=False,
)

trainer = ed.SFTTrainer(
	arguments=train_arguments,
	model=model,
	train_dataset=train_dataset,
	eval_dataset=None,
	tokenizer=tokenizer,
	dataset_text_field=None,
	formatting_func=lambda x: [tokenizer.apply_chat_template(x["messages"], tokenize=False)],
	packing=True,
	num_of_sequences=max_length,
	dataset_num_proc=32,
)

output = trainer.train(model_parameters=model_parameters, state=None)
print("Saving the PyTorch Model")
trainer.save_pretrained(output.state, to_torch=True)

Script to run the code on 0.1.0.dev

import easydel as ed
from easydel.utils.analyze_memory import SMPMemoryMonitor  # Optional for memory analysis
import jax
from jax import numpy as jnp, sharding, lax, random as jrnd
from huggingface_hub import HfApi
import datasets
from flax.core import FrozenDict
from datasets import load_dataset
import jsonlines ## pip install jsonlines
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
import wandb
wandb.init(project="test")
def train():
	pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct"
	max_length=1024
	PartitionSpec, api = sharding.PartitionSpec, HfApi()
	json_entry=[]
	qa_data = load_dataset("Stanford/web_questions" ,split="train")
	instruction ="You are a helpful AI assistant."
	for item in qa_data:
		question=item["question"]
		answer=item["answers"][0]
		json_entry.append({"messages": [{"role": "system", "content": instruction}, {"role": "user", "content": item["question"]}, {"role": "assistant", "content": answer}]})

	with jsonlines.open('SFT_Train.jsonl', 'w') as writer:
    		writer.write_all(json_entry)

	train_dataset = load_dataset("json", data_files={"train" :"/home/big35manf/SFT_Train.jsonl"},split="train")
	sharding_axis_dims = (1, -1, 1, 1)
	new_repo_id = "Test"
	dtype = jnp.bfloat16

	model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
		pretrained_model_name_or_path,
		auto_shard_model=True,
		sharding_axis_dims=sharding_axis_dims,
		config_kwargs=ed.EasyDeLBaseConfigDict(
			use_scan_mlp=False,
			attn_dtype=jnp.bfloat16,
			freq_max_position_embeddings=max_length,
			mask_max_position_embeddings=max_length,
			attn_mechanism=ed.AttentionMechanisms.VANILLA,
			gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
		),
		param_dtype=dtype,
		torch_dtype=torch.bfloat16,
		dtype=dtype,
        	from_torch=True,
		precision=lax.Precision("fastest"),
	)
	config = model.config
	tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path,trust_remote_code=True)
	tokenizer.pad_token = (tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token)
	tokenizer.padding_side = "right"
	train_arguments = ed.SFTConfig(
		num_train_epochs=10,
		learning_rate=5e-5,
		learning_rate_end=0,
		warmup_steps=100,
		optimizer=ed.EasyDeLOptimizers.ADAMW,
		scheduler=ed.EasyDeLSchedulers.WARM_UP_COSINE,
		weight_decay=0.02,
		total_batch_size=8,
		max_sequence_length=max_length,
		gradient_accumulation_steps=1,
		do_last_save=True,
		model_name=new_repo_id,
		track_memory=False,
        	packing=True,
        	num_of_sequences=max_length,
		dataset_text_field=None,
		dataset_num_proc=32,
	)

	trainer = ed.SFTTrainer(
		processing_class=tokenizer,
		arguments=train_arguments,
		model=model,
		train_dataset=train_dataset,
		eval_dataset=None,
		formatting_func=lambda x: [tokenizer.apply_chat_template(x["messages"], tokenize=False)]
	)

	output = trainer.train()
	logger.info("Training Done")
	tokenizer.save_pretrained(output.last_save_file_name)
train()

@erfanzar
Copy link
Owner

Thank you @salrowili for bringing up these issues and for your detailed feedback!
I want to clarify that several of the speed-related issues stem from our flax/NNX integration. I'm currently working on fixing these with @cgarciae's help. Going forward, any easydel issues that are specifically related to flax/NNX will be redirected to the flax/NNX repository with additional context and documentation.

@erfanzar
Copy link
Owner

@salrowili #185

@salrowili
Copy link
Author

salrowili commented Jan 14, 2025

Great!. Thank you @erfanzar for opening the topic. I have one question. I am planning to start sharing my code with the topic you just opened #185, but i am still struggling to run my codes on the new EasyDEL 0.1dev release. Its very slow compared to 0.0.80 and you have told me that it due to flax/NNX integration. The inference with the new 0.1dev is fast, but the problem is with SFT code. Do you have any estimation when the issue will be fixed? Because if it will be soon, i will wait and till it fixed and share my codes with 0.1dev release.

@erfanzar
Copy link
Owner

Hi @salrowili,

Many performance issues related to the new arguments and the updated base trainer have been resolved. These include fixes for duplicated if statements, redundant code checks, and incorrect caching mechanisms.

You can rerun your benchmark to see if there are any remaining performance issues (avoid using ahead-of-time compilation).

With Qwen-2 7B, batch size 8, and full sequence parallelism, I was able to achieve 6 seconds per iteration. Let me know how it goes!

@salrowili
Copy link
Author

Hi @erfanzar . That's a great news!. Can you share the code you have used to achieve this performance?

@erfanzar
Copy link
Owner

@salrowili im using tests/trainer_test.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants