Skip to content

Commit

Permalink
added @DataClass so that lint check does not complain about field()
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng committed Dec 6, 2024
1 parent 4eda388 commit 082d589
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
8 changes: 7 additions & 1 deletion transformer_engine/jax/praxis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union
from dataclasses import field
from dataclasses import field, dataclass

from praxis import pax_fiddle
from praxis.base_layer import init_var
Expand All @@ -28,6 +28,7 @@ def _generate_ln_scale_init(scale_init):
return scale_init


@dataclass
class TransformerEngineBaseLayer(BaseLayer):
"""TransformerEngineBaseLayer"""

Expand Down Expand Up @@ -67,6 +68,7 @@ def create_layer(self, name, flax_module_cls):
self.create_child(name, flax_module_p.clone())


@dataclass
class LayerNorm(TransformerEngineBaseLayer):
"""LayerNorm"""

Expand Down Expand Up @@ -103,6 +105,7 @@ def __call__(self, x: JTensor) -> JTensor:
return self.layer_norm(x)


@dataclass
class FusedSoftmax(TransformerEngineBaseLayer):
"""FusedSoftmax"""

Expand All @@ -124,6 +127,7 @@ def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JT
return self.fused_softmax(x, mask, bias)


@dataclass
class Linear(TransformerEngineBaseLayer):
"""Linear"""

Expand Down Expand Up @@ -165,6 +169,7 @@ def __call__(self, x: JTensor) -> JTensor:
return self.linear(x)


@dataclass
class LayerNormLinear(TransformerEngineBaseLayer):
"""LayerNormLinear"""

Expand Down Expand Up @@ -228,6 +233,7 @@ def __call__(self, x: JTensor) -> JTensor:
return self.ln_linear(x)


@dataclass
class LayerNormMLP(TransformerEngineBaseLayer):
"""LayerNormMLP"""

Expand Down
6 changes: 5 additions & 1 deletion transformer_engine/jax/praxis/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
from functools import partial
from typing import Optional, Sequence, Tuple
from dataclasses import field
from dataclasses import field, dataclass
import warnings

from praxis import pax_fiddle
Expand All @@ -22,6 +22,7 @@
from ..attention import AttnBiasType, AttnMaskType


@dataclass
class RelativePositionBiases(TransformerEngineBaseLayer):
"""RelativePositionBiases"""

Expand Down Expand Up @@ -67,6 +68,7 @@ def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = T
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)


@dataclass
class DotProductAttention(TransformerEngineBaseLayer):
"""DotProductAttention"""

Expand Down Expand Up @@ -125,6 +127,7 @@ def __call__(
)


@dataclass
class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention"""

Expand Down Expand Up @@ -258,6 +261,7 @@ def __call__(
)


@dataclass
class TransformerLayer(TransformerEngineBaseLayer):
"""TransformerLayer"""

Expand Down

0 comments on commit 082d589

Please sign in to comment.