Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Adding TP overlap support for te.Linear with parallel_mode="column" #1343

Merged
merged 8 commits into from
Jan 13, 2025

Conversation

denera
Copy link
Collaborator

@denera denera commented Nov 20, 2024

Description

te.Linear currently only supports TP overlap with parallel_mode="row" where it overlaps reduce-scatter in the forward pass, and all-gather with dgrad in the backward pass.

This PR adds new options to enable all-gather overlap in the forward pass, and reduce-scatter overlap with dgrad in the backward pass, when parallel_mode="column".

Fixes #1312

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera added enhancement New feature or request 1.13.0 labels Nov 20, 2024
@denera denera requested review from timmoon10 and ksivaman November 20, 2024 21:52
@denera denera self-assigned this Nov 20, 2024
@denera denera force-pushed the linear-tp-overlap-ag-fprop-rs-dgrad branch from 90458d4 to 4e3e61a Compare November 20, 2024 21:53
@denera
Copy link
Collaborator Author

denera commented Nov 20, 2024

/te-ci pytorch L1

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, pending CI.

Comment on lines 866 to 869
ub_overlap_ag: bool = False,
ub_overlap_rs: bool = False,
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should seriously consider deprecating these UB options and just passing in a dict. The UB interface is unstable and will likely be so for some while. A dict would be better for backward compatibility (reinterpret old options) and forward compatibility (ignore unknown options). This would be especially helpful for Mcore integration.

For example, the operation-based API passes in UB options with a dict:

userbuffers_options: Optional[dict[str, Any]] = None,

Comment on lines 931 to 935
assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), "Internal TE error!"
assert not (self.ub_overlap_ag_dgrad and self.ub_overlap_rs_dgrad), "Internal TE error!"
assert not (
self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad)
), "Internal TE error!"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More descriptive error messages would be helpful.

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, much needed

@denera denera force-pushed the linear-tp-overlap-ag-fprop-rs-dgrad branch from 3951993 to 360c127 Compare December 17, 2024 20:48
@denera denera added 1.14.0 and removed 1.13.0 labels Dec 17, 2024
@denera
Copy link
Collaborator Author

denera commented Dec 17, 2024

/te-ci pytorch L1

1 similar comment
@denera
Copy link
Collaborator Author

denera commented Dec 18, 2024

/te-ci pytorch L1

@denera denera force-pushed the linear-tp-overlap-ag-fprop-rs-dgrad branch from 744a96f to 9adf99f Compare January 13, 2025 20:23
@denera denera merged commit 2402406 into NVIDIA:main Jan 13, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1.14.0 enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Linear does not support TP comm overlap for Column Parallel mode
3 participants