Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow external positions to be inputed in RoPE embedding layer #926

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

Firenze11
Copy link

Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after i_proj. Unlike the implementation of current RoFormerQKVLinear, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by MultiheadAttention.

Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`.
@Firenze11 Firenze11 requested review from ruomingp, markblee and a team as code owners January 15, 2025 02:56
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config):
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding.
theta: float = 10000.0 # The scale of base frequency.

def forward(self, positions: Tensor) -> Tensor:
def default_query_positions(self, max_seq_len: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def default_query_positions(self, max_seq_len: int) -> Tensor:
def _default_query_positions(self, max_seq_len: int) -> Tensor:

Users should pass max_seq_len rather than calling this method publicly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be situations that we want to access the default query positions from the outside, such as getting the default positions and computing the positions based on it before passing it to forward. Therefore we want to keep this public.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect subclasses to override this method?

If not, it will be more readable for callers to call jnp.arange directly given the implementation is only one line.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we do expect subclasses to override this method. One example is that we might want to specify the rotation start index and rotation end index for this embedding class. And the default embedding positions will be different then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Given this, it's reasonable to consolidate the position computation logic in this class.

axlearn/common/attention.py Outdated Show resolved Hide resolved
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config):
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding.
theta: float = 10000.0 # The scale of base frequency.

def forward(self, positions: Tensor) -> Tensor:
def default_query_positions(self, max_seq_len: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect subclasses to override this method?

If not, it will be more readable for callers to call jnp.arange directly given the implementation is only one line.

axlearn/common/attention.py Outdated Show resolved Hide resolved
Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

axlearn/common/attention.py Outdated Show resolved Hide resolved
Comment on lines +1243 to +1244
if positions is None:
if max_seq_len is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if both positions and max_seq_len are provided? should we check that they are consistent?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case positions will take precedence and we ignore max_seq_len. We won't need max_seq_len if the client provides explicit positions. Will add that to docstring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can error if they are both provided?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure can do this. @ruomingp any opinion on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems better to throw an error if they are both provided or when they are inconsistent. WDYT?

@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config):
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding.
theta: float = 10000.0 # The scale of base frequency.

def forward(self, positions: Tensor) -> Tensor:
def default_query_positions(self, max_seq_len: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Given this, it's reasonable to consolidate the position computation logic in this class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants