-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow external positions to be inputed in RoPE embedding layer #926
base: main
Are you sure you want to change the base?
Conversation
Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`.
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config): | |||
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding. | |||
theta: float = 10000.0 # The scale of base frequency. | |||
|
|||
def forward(self, positions: Tensor) -> Tensor: | |||
def default_query_positions(self, max_seq_len: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def default_query_positions(self, max_seq_len: int) -> Tensor: | |
def _default_query_positions(self, max_seq_len: int) -> Tensor: |
Users should pass max_seq_len
rather than calling this method publicly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be situations that we want to access the default query positions from the outside, such as getting the default positions and computing the positions based on it before passing it to forward
. Therefore we want to keep this public.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect subclasses to override this method?
If not, it will be more readable for callers to call jnp.arange
directly given the implementation is only one line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we do expect subclasses to override this method. One example is that we might want to specify the rotation start index and rotation end index for this embedding class. And the default embedding positions will be different then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Given this, it's reasonable to consolidate the position computation logic in this class.
Co-authored-by: Mark Lee <[email protected]>
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config): | |||
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding. | |||
theta: float = 10000.0 # The scale of base frequency. | |||
|
|||
def forward(self, positions: Tensor) -> Tensor: | |||
def default_query_positions(self, max_seq_len: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect subclasses to override this method?
If not, it will be more readable for callers to call jnp.arange
directly given the implementation is only one line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
if positions is None: | ||
if max_seq_len is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if both positions
and max_seq_len
are provided? should we check that they are consistent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case positions
will take precedence and we ignore max_seq_len
. We won't need max_seq_len
if the client provides explicit positions. Will add that to docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can error if they are both provided?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure can do this. @ruomingp any opinion on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems better to throw an error if they are both provided or when they are inconsistent. WDYT?
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config): | |||
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding. | |||
theta: float = 10000.0 # The scale of base frequency. | |||
|
|||
def forward(self, positions: Tensor) -> Tensor: | |||
def default_query_positions(self, max_seq_len: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Given this, it's reasonable to consolidate the position computation logic in this class.
Co-authored-by: Ruoming Pang <[email protected]>
Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after
i_proj
. Unlike the implementation of currentRoFormerQKVLinear
, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported byMultiheadAttention
.