-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed Data Parallel Training #76
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this. In this version of the callback, I didn't have to do this. I think the trick is to make sure that each worker runs through the same data.
But also, consider that in my latest training version (which I have only in a branch as yet), I don't even have the callback anymore. Can we just get rid of that whole problem area by making sure the trainable model computes its metrics during forward()
?
catwalk/steps.py
Outdated
# if distributed, this model hasn't been sent to device yet | ||
trainable_model.to(resolve_device()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have to return a model that's tied to a device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have to return a model that's tied to a device?
Hmm I had a crash where not moving the model back to the GPU would cause inputs on the GPU to attempt to be processed in the CPU model during the final evaluation. I can't replicate this now thought. I suspect it was required for the failed attempt to integrate the fairscale code in #79.
As it stands here there is no crash without this code, but it does mean that the final evaluation will run on CPU rather than GPU as it would it if it were running a non-trainable model. I can't just make the TrainableRankClassificationModel.predict()
use resolve_device()
like the non-trainable predict()
does, because invoking predict at all inside the training process would mess up the distributed device map.
Perhaps the best place to move it to GPU would be in catwalk.train.py
like this?
model_step = FinetuneStep(
model=args.model,
tasks=tasks,
batch_size=args.batch_size,
grad_accum=args.grad_acc,
device_count=args.device_count
)
model_step = model_step.result().to(resolve_device())
I agree, getting rid of the validation callback all together would be the best solution. I'm concerned that's a bit beyond the scope of what I can accomplish this week. All this distributed processing stuff has me mostly just feeling around in the dark because of my lack of systems background. This PR is not a necessary dependency of the IA3 PR #81, so if it's going to be superseded by your rework of the training code in that branch then perhaps we should just skip this PR? |
I've reverted theses changes in the IA3 PR #81 as they are not actually necessary for that PR, and I don't want this one to block that. |
I will revisit this after #84 is merged. |
What is here
Fixes support for distributed training with data parallelism. Previously torch metrics would attempt to synchronize across processes during validation call back and would cause a crash. Also the final model output by the FinetuneStep would be on cpu rather than GPU as is the case for the non-distributed usage; now the returned model is on GPU.
Limitations
Validation is still done in a single process with how data parallelism.
Reproduction
Running with and without multiple devices produces exactly the same validation metrics, though the print out is slightly different due to tasks being copied:
python -m catwalk.train --model rc::gpt2 --task piqa --device_count 1 --batch_size 16
python -m catwalk.train --model rc::gpt2 --task piqa --device_count 2 --batch_size 16