Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Entmax loss #3

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8543c0e
add entmax loss training for gpt-like models
bpopeters Jan 9, 2024
d3be645
add entmax gpt example for debugging
bpopeters Jan 9, 2024
f946b8f
fix loss indentation problem that breaks entmax training
bpopeters Jan 25, 2024
1c76bf8
remove nonfunctional, unused, and unnecessary shell script
bpopeters Jan 25, 2024
563f968
compute support during training (although do not log it yet)
bpopeters Mar 21, 2024
7228a8b
fix model training bug, accumulate support across gpus
bpopeters Mar 27, 2024
202caed
slightly more informative support logging
bpopeters Apr 1, 2024
8a83cf0
update return_support to return_support_size in all cases
bpopeters May 20, 2024
48a1781
[wip] compute more sparsity stats
bpopeters May 22, 2024
19b478f
add more support logging statistics
bpopeters May 22, 2024
c4cf60c
add multiple loss functions (with some print statements to make sure …
bpopeters Jun 26, 2024
a02cf72
add force_decoded_accuracy as a zero-shot evaluation metric
bpopeters Jun 26, 2024
a8733c4
update logging with teacher forced accuracy
bpopeters Jun 26, 2024
e51656d
update xavier_uniform init (which will probably be unused
bpopeters Jun 26, 2024
055f869
avoid unpacking error when labels is None
bpopeters Jun 26, 2024
e24ebc6
remove print statements for eval computation
bpopeters Jun 26, 2024
f77cf0c
call item() because json cannot handle tensors
bpopeters Jun 26, 2024
68a4a1b
add accuracy at k
bpopeters Jun 27, 2024
24a5ed3
fix guardrail for new task
bpopeters Jun 27, 2024
77fe108
small change to try to get LAMBADA to run
bpopeters Jun 27, 2024
c41f8b0
don't load optimizer state for lambada
bpopeters Jun 27, 2024
4cfc133
write lambada to results file
bpopeters Jun 27, 2024
0182532
fix entmax_bisect_loss return type
bpopeters Jun 28, 2024
9e7f53d
refactor zeroshot gpt evaluation for sparsemax score
bpopeters Jul 12, 2024
f3c66c4
add missing parentheses
bpopeters Jul 12, 2024
4f2c4fc
get item from sparsemax score tensor
bpopeters Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def validate_args(args):
if args.no_pipeline_parallel:
assert args.pipeline_model_parallel_size == 1, \
"pipeline_model_parallel_size must be 1 if pipeline parallel is disabled"

if args.ds_sequence_parallel_size > 1:
assert version.parse(deepspeed.__version__) >= version.parse("0.10.2"), "sequence parallelism requires DeepSpeed version 0.10.2+"

Expand Down Expand Up @@ -432,6 +432,10 @@ def validate_args(args):
assert not args.mos, 'GQA currently does not support args.mos'
assert not args.kd, 'GQA currently does not support args.kd'

# entmax loss
if args.loss_function != "cross_entropy":
assert not args.fp16_lm_cross_entropy, "entmax loss only supports fp32"

# Print arguments.
_print_args("arguments", args)
retro_args = get_retro_args()
Expand Down Expand Up @@ -477,7 +481,7 @@ def core_transformer_config_from_args(args):
kw_args['bias_gelu_fusion'] = False
if args.init_method_xavier_uniform:
kw_args['init_method'] = torch.nn.init.xavier_uniform_
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
kw_args['output_layer_init_method'] = torch.nn.init.xavier_uniform_

return TransformerConfig(**kw_args)

Expand Down Expand Up @@ -632,7 +636,7 @@ def _add_network_size_args(parser):
group.add_argument('--apply-layernorm-1p', action='store_true',
help='Adjust LayerNorm weights such that they are centered '
'around zero. This improves numerical stability.')
group.add_argument('--disable-mem-efficient-ln', action='store_false',
group.add_argument('--disable-mem-efficient-ln', action='store_false',
help='Disable the memory-efficient fused LayerNorm optimization '
'introduced in https://github.com/NVIDIA/apex/pull/1715')
group.add_argument('--apply-residual-connection-post-layernorm',
Expand Down Expand Up @@ -848,7 +852,7 @@ def _add_training_args(parser):
'training runs.')
group.add_argument('--random-ltd',
action='store_true',
help='enable random layer token drop')
help='enable random layer token drop')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
Expand Down Expand Up @@ -940,6 +944,15 @@ def _add_training_args(parser):
dest='gradient_accumulation_fusion')
group.add_argument('--use-dataset-only', type=bool, required=False, default=False,
help='If set to True, only use the megatron dataset for external trainer ')
group.add_argument('--loss-function', default='cross_entropy',
choices=['cross_entropy', 'entmax15', 'sparsemax', 'entmax_bisect'],
help='Loss function for model training')
group.add_argument('--entmax-alpha', type=float, default=1.5,
help='Entmax alpha for entmax_bisect (unused otherwise)')
group.add_argument('--entmax-topk', type=int, default=512,
help='Top k for computation of exact entmax loss (for entmax15 and sparsemax)')
group.add_argument('--entmax-n-iter', type=int, default=30,
help='Number of bisection interations for entmax_bisect')
return parser


Expand Down Expand Up @@ -1034,7 +1047,7 @@ def _add_checkpointing_args(parser):
group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.')
group.add_argument('--no-load-lr-state', action='store_true',
help='Do not load lr state when loading checkpoint.')
help='Do not load lr state when loading checkpoint.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
Expand Down Expand Up @@ -1210,7 +1223,7 @@ def _add_data_args(parser):
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--multiple-valid-sets', action='store_true',
help='multiple separated validation steps')
help='multiple separated validation steps')
group.add_argument('--test-data-path', nargs='*', default=None,
help='Path to the test dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
Expand Down Expand Up @@ -1490,15 +1503,15 @@ def _add_activation_checkpoint_args(parser):
def _add_distillation_args(parser):
group = parser.add_argument_group('Knowledge distillation',
'Distillation Configurations')

group.add_argument('--num-layers-teacher', type=int, default=None,
help='Number of the teacher transformer layers.')
help='Number of the teacher transformer layers.')
group.add_argument('--num-experts-teacher', type=int, nargs='+', default=[1,],
help='number of teacher experts list, MoE related.')
group.add_argument('--hidden-size-teacher', type=int, default=None,
help='Tansformer teacher hidden size.')
group.add_argument('--num-attention-heads-teacher', type=int, default=None,
help='Number of teacher transformer attention heads.')
help='Number of teacher transformer attention heads.')

