Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Data Parallel Training #76

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Conversation

IanMagnusson
Copy link
Contributor

@IanMagnusson IanMagnusson commented Aug 30, 2022

What is here

Fixes support for distributed training with data parallelism. Previously torch metrics would attempt to synchronize across processes during validation call back and would cause a crash. Also the final model output by the FinetuneStep would be on cpu rather than GPU as is the case for the non-distributed usage; now the returned model is on GPU.

Limitations

Validation is still done in a single process with how data parallelism.

Reproduction

Running with and without multiple devices produces exactly the same validation metrics, though the print out is slightly different due to tasks being copied:

python -m catwalk.train --model rc::gpt2 --task piqa --device_count 1 --batch_size 16

....
Running log-likelihood queries: 100%|##########| 2000/2000 [00:07<00:00, 270.13it/s]
Calculating metrics: 100%|##########| 1000/1000 [00:00<00:00, 4003.12it/s]85.04it/s]
Metrics for piqa: acc: 0.647######  | 804/1000 [00:00<00:00, 4020.71it/s]
Running log-likelihood queries: 100%|##########| 2000/2000 [00:07<00:00, 260.74it/s]_val_loss=3.02, val_loss=3.02]  
Calculating metrics: 100%|##########| 1000/1000 [00:00<00:00, 4000.07it/s]76.95it/s]
Metrics for piqa: acc: 0.648#####9  | 799/1000 [00:00<00:00, 3991.17it/s]
...

python -m catwalk.train --model rc::gpt2 --task piqa --device_count 2 --batch_size 16

...
Running log-likelihood queries: 100%|##########| 2000/2000 [00:07<00:00, 271.48it/s]
Calculating metrics: 100%|##########| 1000/1000 [00:00<00:00, 3683.20it/s]08.65it/s]
Metrics for <catwalk.tasks.eleuther.EleutherTask object at 0x7fe861c05490>: acc: 0.647
Running log-likelihood queries: 100%|##########| 2000/2000 [00:06<00:00, 288.04it/s]_val_loss=3.02, val_loss=3.02]  
Calculating metrics: 100%|##########| 1000/1000 [00:00<00:00, 3767.45it/s]25.77it/s]
Metrics for <catwalk.tasks.eleuther.EleutherTask object at 0x7fe861c05490>: acc: 0.648
...

@IanMagnusson IanMagnusson marked this pull request as ready for review September 12, 2022 22:06
@IanMagnusson IanMagnusson requested a review from dirkgr September 12, 2022 22:06
@IanMagnusson IanMagnusson changed the title Distributed training Distributed Data Parallel Training Sep 12, 2022
Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this. In this version of the callback, I didn't have to do this. I think the trick is to make sure that each worker runs through the same data.

But also, consider that in my latest training version (which I have only in a branch as yet), I don't even have the callback anymore. Can we just get rid of that whole problem area by making sure the trainable model computes its metrics during forward()?

catwalk/steps.py Outdated
Comment on lines 263 to 264
# if distributed, this model hasn't been sent to device yet
trainable_model.to(resolve_device())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to return a model that's tied to a device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to return a model that's tied to a device?

Hmm I had a crash where not moving the model back to the GPU would cause inputs on the GPU to attempt to be processed in the CPU model during the final evaluation. I can't replicate this now thought. I suspect it was required for the failed attempt to integrate the fairscale code in #79.

As it stands here there is no crash without this code, but it does mean that the final evaluation will run on CPU rather than GPU as it would it if it were running a non-trainable model. I can't just make the TrainableRankClassificationModel.predict() use resolve_device() like the non-trainable predict() does, because invoking predict at all inside the training process would mess up the distributed device map.

Perhaps the best place to move it to GPU would be in catwalk.train.py like this?

model_step = FinetuneStep(
            model=args.model,
            tasks=tasks,
            batch_size=args.batch_size,
            grad_accum=args.grad_acc,
            device_count=args.device_count
        )

model_step = model_step.result().to(resolve_device())

@IanMagnusson
Copy link
Contributor Author

I don't like this. In this version of the callback, I didn't have to do this. I think the trick is to make sure that each worker runs through the same data.

But also, consider that in my latest training version (which I have only in a branch as yet), I don't even have the callback anymore. Can we just get rid of that whole problem area by making sure the trainable model computes its metrics during forward()?

I agree, getting rid of the validation callback all together would be the best solution. I'm concerned that's a bit beyond the scope of what I can accomplish this week. All this distributed processing stuff has me mostly just feeling around in the dark because of my lack of systems background.

This PR is not a necessary dependency of the IA3 PR #81, so if it's going to be superseded by your rework of the training code in that branch then perhaps we should just skip this PR?

@IanMagnusson IanMagnusson marked this pull request as draft September 16, 2022 02:18
@IanMagnusson
Copy link
Contributor Author

I've reverted theses changes in the IA3 PR #81 as they are not actually necessary for that PR, and I don't want this one to block that.

@dirkgr
Copy link
Member

dirkgr commented Sep 20, 2022

I will revisit this after #84 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants