Skip to content

Commit

Permalink
debuging falcon new arc model on TPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed May 26, 2024
1 parent f4d95af commit 61e11ee
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/python/easydel/modules/falcon/modelling_falcon_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,13 @@ def setup(self) -> None:
self.config.num_ln_in_parallel_attn = 2
config = self.config

if not config.parallel_attn:
self.input_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)
self.post_attention_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)
if config.new_decoder_architecture and config.num_ln_in_parallel_attn == 2:
self.ln_attn = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)
self.ln_mlp = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)
else:
if config.num_ln_in_parallel_attn == 2:
self.ln_attn = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)
self.ln_mlp = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)
else:
self.input_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)
self.input_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)
if not config.parallel_attn:
self.post_attention_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon, dtype=self.dtype)
attn_block = FlaxFalconAttention
mlp_block = FlaxFalconMlp
if self.config.gradient_checkpointing != "":
Expand Down

0 comments on commit 61e11ee

Please sign in to comment.