group.add_argument('--mos', action='store_true',
help='Enable Mixture-of-Students via knolwedge distillation.')
Expand All @@ -1509,7 +1522,7 @@ def _add_distillation_args(parser):
group.add_argument('--kd-temp', default=1.0, type=float)
group.add_argument('--reset-iteration', action='store_true',
help='Reset the iteration count.')

group.add_argument('--load-teacher', type=str, default=None,
help='Directory containing a teacher model checkpoint.')

Expand Down
82 changes: 71 additions & 11 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

"""GPT-2 model."""

from functools import partial

import torch
import entmax

from megatron import get_args
from megatron.core import mpu, tensor_parallel, sequence_parallel
Expand All @@ -24,32 +27,36 @@
except ImportError:
MixedFusedRMSNorm = None

try:
try:
from deepspeed.checkpoint import (
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
TP_REPLICATED_PARAMETER_PATTERNS,
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
)
DS_UNIVERSAL_CHECKPOINT_INFO = True
DS_UNIVERSAL_CHECKPOINT_INFO = True
except ImportError:
DS_UNIVERSAL_CHECKPOINT_INFO = False
DS_UNIVERSAL_CHECKPOINT_INFO = False


def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output,
fp16_lm_cross_entropy):
fp16_lm_cross_entropy,
loss_function, alpha, topk, n_iter, return_support_size=False):

# Output. Format [s b h]
output = parallel_lm_logits(
lm_output,
logit_weights,
parallel_output)

# should it return a None support size in this case? let's say no for now
if labels is None:
# [s b h] => [b s h]
return output.transpose(0,1).contiguous()
else:

if loss_function == "cross_entropy":
# cross entropy
# [b s] => [s b]
labels = labels.transpose(0,1).contiguous()
cross_entropy = sequence_parallel.vocab_sequence_parallel_cross_entropy if mpu.get_sequence_parallel_world_size() > 1 \
Expand All @@ -62,6 +69,42 @@ def post_language_model_processing(lm_output, labels, logit_weights,

# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
support = None
else:
# now: the loss function is "entmax15", "sparsemax", or "entmax_bisect"
loss_funcs = {
"entmax15": partial(entmax.entmax15_loss, k=topk, return_support_size=True),
"sparsemax": partial(entmax.sparsemax_loss, k=topk, return_support_size=True),
"entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter)
}
f = loss_funcs[loss_function]
b, s = labels.size()
output = output.transpose(0, 1).contiguous()
vocab_size = output.size(-1)

# currently entmax_bisect_loss always returns a None support size,
# which is not ideal. This is a stopgap until entmax_bisect_loss is
# fixed
loss = f(output.float().view(-1, vocab_size), labels.view(-1))
if isinstance(loss, tuple):
loss, support = loss
else:
support = None

# old version which breaks because entmax_bisect unexpectedly returned
# a tuple:
'''
if loss_function != "entmax_bisect":
loss, support = f(output.float().view(-1, vocab_size), labels.view(-1))
else:
loss = f(output.float().view(-1, vocab_size), labels.view(-1))
support = None
'''
loss = loss.view(b, s)

if return_support_size:
return loss, support
else:
return loss


Expand All @@ -84,6 +127,10 @@ def __init__(self,
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.return_moe_loss = return_moe_loss
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
self.loss_function = args.loss_function
self.entmax_alpha = args.entmax_alpha
self.entmax_topk = args.entmax_topk
self.entmax_n_iter = args.entmax_n_iter

self.language_model, self._language_model_key = get_language_model(
config=config,
Expand All @@ -106,7 +153,8 @@ def forward(self, input_ids, position_ids, attention_mask,
retriever_position_ids=None,
retriever_attn_mask=None,
labels=None, tokentype_ids=None, inference_params=None,
curriculum_seqlen=None):
curriculum_seqlen=None,
return_support_size=False):
args = get_args()
if curriculum_seqlen is not None:
args.curriculum_seqlen = curriculum_seqlen
Expand Down Expand Up @@ -135,13 +183,25 @@ def forward(self, input_ids, position_ids, attention_mask,
inference_params=inference_params)

if self.post_process:
lm_output = post_language_model_processing(
post_lm_out = post_language_model_processing(
lm_output, labels,
self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(),
self.parallel_output,
self.fp16_lm_cross_entropy)

return lm_output, moe_losses if self.return_moe_loss else lm_output
self.fp16_lm_cross_entropy,
self.loss_function,
self.entmax_alpha,
self.entmax_topk,
self.entmax_n_iter,
return_support_size=True)
if labels is not None:
lm_output, support_size = post_lm_out
else:
lm_output = post_lm_out
# now...what do do about support_size?
if return_support_size:
return lm_output, moe_losses if self.return_moe_loss else lm_output, support_size
else:
return lm_output, moe_losses if self.return_moe_loss else lm_output

def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):

Expand Down Expand Up @@ -210,7 +270,7 @@ def universal_checkpoint_info(self):
]

return info

def CrossEntropy(output, labels):
labels, loss_mask = labels[0], labels[1]

Expand Down
69 changes: 64 additions & 5 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def get_batch_pipe(data):
return (tokens, position_ids, attention_mask), (labels, loss_mask)


def loss_func(loss_mask, moe_loss, mos_loss, output_tensor):
def loss_func(loss_mask, moe_loss, mos_loss, support_size, output_tensor):
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
Expand All @@ -229,11 +229,68 @@ def loss_func(loss_mask, moe_loss, mos_loss, output_tensor):
print_rank_0('>>> total loss: {}, lm loss {}, kd loss {}'.format(loss, averaged_loss[0], mos_loss))
else:
if max(args.num_experts) <= 1:
return loss, {'lm loss': averaged_loss[0]}
if support_size is not None:
support_size = support_size.float()
n_tokens = support_size.size(0)

# how often is support size 1?
fully_peaked = support_size.eq(1).sum()

# what is the max support size in the batch?
max_support = support_size.max()
min_support = support_size.min()

# sum support stats across groups
current_group = mpu.get_data_parallel_group()
torch.distributed.all_reduce(support_size, group=current_group)
torch.distributed.all_reduce(max_support, group=current_group)
torch.distributed.all_reduce(min_support, group=current_group)
torch.distributed.all_reduce(fully_peaked, group=current_group)

# find number of groups
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group()
)

# compute mean support size
support_size = support_size / world_size

# compute mean max support
max_support = max_support / world_size
min_support = min_support / world_size

# compute how often support is fully peaked
fully_peaked = fully_peaked / (world_size * n_tokens)

# the stats for support size might be slightly wrong, but still
# illustrative
support_mean = support_size.mean()
support_std = support_size.std()
quantiles = torch.linspace(
0, 1, 5, dtype=support_size.dtype, device=support_size.device
)
support_quantiles = torch.quantile(support_size, quantiles)

loss_dict = {
'lm loss': averaged_loss[0],
'support size': support_mean,
"support std": support_std,
"support max": max_support,
"support min": min_support,
"support 25%": support_quantiles[1],
"support 50%": support_quantiles[2],
"support 75%": support_quantiles[3],
"fully peaked": fully_peaked
}

return loss, loss_dict
else:
return loss, {'lm loss': averaged_loss[0]}
else:
loss = loss + moe_loss
return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss}


def calculate_mos_loss(args, stu_output, teacher_model, tokens, position_ids, attention_mask):
mos_loss = 0
alpha = args.kd_alpha_ce
Expand Down Expand Up @@ -277,6 +334,7 @@ def forward_step(data_iterator, model):
args.data_efficiency_curriculum_learning_seqlen_type == 'seqlen_reshape':
args.data_efficiency_curriculum_learning_numel = torch.numel(tokens)

support_size = None
if args.mos or args.kd:
# The forward func can return either the loss or the logits, depending on whether passing in the labels or not.
stu_output, other_losses = model(tokens, position_ids, attention_mask)
Expand All @@ -285,8 +343,9 @@ def forward_step(data_iterator, model):
labels = labels[:, :args.curriculum_seqlen].contiguous()
output_tensor = tensor_parallel.vocab_parallel_cross_entropy(stu_output.contiguous().float(), labels)
else:
output_tensor, other_losses = model(tokens, position_ids, attention_mask,
labels=labels)
output_tensor, other_losses, support_size = model(
tokens, position_ids, attention_mask, labels=labels, return_support_size=True
)
if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length:
loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()

Expand All @@ -304,7 +363,7 @@ def forward_step(data_iterator, model):
args.teacher_model[0], tokens, position_ids, attention_mask)

# Output_tensor stores the standard loss, loos_func calculates the total loss.
return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss)
return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss, support_size)


def train_valid_test_datasets_provider(train_val_test_num_samples):
Expand Down
8 changes: 6 additions & 2 deletions tasks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def get_tasks_args(parser):
group.add_argument('--val-av-rank-other-neg', type=int, default=30,
help='Av.rank validation: how many other negatives to'
' take from each question pool')

group.add_argument('--eval-metric', default=None,
help='Eval metric to use other than a task-specific'
'default')
group.add_argument('--acc-k', default=5, type=int,
help='k for force-decoded accuracy at k')

return parser

Expand All @@ -79,7 +83,7 @@ def get_tasks_args(parser):

initialize_megatron(extra_args_provider=get_tasks_args)

args = get_args()
args = get_args() # will the task args be included here?

if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
Expand Down
Loading