-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOM issue with same batch size that was running ok on 0.0.80 #184
Comments
Hello and thanks for reporting issue, can you share the code please? |
Hi @erfanzar ,
|
Can you rerun the code? There was an issue with the loss function, which wasn't using the fused version. |
And since the sharding mechanism you're using is tensor parallel you can except OOM but not on 1k sequence length |
In v0.0.80 trainer will automatically use gradient checkpointing (this behavior is removed in 0.1.0 and you should pass gradient_checkponiting to model_kwargs (ill take blame for not having good documentation)) |
You are right!. In 0.0.80, it was part of training arguments as we can see in this example :
However, it was removed in the recent updates
after updating the code with :
i was able to run SFT code with 8 batch size but i got couple of warning:
The issue of NaN is presented even with 0.0.80 every time i use meta-llama/Llama-3.1-8B-Instruct model with Packing=True. It would go in some runs and in some runs will appear. I have worked around this issue by re-running the script multiple time till i had no NaN, without changing any arguments. With Packing=False, the issue would disappear .This issue is not presented with other llama3.2 models. Last note on sharding_axis_dims = (1, 1, -1, 1) choice. This setting give me 113 FLOPS (0.0.8) with TPUv4-8 against other sharding axis setting (98 FLOPS) . Hence, that why i chose it over other options. |
Look flops calculation method is changed in last version every thing was manually calculated but in this version it's calculated from jax analysis so except it to be wrong for example you might be running and getting 160flops but in some parts xla play a bit dumb and show 130Flops Check this out |
Thank you for the detailed reply. I have tested the TFLOPs in term of runtime speed on both 0.0.80 and 0.1dev using the same example of WebQuestion with llama3.1 8B. As you can see from the results, there is a problem with the speed with the recent updated even when with using different sharding strategies. I let the code run for a while till s/it metric become stable. You can see that we have double the speed with 0.0.80. Also notice how (1,1,-1,1) is the best setting in term of speed for TPUv-8 as i stated earlier. Actually, the difference between (1,1,-1,1) and (1,1,1,-1) become worse (almost double) with the new update. I have also noticed that it will take a while (3-5 mins) before the script start running with 0.1.0.dev update, so we should add 3-5 mins to the runtime to have an accurate head-to-head comparison. 0.0.80(1 ,-1 , 1, 1)
(1,1,-1,1)
(1,1,1,-1)
-------------------------------------------------------------------------------------------0.1.0.dev(1,-1,1,1)
(1,1,-1,1)
(1,1,1,-1)
------------------------------------------------------------------------------------------------------- Script to run the code on 0.0.80
Script to run the code on 0.1.0.dev
|
Thank you @salrowili for bringing up these issues and for your detailed feedback! |
Great!. Thank you @erfanzar for opening the topic. I have one question. I am planning to start sharing my code with the topic you just opened #185, but i am still struggling to run my codes on the new EasyDEL 0.1dev release. Its very slow compared to 0.0.80 and you have told me that it due to flax/NNX integration. The inference with the new 0.1dev is fast, but the problem is with SFT code. Do you have any estimation when the issue will be fixed? Because if it will be soon, i will wait and till it fixed and share my codes with 0.1dev release. |
Hi @salrowili, Many performance issues related to the new arguments and the updated base trainer have been resolved. These include fixes for duplicated You can rerun your benchmark to see if there are any remaining performance issues (avoid using ahead-of-time compilation). With Qwen-2 7B, batch size 8, and full sequence parallelism, I was able to achieve 6 seconds per iteration. Let me know how it goes! |
Hi @erfanzar . That's a great news!. Can you share the code you have used to achieve this performance? |
@salrowili im using tests/trainer_test.py |
Hi,
I've noticed that recent updates are causing the SFT trainer code to throw an OutOfMemory (OOM) error with the same batch size that previously ran without issue on version 0.0.80.
I attempted SFT tuning using bfloat16 (no LoRA) with LLaMA 3.1 8B, max_length=1024, and batch=8 on TPUv4-8, but encountered an OOM error. This fine-tuning setup was working ok with 0.0.80.
The text was updated successfully, but these errors were encountered: