Skip to content

Commit

Permalink
add layer norm weight plus 1
Browse files Browse the repository at this point in the history
  • Loading branch information
Yejing-Lai committed Apr 18, 2024
1 parent bcedecd commit 5895994
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
1 change: 1 addition & 0 deletions megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
else:
from .rmsnorm import RMSNorm
from torch.nn import LayerNorm
from .layer_norm_p1 import LayerNorm1P

from .distributed import DistributedDataParallel
from .bert_model import BertModel
Expand Down
14 changes: 14 additions & 0 deletions megatron/model/layer_norm_p1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import math

import torch
import torch.nn as nn


class LayerNorm1P(torch.nn.LayerNorm):
def __init__(self, *args, **kwargs):
super(LayerNorm1P, self).__init__(*args, **kwargs)

def forward(self, input):
weight_plus_1 = (self.weight + 1)
output = torch.nn.functional.layer_norm(input, self.normalized_shape, weight_plus_1, self.bias, self.eps)
return output
33 changes: 24 additions & 9 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,9 +911,14 @@ def __init__(self, config,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
if args.apply_layernorm_1p:
self.input_layernorm = LayerNorm1P(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
# Self attention.
Expand All @@ -937,9 +942,14 @@ def __init__(self, config,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.post_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
if args.apply_layernorm_1p:
self.input_layernorm = LayerNorm1P(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
# Cross attention.
Expand Down Expand Up @@ -1760,9 +1770,14 @@ def build_layer(layer_number, n_e):
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.final_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
if args.apply_layernorm_1p:
self.input_layernorm = LayerNorm1P(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)

Expand Down

0 comments on commit 5895994

Please sign in to comment.