From 8543c0e06da1a4e3aedc285a70d481e9b592ab90 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Tue, 9 Jan 2024 13:08:31 +0000 Subject: [PATCH 01/26] add entmax loss training for gpt-like models --- megatron/arguments.py | 31 ++++++++++++++++++++-------- megatron/model/gpt_model.py | 41 ++++++++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 6297cb16bb..f650b6e331 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -83,7 +83,7 @@ def validate_args(args): if args.no_pipeline_parallel: assert args.pipeline_model_parallel_size == 1, \ "pipeline_model_parallel_size must be 1 if pipeline parallel is disabled" - + if args.ds_sequence_parallel_size > 1: assert version.parse(deepspeed.__version__) >= version.parse("0.10.2"), "sequence parallelism requires DeepSpeed version 0.10.2+" @@ -432,6 +432,10 @@ def validate_args(args): assert not args.mos, 'GQA currently does not support args.mos' assert not args.kd, 'GQA currently does not support args.kd' + # entmax loss + if args.loss_function != "cross_entropy": + assert not args.fp16_lm_cross_entropy, "entmax loss only supports fp32" + # Print arguments. _print_args("arguments", args) retro_args = get_retro_args() @@ -632,7 +636,7 @@ def _add_network_size_args(parser): group.add_argument('--apply-layernorm-1p', action='store_true', help='Adjust LayerNorm weights such that they are centered ' 'around zero. This improves numerical stability.') - group.add_argument('--disable-mem-efficient-ln', action='store_false', + group.add_argument('--disable-mem-efficient-ln', action='store_false', help='Disable the memory-efficient fused LayerNorm optimization ' 'introduced in https://github.com/NVIDIA/apex/pull/1715') group.add_argument('--apply-residual-connection-post-layernorm', @@ -848,7 +852,7 @@ def _add_training_args(parser): 'training runs.') group.add_argument('--random-ltd', action='store_true', - help='enable random layer token drop') + help='enable random layer token drop') group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') group.add_argument('--exit-interval', type=int, default=None, @@ -940,6 +944,15 @@ def _add_training_args(parser): dest='gradient_accumulation_fusion') group.add_argument('--use-dataset-only', type=bool, required=False, default=False, help='If set to True, only use the megatron dataset for external trainer ') + group.add_argument('--loss-function', default='cross_entropy', + choices=['cross_entropy', 'entmax15', 'sparsemax', 'entmax_bisect'], + help='Loss function for model training') + group.add_argument('--entmax-alpha', type=float, default=1.5, + help='Entmax alpha for entmax_bisect (unused otherwise)') + group.add_argument('--entmax-topk', type=int, default=512, + help='Top k for computation of exact entmax loss (for entmax15 and sparsemax)') + group.add_argument('--entmax-n-iter', type=int, default=30, + help='Number of bisection interations for entmax_bisect') return parser @@ -1034,7 +1047,7 @@ def _add_checkpointing_args(parser): group.add_argument('--no-load-rng', action='store_true', default=None, help='Do not load rng state when loading checkpoint.') group.add_argument('--no-load-lr-state', action='store_true', - help='Do not load lr state when loading checkpoint.') + help='Do not load lr state when loading checkpoint.') group.add_argument('--finetune', action='store_true', help='Load model for finetuning. Do not load optimizer ' 'or rng state from checkpoint and set iteration to 0. ' @@ -1210,7 +1223,7 @@ def _add_data_args(parser): 'form: dataset1-weight dataset1-path dataset2-weight ' 'dataset2-path ...') group.add_argument('--multiple-valid-sets', action='store_true', - help='multiple separated validation steps') + help='multiple separated validation steps') group.add_argument('--test-data-path', nargs='*', default=None, help='Path to the test dataset. Accepted format:' '1) a single data path, 2) multiple datasets in the' @@ -1490,15 +1503,15 @@ def _add_activation_checkpoint_args(parser): def _add_distillation_args(parser): group = parser.add_argument_group('Knowledge distillation', 'Distillation Configurations') - + group.add_argument('--num-layers-teacher', type=int, default=None, - help='Number of the teacher transformer layers.') + help='Number of the teacher transformer layers.') group.add_argument('--num-experts-teacher', type=int, nargs='+', default=[1,], help='number of teacher experts list, MoE related.') group.add_argument('--hidden-size-teacher', type=int, default=None, help='Tansformer teacher hidden size.') group.add_argument('--num-attention-heads-teacher', type=int, default=None, - help='Number of teacher transformer attention heads.') + help='Number of teacher transformer attention heads.') group.add_argument('--mos', action='store_true', help='Enable Mixture-of-Students via knolwedge distillation.') @@ -1509,7 +1522,7 @@ def _add_distillation_args(parser): group.add_argument('--kd-temp', default=1.0, type=float) group.add_argument('--reset-iteration', action='store_true', help='Reset the iteration count.') - + group.add_argument('--load-teacher', type=str, default=None, help='Directory containing a teacher model checkpoint.') diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index f2eabe341e..e96ff97120 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -2,7 +2,10 @@ """GPT-2 model.""" +from functools import partial + import torch +import entmax from megatron import get_args from megatron.core import mpu, tensor_parallel, sequence_parallel @@ -24,21 +27,22 @@ except ImportError: MixedFusedRMSNorm = None -try: +try: from deepspeed.checkpoint import ( VOCABULARY_PARAMETER_PATTERNS, PIPELINE_REPLICATED_PARAMETER_PATTERNS, TP_REPLICATED_PARAMETER_PATTERNS, PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, ) - DS_UNIVERSAL_CHECKPOINT_INFO = True + DS_UNIVERSAL_CHECKPOINT_INFO = True except ImportError: - DS_UNIVERSAL_CHECKPOINT_INFO = False + DS_UNIVERSAL_CHECKPOINT_INFO = False def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, - fp16_lm_cross_entropy): + fp16_lm_cross_entropy, + loss_function, alpha, topk, n_iter): # Output. Format [s b h] output = parallel_lm_logits( @@ -49,7 +53,9 @@ def post_language_model_processing(lm_output, labels, logit_weights, if labels is None: # [s b h] => [b s h] return output.transpose(0,1).contiguous() - else: + + if loss_function == "cross_entropy": + # cross entropy # [b s] => [s b] labels = labels.transpose(0,1).contiguous() cross_entropy = sequence_parallel.vocab_sequence_parallel_cross_entropy if mpu.get_sequence_parallel_world_size() > 1 \ @@ -62,6 +68,19 @@ def post_language_model_processing(lm_output, labels, logit_weights, # [s b] => [b, s] loss = loss.transpose(0,1).contiguous() + else: + # now: the loss function is "entmax15", "sparsemax", or "entmax_bisect" + loss_funcs = { + "entmax15": partial(entmax.entmax15_loss, k=topk), + "sparsemax": partial(entmax.sparsemax_loss, k=topk), + "entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter) + } + f = loss_funcs[loss_function] + + b, s = labels.size() + vocab_size = output.size(-1) + loss = f(output.float().view(-1, vocab_size), labels.view(-1)) + loss = loss.view(b, s) return loss @@ -84,6 +103,10 @@ def __init__(self, self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.return_moe_loss = return_moe_loss self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights + self.loss_function = args.loss_function + self.entmax_alpha = args.entmax_alpha + self.entmax_topk = args.entmax_topk + self.entmax_n_iter = args.entmax_n_iter self.language_model, self._language_model_key = get_language_model( config=config, @@ -139,7 +162,11 @@ def forward(self, input_ids, position_ids, attention_mask, lm_output, labels, self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(), self.parallel_output, - self.fp16_lm_cross_entropy) + self.fp16_lm_cross_entropy, + self.loss_function, + self.entmax_alpha, + self.entmax_topk, + self.entmax_n_iter) return lm_output, moe_losses if self.return_moe_loss else lm_output @@ -210,7 +237,7 @@ def universal_checkpoint_info(self): ] return info - + def CrossEntropy(output, labels): labels, loss_mask = labels[0], labels[1] From d3be645ea5f71e1b832e426e0e09074a56f7d045 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Tue, 9 Jan 2024 20:59:58 +0000 Subject: [PATCH 02/26] add entmax gpt example for debugging --- examples/pretrain_gpt_entmax_125M.sh | 228 +++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 examples/pretrain_gpt_entmax_125M.sh diff --git a/examples/pretrain_gpt_entmax_125M.sh b/examples/pretrain_gpt_entmax_125M.sh new file mode 100644 index 0000000000..5a66efb462 --- /dev/null +++ b/examples/pretrain_gpt_entmax_125M.sh @@ -0,0 +1,228 @@ +#!/bin/bash +dir=`pwd` +############################################################################### +### Main configs +## GPT-3 models use 2K sequence length/context window +seq_len=2048 + +## The "GPT-3 XXX" below are configs from GPT-3 paper +## https://arxiv.org/abs/2005.14165, choose based on +## your desired model size or build your own configs + +## init_std is standard deviation for weight initialization. Usually larger +## model needs lower std. We used a heuristic equation of sqrt(1/3/hidden_size) +## from the MT-NLG 530B work (https://arxiv.org/pdf/2201.11990.pdf) + +## We changed min_lr to a lower number (1.0e-6), which we found is able to +## provide better zero-shot eval results. + +## GPT-3 Small 125M +model_size=0.125 +num_layers=12 +hidden_size=768 +num_attn_heads=12 +global_batch_size=256 +lr=6.0e-4 +min_lr=1.0e-6 +init_std=0.02 + +############################################################################### +### Training duration configs +## The main termination condition, original GPT-3 paper trains for 300B tokens. +train_tokens_in_billion=300 +train_tokens=$((${train_tokens_in_billion} * 1000000000)) + +## train_samples is another termination condition and also affect the number of +## data samples to be indexed. Since we want to reach the train_tokens +## above, and data efficiency techniques may change num tokens in some samples, +## so we just set this config large enough to make sure we have enough +## processed data and don't terminate by train_samples. +train_samples=$(( 300 * 1000000000 * 2 / ${seq_len} )) + +## Another wall-clock time termination condition in minutes. Set it large +## enough to avoid undesired early termination. +exit_duration=30000000 +############################################################################### +### lr configs +## lr warmup and decay duration. +## Original GPT-3 paper uses 375M warmup tokens and 260B cosine decay tokens. +## Here we increase the warmup tokens to 3B since when batch size warmup is not +## used, there are more tokens per step. Thus we need to increase warmup tokens +## to make sure there are enough warmup steps, which is important for training +## stability. +lr_warmup_tokens_in_million=3000 +lr_warmup_tokens=$((${lr_warmup_tokens_in_million} * 1000000)) +## Here we changed the LR decay tokens to align with total train tokens, since +## related works (e.g., https://arxiv.org/abs/2203.15556) find that setting the +## learning rate schedule to match the number of training tokens results in the +## best final model quality +lr_decay_tokens_in_billion=${train_tokens_in_billion} +lr_decay_tokens=$((${lr_decay_tokens_in_billion} * 1000000000)) +lr_decay_style="cosine" + +############################################################################### +### Parallelism configs +## Model parallelism, 1 is no MP +mp_size=2 + +## Pipeline parallelism. To disable PP, set pp_size to 1 and no_pp to true. +## Note that currently both curriculum learning and random-LTD are NOT +## compatible with pipeline parallelism. +pp_size=1 +no_pp="true" + +## ZeRO-based data parallelism, stage=0 will disable ZeRO +## The llama2 config uses stage=0... +zero_stage=1 + +## Total number of GPUs. ds_ssh is from DeepSpeed library. +num_gpus=$(($(ds_ssh nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)-2)) +num_gpus_pernode=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +num_node=$(( ${num_gpus} / ${num_gpus_pernode} )) + +## Data parallel size. +dp_size=$(( ${num_gpus} / ${pp_size} / ${mp_size} )) + +## Micro batch size per GPU +## Make sure that batch_size <= global_batch_size*pp_size*mp_size/num_gpus +## Reduce it manually if GPU OOM +# batch_size=$(( ${global_batch_size} / ${dp_size} )) +batch_size=2 +############################################################################### +### Misc configs +log_interval=1 +eval_iters=10 +eval_interval=1000 +# num_save controls how frequent to save checkpoint. num_save=20 means that a +# checkpoint will be saved every 5% of training. For longer training you would +# want larger num_save to save more frequently, and vice versa. +num_save=100 +estimated_train_iter=$((${train_tokens} / ${seq_len} / ${global_batch_size})) +# save_interval=$((${estimated_train_iter} / ${num_save})) +save_interval=100 + +## Activation checkpointing saves GPU memory, but reduces training speed +# bp: removing this +activation_checkpoint="false" + +## Whether or not log optimizer states (norms, max abs values) to tensorboard. +## This is not required for training and might save GPU memory when turned off. +log_optimizer_state="false" +############################################################################### +### Output and data configs +current_time=$(date "+%Y.%m.%d_%H.%M.%S") +host="${HOSTNAME}" +seed=1234 +num_workers=0 + +data_path="/home/bpop/multilinguality_megatron/data-bin/data_text_document" + +prescale_grad="true" +jobname="gpt_entmax_${model_size}B_tok${train_tokens_in_billion}B" +jobname="${jobname}_lr${lr}_min${min_lr}_w${lr_warmup_tokens_in_million}M_d${lr_decay_tokens_in_billion}B_${lr_decay_style}" +jobname="${jobname}_gbs${global_batch_size}_mbs${batch_size}_g${num_gpus}" +if [[ $zero_stage -gt 0 ]]; then + jobname="${jobname}_z${zero_stage}" + prescale_grad="false" +fi +if [[ $mp_size -gt 1 ]]; then + jobname="${jobname}_mp${mp_size}" +fi +if [ "${no_pp}" = "false" ]; then + jobname="${jobname}_pp${pp_size}" +fi +jobname="${jobname}_seed${seed}_rebase" + +username=$(whoami) +output_home="output" +log_path="${output_home}/log/" +checkpoint_path="${output_home}/checkpoint/${jobname}" +tensorboard_dir="${output_home}/tensorboard/" +tensorboard_path="${tensorboard_dir}${jobname}_${host}_${current_time}" +mkdir -p ${log_path} +mkdir -p ${checkpoint_path} +mkdir -p ${tensorboard_path} +############################################################################### +tokenizer_path="/home/bpop/llama-tokenizer/tokenizer.model" + +config_json="ds_config_gbs${global_batch_size}_mbs${batch_size}_log${log_interval}_zero${zero_stage}.json" +template_json="ds_config_gpt_TEMPLATE.json" +sed "s/GBSIZE/${global_batch_size}/" ${template_json} \ + | sed "s/MBSIZE/${batch_size}/" \ + | sed "s/LOG_INTERVAL/${log_interval}/" \ + | sed "s/ZERO_STAGE/${zero_stage}/" \ + | sed "s/PRESCALE_GRAD/${prescale_grad}/" \ + > ${config_json} + +## When saving checkpoint to a storage with cache, their could be consistency +## issue of the pointer to latest checkpoint. Here we find the correct pointer +## and broadcast it to all nodes. +iteration_file="$checkpoint_path/latest_checkpointed_iteration.txt" +iteration_file_2="$checkpoint_path/latest" +iteration=0 +for (( node = 0; node <= num_node-1; node++ )) +do + if $(ssh -q worker-"$node" "test -f \"$iteration_file\""); then + local_iteration=$(ssh -q worker-"$node" cat $iteration_file) + iteration=$(( ${local_iteration} > ${iteration} ? ${local_iteration} : ${iteration} )) + fi +done +if [[ $iteration -gt 0 ]]; then + iteration_2="global_step${iteration}" + ds_ssh "echo $iteration > $iteration_file" + ds_ssh "echo $iteration_2 > $iteration_file_2" +fi + +deepspeed pretrain_gpt.py \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --override-opt_param-scheduler \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --tensor-model-parallel-size ${mp_size} \ + --init-method-std ${init_std} \ + --lr-decay-tokens ${lr_decay_tokens} \ + --lr-warmup-tokens ${lr_warmup_tokens} \ + --micro-batch-size ${batch_size} \ + --exit-duration-in-mins ${exit_duration} \ + --global-batch-size ${global_batch_size} \ + --num-layers ${num_layers} \ + --hidden-size ${hidden_size} \ + --num-attention-heads ${num_attn_heads} \ + --seq-length ${seq_len} \ + --max-position-embeddings ${seq_len} \ + --train-tokens ${train_tokens} \ + --train-samples ${train_samples} \ + --lr ${lr} \ + --min-lr ${min_lr} \ + --lr-decay-style ${lr_decay_style} \ + --split 949,50,1 \ + --log-interval ${log_interval} \ + --eval-interval ${eval_interval} \ + --eval-iters ${eval_iters} \ + --save-interval ${save_interval} \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --hysteresis 2 \ + --num-workers ${num_workers} \ + --bf16 \ + --seed ${seed} \ + --load ${checkpoint_path} \ + --save ${checkpoint_path} \ + --no-async-tensor-model-parallel-allreduce \ + --tensorboard-queue-size 1 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --tensorboard-dir ${tensorboard_path} \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model ${tokenizer_path} \ + --data-path ${data_path} \ + --data-impl mmap \ + --deepspeed \ + --deepspeed_config ${config_json} \ + --zero-stage ${zero_stage} \ + --pipeline-model-parallel-size ${pp_size} \ + --no-pipeline-parallel \ + --loss-function entmax15 \ + --entmax-topk 512 &>> ${log_path}/${jobname}_${host}_${current_time}.log From f946b8fcbd826bdd95b9461ef2f588fa4145f4b3 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 25 Jan 2024 12:13:54 +0000 Subject: [PATCH 03/26] fix loss indentation problem that breaks entmax training --- megatron/model/gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index e96ff97120..fb871c7b2c 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -81,7 +81,7 @@ def post_language_model_processing(lm_output, labels, logit_weights, vocab_size = output.size(-1) loss = f(output.float().view(-1, vocab_size), labels.view(-1)) loss = loss.view(b, s) - return loss + return loss class GPTModel(MegatronModule): From 1c76bf8cd928e4b9711a862f3dfddd36a607961d Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 25 Jan 2024 16:51:43 +0000 Subject: [PATCH 04/26] remove nonfunctional, unused, and unnecessary shell script --- examples/pretrain_gpt_entmax_125M.sh | 228 --------------------------- 1 file changed, 228 deletions(-) delete mode 100644 examples/pretrain_gpt_entmax_125M.sh diff --git a/examples/pretrain_gpt_entmax_125M.sh b/examples/pretrain_gpt_entmax_125M.sh deleted file mode 100644 index 5a66efb462..0000000000 --- a/examples/pretrain_gpt_entmax_125M.sh +++ /dev/null @@ -1,228 +0,0 @@ -#!/bin/bash -dir=`pwd` -############################################################################### -### Main configs -## GPT-3 models use 2K sequence length/context window -seq_len=2048 - -## The "GPT-3 XXX" below are configs from GPT-3 paper -## https://arxiv.org/abs/2005.14165, choose based on -## your desired model size or build your own configs - -## init_std is standard deviation for weight initialization. Usually larger -## model needs lower std. We used a heuristic equation of sqrt(1/3/hidden_size) -## from the MT-NLG 530B work (https://arxiv.org/pdf/2201.11990.pdf) - -## We changed min_lr to a lower number (1.0e-6), which we found is able to -## provide better zero-shot eval results. - -## GPT-3 Small 125M -model_size=0.125 -num_layers=12 -hidden_size=768 -num_attn_heads=12 -global_batch_size=256 -lr=6.0e-4 -min_lr=1.0e-6 -init_std=0.02 - -############################################################################### -### Training duration configs -## The main termination condition, original GPT-3 paper trains for 300B tokens. -train_tokens_in_billion=300 -train_tokens=$((${train_tokens_in_billion} * 1000000000)) - -## train_samples is another termination condition and also affect the number of -## data samples to be indexed. Since we want to reach the train_tokens -## above, and data efficiency techniques may change num tokens in some samples, -## so we just set this config large enough to make sure we have enough -## processed data and don't terminate by train_samples. -train_samples=$(( 300 * 1000000000 * 2 / ${seq_len} )) - -## Another wall-clock time termination condition in minutes. Set it large -## enough to avoid undesired early termination. -exit_duration=30000000 -############################################################################### -### lr configs -## lr warmup and decay duration. -## Original GPT-3 paper uses 375M warmup tokens and 260B cosine decay tokens. -## Here we increase the warmup tokens to 3B since when batch size warmup is not -## used, there are more tokens per step. Thus we need to increase warmup tokens -## to make sure there are enough warmup steps, which is important for training -## stability. -lr_warmup_tokens_in_million=3000 -lr_warmup_tokens=$((${lr_warmup_tokens_in_million} * 1000000)) -## Here we changed the LR decay tokens to align with total train tokens, since -## related works (e.g., https://arxiv.org/abs/2203.15556) find that setting the -## learning rate schedule to match the number of training tokens results in the -## best final model quality -lr_decay_tokens_in_billion=${train_tokens_in_billion} -lr_decay_tokens=$((${lr_decay_tokens_in_billion} * 1000000000)) -lr_decay_style="cosine" - -############################################################################### -### Parallelism configs -## Model parallelism, 1 is no MP -mp_size=2 - -## Pipeline parallelism. To disable PP, set pp_size to 1 and no_pp to true. -## Note that currently both curriculum learning and random-LTD are NOT -## compatible with pipeline parallelism. -pp_size=1 -no_pp="true" - -## ZeRO-based data parallelism, stage=0 will disable ZeRO -## The llama2 config uses stage=0... -zero_stage=1 - -## Total number of GPUs. ds_ssh is from DeepSpeed library. -num_gpus=$(($(ds_ssh nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)-2)) -num_gpus_pernode=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) -num_node=$(( ${num_gpus} / ${num_gpus_pernode} )) - -## Data parallel size. -dp_size=$(( ${num_gpus} / ${pp_size} / ${mp_size} )) - -## Micro batch size per GPU -## Make sure that batch_size <= global_batch_size*pp_size*mp_size/num_gpus -## Reduce it manually if GPU OOM -# batch_size=$(( ${global_batch_size} / ${dp_size} )) -batch_size=2 -############################################################################### -### Misc configs -log_interval=1 -eval_iters=10 -eval_interval=1000 -# num_save controls how frequent to save checkpoint. num_save=20 means that a -# checkpoint will be saved every 5% of training. For longer training you would -# want larger num_save to save more frequently, and vice versa. -num_save=100 -estimated_train_iter=$((${train_tokens} / ${seq_len} / ${global_batch_size})) -# save_interval=$((${estimated_train_iter} / ${num_save})) -save_interval=100 - -## Activation checkpointing saves GPU memory, but reduces training speed -# bp: removing this -activation_checkpoint="false" - -## Whether or not log optimizer states (norms, max abs values) to tensorboard. -## This is not required for training and might save GPU memory when turned off. -log_optimizer_state="false" -############################################################################### -### Output and data configs -current_time=$(date "+%Y.%m.%d_%H.%M.%S") -host="${HOSTNAME}" -seed=1234 -num_workers=0 - -data_path="/home/bpop/multilinguality_megatron/data-bin/data_text_document" - -prescale_grad="true" -jobname="gpt_entmax_${model_size}B_tok${train_tokens_in_billion}B" -jobname="${jobname}_lr${lr}_min${min_lr}_w${lr_warmup_tokens_in_million}M_d${lr_decay_tokens_in_billion}B_${lr_decay_style}" -jobname="${jobname}_gbs${global_batch_size}_mbs${batch_size}_g${num_gpus}" -if [[ $zero_stage -gt 0 ]]; then - jobname="${jobname}_z${zero_stage}" - prescale_grad="false" -fi -if [[ $mp_size -gt 1 ]]; then - jobname="${jobname}_mp${mp_size}" -fi -if [ "${no_pp}" = "false" ]; then - jobname="${jobname}_pp${pp_size}" -fi -jobname="${jobname}_seed${seed}_rebase" - -username=$(whoami) -output_home="output" -log_path="${output_home}/log/" -checkpoint_path="${output_home}/checkpoint/${jobname}" -tensorboard_dir="${output_home}/tensorboard/" -tensorboard_path="${tensorboard_dir}${jobname}_${host}_${current_time}" -mkdir -p ${log_path} -mkdir -p ${checkpoint_path} -mkdir -p ${tensorboard_path} -############################################################################### -tokenizer_path="/home/bpop/llama-tokenizer/tokenizer.model" - -config_json="ds_config_gbs${global_batch_size}_mbs${batch_size}_log${log_interval}_zero${zero_stage}.json" -template_json="ds_config_gpt_TEMPLATE.json" -sed "s/GBSIZE/${global_batch_size}/" ${template_json} \ - | sed "s/MBSIZE/${batch_size}/" \ - | sed "s/LOG_INTERVAL/${log_interval}/" \ - | sed "s/ZERO_STAGE/${zero_stage}/" \ - | sed "s/PRESCALE_GRAD/${prescale_grad}/" \ - > ${config_json} - -## When saving checkpoint to a storage with cache, their could be consistency -## issue of the pointer to latest checkpoint. Here we find the correct pointer -## and broadcast it to all nodes. -iteration_file="$checkpoint_path/latest_checkpointed_iteration.txt" -iteration_file_2="$checkpoint_path/latest" -iteration=0 -for (( node = 0; node <= num_node-1; node++ )) -do - if $(ssh -q worker-"$node" "test -f \"$iteration_file\""); then - local_iteration=$(ssh -q worker-"$node" cat $iteration_file) - iteration=$(( ${local_iteration} > ${iteration} ? ${local_iteration} : ${iteration} )) - fi -done -if [[ $iteration -gt 0 ]]; then - iteration_2="global_step${iteration}" - ds_ssh "echo $iteration > $iteration_file" - ds_ssh "echo $iteration_2 > $iteration_file_2" -fi - -deepspeed pretrain_gpt.py \ - --untie-embeddings-and-output-weights \ - --disable-bias-linear \ - --override-opt_param-scheduler \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --tensor-model-parallel-size ${mp_size} \ - --init-method-std ${init_std} \ - --lr-decay-tokens ${lr_decay_tokens} \ - --lr-warmup-tokens ${lr_warmup_tokens} \ - --micro-batch-size ${batch_size} \ - --exit-duration-in-mins ${exit_duration} \ - --global-batch-size ${global_batch_size} \ - --num-layers ${num_layers} \ - --hidden-size ${hidden_size} \ - --num-attention-heads ${num_attn_heads} \ - --seq-length ${seq_len} \ - --max-position-embeddings ${seq_len} \ - --train-tokens ${train_tokens} \ - --train-samples ${train_samples} \ - --lr ${lr} \ - --min-lr ${min_lr} \ - --lr-decay-style ${lr_decay_style} \ - --split 949,50,1 \ - --log-interval ${log_interval} \ - --eval-interval ${eval_interval} \ - --eval-iters ${eval_iters} \ - --save-interval ${save_interval} \ - --weight-decay 0.1 \ - --clip-grad 1.0 \ - --hysteresis 2 \ - --num-workers ${num_workers} \ - --bf16 \ - --seed ${seed} \ - --load ${checkpoint_path} \ - --save ${checkpoint_path} \ - --no-async-tensor-model-parallel-allreduce \ - --tensorboard-queue-size 1 \ - --log-timers-to-tensorboard \ - --log-batch-size-to-tensorboard \ - --log-validation-ppl-to-tensorboard \ - --tensorboard-dir ${tensorboard_path} \ - --tokenizer-type GPTSentencePieceTokenizer \ - --tokenizer-model ${tokenizer_path} \ - --data-path ${data_path} \ - --data-impl mmap \ - --deepspeed \ - --deepspeed_config ${config_json} \ - --zero-stage ${zero_stage} \ - --pipeline-model-parallel-size ${pp_size} \ - --no-pipeline-parallel \ - --loss-function entmax15 \ - --entmax-topk 512 &>> ${log_path}/${jobname}_${host}_${current_time}.log From 563f968a3641a8d23f9b38d709f9e90ec41424d8 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 21 Mar 2024 21:11:25 +0000 Subject: [PATCH 05/26] compute support during training (although do not log it yet) --- megatron/model/gpt_model.py | 35 +++++++++++++++++++++++++---------- pretrain_gpt.py | 17 ++++++++++++----- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index fb871c7b2c..951d379bf2 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -42,7 +42,7 @@ def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy, - loss_function, alpha, topk, n_iter): + loss_function, alpha, topk, n_iter, return_support=False): # Output. Format [s b h] output = parallel_lm_logits( @@ -50,6 +50,7 @@ def post_language_model_processing(lm_output, labels, logit_weights, logit_weights, parallel_output) + # should it return a None support size in this case? let's say no for now if labels is None: # [s b h] => [b s h] return output.transpose(0,1).contiguous() @@ -68,20 +69,29 @@ def post_language_model_processing(lm_output, labels, logit_weights, # [s b] => [b, s] loss = loss.transpose(0,1).contiguous() + support = None else: # now: the loss function is "entmax15", "sparsemax", or "entmax_bisect" loss_funcs = { - "entmax15": partial(entmax.entmax15_loss, k=topk), - "sparsemax": partial(entmax.sparsemax_loss, k=topk), + "entmax15": partial(entmax.entmax15_loss, k=topk, return_support=True), + "sparsemax": partial(entmax.sparsemax_loss, k=topk, return_support=True), "entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter) } f = loss_funcs[loss_function] b, s = labels.size() vocab_size = output.size(-1) - loss = f(output.float().view(-1, vocab_size), labels.view(-1)) + if loss_function != "entmax_bisect": + loss, support = f(output.float().view(-1, vocab_size), labels.view(-1)) + else: + loss = f(output.float().view(-1, vocab_size), labels.view(-1)) + support = None loss = loss.view(b, s) - return loss + + if return_support: + return loss, support + else: + return loss class GPTModel(MegatronModule): @@ -129,7 +139,8 @@ def forward(self, input_ids, position_ids, attention_mask, retriever_position_ids=None, retriever_attn_mask=None, labels=None, tokentype_ids=None, inference_params=None, - curriculum_seqlen=None): + curriculum_seqlen=None, + return_support=False): args = get_args() if curriculum_seqlen is not None: args.curriculum_seqlen = curriculum_seqlen @@ -158,7 +169,7 @@ def forward(self, input_ids, position_ids, attention_mask, inference_params=inference_params) if self.post_process: - lm_output = post_language_model_processing( + lm_output, support_size = post_language_model_processing( lm_output, labels, self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(), self.parallel_output, @@ -166,9 +177,13 @@ def forward(self, input_ids, position_ids, attention_mask, self.loss_function, self.entmax_alpha, self.entmax_topk, - self.entmax_n_iter) - - return lm_output, moe_losses if self.return_moe_loss else lm_output + self.entmax_n_iter, + return_support=True) + # now...what do do about support_size? + if return_support: + return lm_output, moe_losses if self.return_moe_loss else lm_output, support_size + else: + return lm_output, moe_losses if self.return_moe_loss else lm_output def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 1ff671f120..307f6dd388 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -211,7 +211,7 @@ def get_batch_pipe(data): return (tokens, position_ids, attention_mask), (labels, loss_mask) -def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): +def loss_func(loss_mask, moe_loss, mos_loss, support_size, output_tensor): args = get_args() losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() @@ -229,7 +229,12 @@ def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): print_rank_0('>>> total loss: {}, lm loss {}, kd loss {}'.format(loss, averaged_loss[0], mos_loss)) else: if max(args.num_experts) <= 1: - return loss, {'lm loss': averaged_loss[0]} + if support_size is not None: + # need to average the support sizes, I guess + # todo: return support + return loss, {'lm loss': averaged_loss[0]} + else: + return loss, {'lm loss': averaged_loss[0]} else: loss = loss + moe_loss return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} @@ -277,6 +282,7 @@ def forward_step(data_iterator, model): args.data_efficiency_curriculum_learning_seqlen_type == 'seqlen_reshape': args.data_efficiency_curriculum_learning_numel = torch.numel(tokens) + support_size = None if args.mos or args.kd: # The forward func can return either the loss or the logits, depending on whether passing in the labels or not. stu_output, other_losses = model(tokens, position_ids, attention_mask) @@ -285,8 +291,9 @@ def forward_step(data_iterator, model): labels = labels[:, :args.curriculum_seqlen].contiguous() output_tensor = tensor_parallel.vocab_parallel_cross_entropy(stu_output.contiguous().float(), labels) else: - output_tensor, other_losses = model(tokens, position_ids, attention_mask, - labels=labels) + output_tensor, other_losses, support_size = model( + tokens, position_ids, attention_mask, labels=labels, return_support=True + ) if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() @@ -304,7 +311,7 @@ def forward_step(data_iterator, model): args.teacher_model[0], tokens, position_ids, attention_mask) # Output_tensor stores the standard loss, loos_func calculates the total loss. - return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss) + return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss, support_size) def train_valid_test_datasets_provider(train_val_test_num_samples): From 7228a8b190b051f8f7df65c440df9d73ac235900 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 27 Mar 2024 11:55:47 +0000 Subject: [PATCH 06/26] fix model training bug, accumulate support across gpus --- megatron/model/gpt_model.py | 2 +- pretrain_gpt.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 951d379bf2..679152fa02 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -78,8 +78,8 @@ def post_language_model_processing(lm_output, labels, logit_weights, "entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter) } f = loss_funcs[loss_function] - b, s = labels.size() + output = output.transpose(0, 1).contiguous() vocab_size = output.size(-1) if loss_function != "entmax_bisect": loss, support = f(output.float().view(-1, vocab_size), labels.view(-1)) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 307f6dd388..8d5812ba3f 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -230,9 +230,13 @@ def loss_func(loss_mask, moe_loss, mos_loss, support_size, output_tensor): else: if max(args.num_experts) <= 1: if support_size is not None: + support_size = support_size.float() # need to average the support sizes, I guess # todo: return support - return loss, {'lm loss': averaged_loss[0]} + # average_support = average_losses_across_data_parallel_group([support_size]) + torch.distributed.all_reduce(support_size, group=mpu.get_data_parallel_group()) + support_size = support_size / torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) + return loss, {'lm loss': averaged_loss[0], 'support size': support_size.mean()} else: return loss, {'lm loss': averaged_loss[0]} else: From 202caed32d9dda75cd92118ffb315eb4527d2207 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Mon, 1 Apr 2024 14:38:31 +0100 Subject: [PATCH 07/26] slightly more informative support logging --- pretrain_gpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 8d5812ba3f..bdc4eff794 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -231,12 +231,12 @@ def loss_func(loss_mask, moe_loss, mos_loss, support_size, output_tensor): if max(args.num_experts) <= 1: if support_size is not None: support_size = support_size.float() - # need to average the support sizes, I guess - # todo: return support - # average_support = average_losses_across_data_parallel_group([support_size]) torch.distributed.all_reduce(support_size, group=mpu.get_data_parallel_group()) support_size = support_size / torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) - return loss, {'lm loss': averaged_loss[0], 'support size': support_size.mean()} + # this way to compute the max support will be incorrect on + # multiple GPUs, but the numbers should still be reasonably + # informative + return loss, {'lm loss': averaged_loss[0], 'support size': support_size.mean(), "support max": support_size.max()} else: return loss, {'lm loss': averaged_loss[0]} else: From 8a83cf0be33eef3dc3fc46728aa551d9a7d18dfa Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Mon, 20 May 2024 16:16:23 +0100 Subject: [PATCH 08/26] update return_support to return_support_size in all cases --- megatron/model/gpt_model.py | 14 +++++++------- pretrain_gpt.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 679152fa02..8acfd04251 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -42,7 +42,7 @@ def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy, - loss_function, alpha, topk, n_iter, return_support=False): + loss_function, alpha, topk, n_iter, return_support_size=False): # Output. Format [s b h] output = parallel_lm_logits( @@ -73,8 +73,8 @@ def post_language_model_processing(lm_output, labels, logit_weights, else: # now: the loss function is "entmax15", "sparsemax", or "entmax_bisect" loss_funcs = { - "entmax15": partial(entmax.entmax15_loss, k=topk, return_support=True), - "sparsemax": partial(entmax.sparsemax_loss, k=topk, return_support=True), + "entmax15": partial(entmax.entmax15_loss, k=topk, return_support_size=True), + "sparsemax": partial(entmax.sparsemax_loss, k=topk, return_support_size=True), "entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter) } f = loss_funcs[loss_function] @@ -88,7 +88,7 @@ def post_language_model_processing(lm_output, labels, logit_weights, support = None loss = loss.view(b, s) - if return_support: + if return_support_size: return loss, support else: return loss @@ -140,7 +140,7 @@ def forward(self, input_ids, position_ids, attention_mask, retriever_attn_mask=None, labels=None, tokentype_ids=None, inference_params=None, curriculum_seqlen=None, - return_support=False): + return_support_size=False): args = get_args() if curriculum_seqlen is not None: args.curriculum_seqlen = curriculum_seqlen @@ -178,9 +178,9 @@ def forward(self, input_ids, position_ids, attention_mask, self.entmax_alpha, self.entmax_topk, self.entmax_n_iter, - return_support=True) + return_support_size=True) # now...what do do about support_size? - if return_support: + if return_support_size: return lm_output, moe_losses if self.return_moe_loss else lm_output, support_size else: return lm_output, moe_losses if self.return_moe_loss else lm_output diff --git a/pretrain_gpt.py b/pretrain_gpt.py index bdc4eff794..2f4eab4b08 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -296,7 +296,7 @@ def forward_step(data_iterator, model): output_tensor = tensor_parallel.vocab_parallel_cross_entropy(stu_output.contiguous().float(), labels) else: output_tensor, other_losses, support_size = model( - tokens, position_ids, attention_mask, labels=labels, return_support=True + tokens, position_ids, attention_mask, labels=labels, return_support_size=True ) if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() From 48a1781a89cc20f7f09fa3d85c1275cec8596512 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 22 May 2024 15:14:13 +0100 Subject: [PATCH 09/26] [wip] compute more sparsity stats --- pretrain_gpt.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 2f4eab4b08..292f7b476e 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -231,18 +231,42 @@ def loss_func(loss_mask, moe_loss, mos_loss, support_size, output_tensor): if max(args.num_experts) <= 1: if support_size is not None: support_size = support_size.float() - torch.distributed.all_reduce(support_size, group=mpu.get_data_parallel_group()) - support_size = support_size / torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) - # this way to compute the max support will be incorrect on - # multiple GPUs, but the numbers should still be reasonably - # informative - return loss, {'lm loss': averaged_loss[0], 'support size': support_size.mean(), "support max": support_size.max()} + n_tokens = support_size.size(0) + + # how often is support size 1? + fully_peaked = support_size.eq(1).sum() + + # what is the max support size in the batch? + max_support = support_size.max() + + # sum support stats across groups + current_group = mpu.get_data_parallel_group() + torch.distributed.all_reduce(support_size, group=current_group) + torch.distributed.all_reduce(max_support, group=current_group) + torch.distributed.all_reduce(fully_peaked, group=current_group) + + # find number of groups + world_size = torch.distributed.get_world_size( + group=mpu.get_data_parallel_group() + ) + + # compute mean support size + support_size = support_size / world_size + + # compute mean max support + max_support = max_support / world_size + + # compute how often support is fully peaked + fully_peaked = fully_peaked / (world_size * n_tokens) + + return loss, {'lm loss': averaged_loss[0], 'support size': support_size.mean(), "support max": max_support, "fully peaked": fully_peaked} else: return loss, {'lm loss': averaged_loss[0]} else: loss = loss + moe_loss return loss, {'lm loss': averaged_loss[0], 'moe loss': moe_loss} + def calculate_mos_loss(args, stu_output, teacher_model, tokens, position_ids, attention_mask): mos_loss = 0 alpha = args.kd_alpha_ce From 19b478f92a1d114374ef50606baa3a0ce00bda98 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 22 May 2024 15:51:36 +0100 Subject: [PATCH 10/26] add more support logging statistics --- pretrain_gpt.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 292f7b476e..5129129923 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -238,11 +238,13 @@ def loss_func(loss_mask, moe_loss, mos_loss, support_size, output_tensor): # what is the max support size in the batch? max_support = support_size.max() + min_support = support_size.min() # sum support stats across groups current_group = mpu.get_data_parallel_group() torch.distributed.all_reduce(support_size, group=current_group) torch.distributed.all_reduce(max_support, group=current_group) + torch.distributed.all_reduce(min_support, group=current_group) torch.distributed.all_reduce(fully_peaked, group=current_group) # find number of groups @@ -255,11 +257,33 @@ def loss_func(loss_mask, moe_loss, mos_loss, support_size, output_tensor): # compute mean max support max_support = max_support / world_size + min_support = min_support / world_size # compute how often support is fully peaked fully_peaked = fully_peaked / (world_size * n_tokens) - return loss, {'lm loss': averaged_loss[0], 'support size': support_size.mean(), "support max": max_support, "fully peaked": fully_peaked} + # the stats for support size might be slightly wrong, but still + # illustrative + support_mean = support_size.mean() + support_std = support_size.std() + quantiles = torch.linspace( + 0, 1, 5, dtype=support_size.dtype, device=support_size.device + ) + support_quantiles = torch.quantile(support_size, quantiles) + + loss_dict = { + 'lm loss': averaged_loss[0], + 'support size': support_mean, + "support std": support_std, + "support max": max_support, + "support min": min_support, + "support 25%": support_quantiles[1], + "support 50%": support_quantiles[2], + "support 75%": support_quantiles[3], + "fully peaked": fully_peaked + } + + return loss, loss_dict else: return loss, {'lm loss': averaged_loss[0]} else: From c4cf60c5951d4c78642ad2ff169e968637dc4239 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 26 Jun 2024 13:09:00 +0100 Subject: [PATCH 11/26] add multiple loss functions (with some print statements to make sure my assumptions of dimensions are correct) --- tasks/main.py | 6 +- tasks/zeroshot_gpt/evaluate.py | 107 +++++++++++++++++++++++++++++++-- 2 files changed, 105 insertions(+), 8 deletions(-) diff --git a/tasks/main.py b/tasks/main.py index 9bc38f5fd2..3e8ca2c6c6 100644 --- a/tasks/main.py +++ b/tasks/main.py @@ -70,7 +70,9 @@ def get_tasks_args(parser): group.add_argument('--val-av-rank-other-neg', type=int, default=30, help='Av.rank validation: how many other negatives to' ' take from each question pool') - + group.add_argument('--eval-metric', default=None, + help='Eval metric to use other than a task-specific' + 'default') return parser @@ -79,7 +81,7 @@ def get_tasks_args(parser): initialize_megatron(extra_args_provider=get_tasks_args) - args = get_args() + args = get_args() # will the task args be included here? if args.num_layers_per_virtual_pipeline_stage is not None: print("Interleaved pipeline schedule is not yet supported for downstream tasks.") diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index a9e27fc49c..13fece9359 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -3,9 +3,12 @@ """GPT zero-shot evaluation.""" import math +from functools import partial import torch +import entmax + from megatron import get_args from megatron import print_rank_0, is_last_rank from megatron import get_tokenizer @@ -53,7 +56,7 @@ def model_provider(pre_process=True, post_process=True): 'is not supported.'.format(eval_metric)) print_rank_0('building GPT model ...') - + args = get_args() config = core_transformer_config_from_args(args) if args.deepspeed: @@ -62,7 +65,7 @@ def model_provider(pre_process=True, post_process=True): config_dict_or_path=args.deepspeed_config, enabled=args.zero_stage == 3, mpu=mpu): - + model = GPTModel( config=config, num_tokentypes=0, @@ -73,7 +76,7 @@ def model_provider(pre_process=True, post_process=True): else: model = GPTModel(config=config, num_tokentypes=0, parallel_output=parallel_output, pre_process=pre_process, post_process=post_process) - + return model @@ -100,6 +103,88 @@ def process_batch(batch): return tokens, labels, attention_mask, position_ids, loss_mask +""" +if loss_function == "cross_entropy": + # cross entropy + # [b s] => [s b] + labels = labels.transpose(0,1).contiguous() + cross_entropy = sequence_parallel.vocab_sequence_parallel_cross_entropy if mpu.get_sequence_parallel_world_size() > 1 \ + else tensor_parallel.vocab_parallel_cross_entropy + if fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = cross_entropy(output, labels) + else: + loss = cross_entropy(output.float(), labels) + + # [s b] => [b, s] + loss = loss.transpose(0,1).contiguous() + support = None +else: + # now: the loss function is "entmax15", "sparsemax", or "entmax_bisect" + loss_funcs = { + "entmax15": partial(entmax.entmax15_loss, k=topk, return_support_size=True), + "sparsemax": partial(entmax.sparsemax_loss, k=topk, return_support_size=True), + "entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter) + } + f = loss_funcs[loss_function] + b, s = labels.size() + output = output.transpose(0, 1).contiguous() + vocab_size = output.size(-1) + if loss_function != "entmax_bisect": + loss, support = f(output.float().view(-1, vocab_size), labels.view(-1)) + else: + loss = f(output.float().view(-1, vocab_size), labels.view(-1)) + support = None + loss = loss.view(b, s) + +""" + + +def _compute_loss(output, labels, loss_mask, loss_function="cross_entropy", topk=512, alpha=1.5, n_iter=30): + """ + Dimensions are confusing but I think I've figured it out. Based on + process_batch (defined above) and forward_step, labels is [b s]. I assume + loss_mask is the same shape. And I assume that output is [b s V] (it would + be ridiculously confusing otherwise). + + Based on the documentation of tensor_parallel.vocab_parallel_cross_entropy, + we can expect output (or rather, output[0], which should be the decoder + output based on TransformerLanguageModel.forward) to be [s b V] and labels + to be [s b]. + """ + print("output type before _compute_loss", type(output)) + if isinstance(output, torch.Tensor): + print("size before indexing", output.size()) + output = output[0] # based on how loss was previously computed + print("size after indexing", output.size()) + + if loss_function == "cross_entropy": + # I believe (based on the commented-out block above) that this + # function takes [s b] as its input. + # But I'm not certain + losses = tensor_parallel.vocab_parallel_cross_entropy( + output.contiguous().float(), labels.contiguous()) + else: + # now: the loss function is "entmax15", "sparsemax", or "entmax_bisect" + loss_funcs = { + "entmax15": partial(entmax.entmax15_loss, k=topk, return_support_size=True), + "sparsemax": partial(entmax.sparsemax_loss, k=topk, return_support_size=True), + "entmax_bisect": partial(entmax.entmax_bisect_loss, alpha=alpha, n_iter=n_iter) + } + f = loss_funcs[loss_function] + print("labels size: ", labels.size()) + print("output size", output.size()) + vocab_size = output[0].size(-1) + if loss_function != "entmax_bisect": + losses, _ = f(output.float().view(-1, vocab_size), labels.view(-1)) + else: + losses = f(output.float().view(-1, vocab_size), labels.view(-1)) + # losses = losses.view(b, s) + + loss = torch.sum(losses.view(-1) * loss_mask.contiguous().view(-1).float()) + + return loss + def forward_step(batch, model, eval_metric): """Forward step.""" @@ -126,11 +211,19 @@ def forward_step(batch, model, eval_metric): if parallel_state.is_pipeline_last_stage(): # For loss, return the unreduced loss. if eval_metric == 'loss': + ''' losses = tensor_parallel.vocab_parallel_cross_entropy( output[0].contiguous().float(), labels.contiguous()) loss = torch.sum( losses.view(-1) * loss_mask.contiguous().view(-1).float()) return loss + ''' + + loss = _compute_loss( + output, labels, loss_mask, + loss_function=args.loss_function, topk=args.entmax_topk, n_iter=args.entmax_n_iter, alpha=args.entmax_alpha + ) + return loss # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': @@ -225,7 +318,9 @@ def main(): print("Interleaved pipeline schedule is not yet supported for text generation.") exit() - if args.task == 'LAMBADA': + if args.eval_metric is not None: + eval_metric = args.eval_metric + elif args.task == 'LAMBADA': eval_metric = 'accuracy' elif args.task == 'WIKITEXT103': eval_metric = 'loss' @@ -248,7 +343,7 @@ def main(): mpu=mpu if args.no_pipeline_parallel else None ) model = [model] - + if args.load is not None: _ = load_checkpoint(model, None, None, load_iteration=args.load_iteration) @@ -262,7 +357,7 @@ def main(): # Run evaluation. evaluate_and_print_results(args.task, dataloader, model, eval_metric) - + print_rank_0('done :-)') From a02cf72da485b6350e95d23eadabf7069b076706 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 26 Jun 2024 14:06:06 +0100 Subject: [PATCH 12/26] add force_decoded_accuracy as a zero-shot evaluation metric --- tasks/zeroshot_gpt/evaluate.py | 37 +++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 13fece9359..0b480eab5c 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -47,7 +47,7 @@ def model_provider(pre_process=True, post_process=True): config = core_transformer_config_from_args(get_args()) - if eval_metric == 'loss': + if eval_metric == 'loss' or eval_metric == 'force_decoded_accuracy': parallel_output = True elif eval_metric == 'accuracy': parallel_output = False @@ -186,6 +186,37 @@ def _compute_loss(output, labels, loss_mask, loss_function="cross_entropy", topk return loss +def _force_decoded_accuracy(output, labels, loss_mask): + """ + This is different from LAMBADA accuracy, which is only about getting the + final word right based on a long context. + + Dimensions are confusing but I think I've figured it out. Based on + process_batch (defined above) and forward_step, labels is [b s]. I assume + loss_mask is the same shape. And I assume that output is [b s V] (it would + be ridiculously confusing otherwise). + + Based on the documentation of tensor_parallel.vocab_parallel_cross_entropy, + we can expect output (or rather, output[0], which should be the decoder + output based on TransformerLanguageModel.forward) to be [s b V] and labels + to be [s b]. + """ + + # same as for the loss one, but not the accuracy one. + # keep these print statements for a bit + print("output type before _compute_loss", type(output)) + if isinstance(output, torch.Tensor): + print("size before indexing", output.size()) + output = output[0] # based on how loss was previously computed + print("size after indexing", output.size()) + + predictions = output.argmax(dim=-1).view(-1) + correct = predictions.eq(labels.view(-1)).float() + + correct_sum = torch.sum(correct * loss_mask.contiguous().view(-1).float()) + return correct_sum + + def forward_step(batch, model, eval_metric): """Forward step.""" @@ -225,6 +256,10 @@ def forward_step(batch, model, eval_metric): ) return loss + if eval_metric == "force_decoded_accuracy": + correct_sum = _force_decoded_accuracy(output, labels, loss_mask) + return correct_sum + # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': outputs = torch.argmax(output, -1) From a8733c4d6df5486fe6e37abf9d9e4c4aa2a80d34 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 26 Jun 2024 14:28:26 +0100 Subject: [PATCH 13/26] update logging with teacher forced accuracy --- tasks/zeroshot_gpt/evaluate.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 0b480eab5c..a59e638bf2 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -336,6 +336,21 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): string += 'total examples: {:.4E} | '.format(num_examples) string += 'avg accuracy: {:.4E}'.format(acc) + elif eval_metric == "force_decoded_accuracy": + num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens + acc = output / (num_tokenized_tokens - 1) + string += 'number correct: {:.4E} | '.format(output) + string += 'total tokens: {:.4E} | '.format(num_tokenized_tokens) + string += 'avg accuracy: {:.4E}'.format(acc) + + results = { + "accuracy": acc, + "n_correct": output, + "n_tokens": num_tokenized_tokens + } + with open('./eval_results', 'w') as json_file: + json.dump(results, json_file) + else: raise NotImplementedError('evaluation method for {} metric is not ' 'implemented yet.'.format(eval_metric)) From e51656d3c7d8142f672ea53521da081d0f3791f8 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 26 Jun 2024 14:30:54 +0100 Subject: [PATCH 14/26] update xavier_uniform init (which will probably be unused --- megatron/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index f650b6e331..cdbe49f803 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -481,7 +481,7 @@ def core_transformer_config_from_args(args): kw_args['bias_gelu_fusion'] = False if args.init_method_xavier_uniform: kw_args['init_method'] = torch.nn.init.xavier_uniform_ - kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_ + kw_args['output_layer_init_method'] = torch.nn.init.xavier_uniform_ return TransformerConfig(**kw_args) From 055f869ddf2f1a67a694ac060c500c2aa7fab788 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 26 Jun 2024 15:30:46 +0100 Subject: [PATCH 15/26] avoid unpacking error when labels is None --- megatron/model/gpt_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 8acfd04251..5ed0804d9f 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -169,7 +169,7 @@ def forward(self, input_ids, position_ids, attention_mask, inference_params=inference_params) if self.post_process: - lm_output, support_size = post_language_model_processing( + post_lm_out = post_language_model_processing( lm_output, labels, self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(), self.parallel_output, @@ -179,6 +179,10 @@ def forward(self, input_ids, position_ids, attention_mask, self.entmax_topk, self.entmax_n_iter, return_support_size=True) + if labels is not None: + lm_output, support_size = post_lm_out + else: + lm_output = post_lm_out # now...what do do about support_size? if return_support_size: return lm_output, moe_losses if self.return_moe_loss else lm_output, support_size From e24ebc6c7eed49a1d42a0f80533122347bcdcc85 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 26 Jun 2024 16:06:35 +0100 Subject: [PATCH 16/26] remove print statements for eval computation --- tasks/zeroshot_gpt/evaluate.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index a59e638bf2..ba2dc90e5b 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -202,13 +202,8 @@ def _force_decoded_accuracy(output, labels, loss_mask): to be [s b]. """ - # same as for the loss one, but not the accuracy one. - # keep these print statements for a bit - print("output type before _compute_loss", type(output)) - if isinstance(output, torch.Tensor): - print("size before indexing", output.size()) - output = output[0] # based on how loss was previously computed - print("size after indexing", output.size()) + # the raw output is a tuple (same as for eval_metric=="loss") + output = output[0] predictions = output.argmax(dim=-1).view(-1) correct = predictions.eq(labels.view(-1)).float() From f77cf0c87687ea651fc370f5e8b718fd5c53a236 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 26 Jun 2024 16:38:32 +0100 Subject: [PATCH 17/26] call item() because json cannot handle tensors --- tasks/zeroshot_gpt/evaluate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index ba2dc90e5b..7a94735490 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -339,8 +339,8 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): string += 'avg accuracy: {:.4E}'.format(acc) results = { - "accuracy": acc, - "n_correct": output, + "accuracy": acc.item(), + "n_correct": output.item(), "n_tokens": num_tokenized_tokens } with open('./eval_results', 'w') as json_file: From 68a4a1bc73ec22637db36b51fa4e2d21d0cbbdbc Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 27 Jun 2024 17:38:30 +0100 Subject: [PATCH 18/26] add accuracy at k --- tasks/main.py | 2 ++ tasks/zeroshot_gpt/evaluate.py | 37 ++++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tasks/main.py b/tasks/main.py index 3e8ca2c6c6..913e519b0f 100644 --- a/tasks/main.py +++ b/tasks/main.py @@ -73,6 +73,8 @@ def get_tasks_args(parser): group.add_argument('--eval-metric', default=None, help='Eval metric to use other than a task-specific' 'default') + group.add_argument('--acc-k', default=5, type=int, + help='k for force-decoded accuracy at k') return parser diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 7a94735490..bd2049ff44 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -195,11 +195,6 @@ def _force_decoded_accuracy(output, labels, loss_mask): process_batch (defined above) and forward_step, labels is [b s]. I assume loss_mask is the same shape. And I assume that output is [b s V] (it would be ridiculously confusing otherwise). - - Based on the documentation of tensor_parallel.vocab_parallel_cross_entropy, - we can expect output (or rather, output[0], which should be the decoder - output based on TransformerLanguageModel.forward) to be [s b V] and labels - to be [s b]. """ # the raw output is a tuple (same as for eval_metric=="loss") @@ -212,6 +207,31 @@ def _force_decoded_accuracy(output, labels, loss_mask): return correct_sum +def _force_decoded_accuracy_at_k(output, labels, loss_mask, k): + """ + Accuracy at k -- do any of the top k outputs match? + + This is different from LAMBADA accuracy, which is only about getting the + final word right based on a long context. + + Dimensions are confusing but I think I've figured it out. Based on + process_batch (defined above) and forward_step, labels is [b s]. I assume + loss_mask is the same shape. And I assume that output is [b s V] (it would + be ridiculously confusing otherwise). + """ + + # the raw output is a tuple (same as for eval_metric=="loss") + output = output[0] + + _, predictions = torch.topk(output, k, dim=-1) + predictions = predictions.view(-1, k) + + correct = predictions.eq(labels.view(-1, 1)).any(dim=-1).float() + + correct_sum = torch.sum(correct * loss_mask.contiguous().view(-1).float()) + return correct_sum + + def forward_step(batch, model, eval_metric): """Forward step.""" @@ -255,6 +275,11 @@ def forward_step(batch, model, eval_metric): correct_sum = _force_decoded_accuracy(output, labels, loss_mask) return correct_sum + if eval_metric == "force_decoded_accuracy_at_k": + k = args.acc_k + correct_sum = _force_decoded_accuracy_at_k(output, labels, loss_mask, k) + return correct_sum + # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': outputs = torch.argmax(output, -1) @@ -331,7 +356,7 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): string += 'total examples: {:.4E} | '.format(num_examples) string += 'avg accuracy: {:.4E}'.format(acc) - elif eval_metric == "force_decoded_accuracy": + elif eval_metric == "force_decoded_accuracy" or eval_metric == "force_decoded_accuracy_at_k": num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens acc = output / (num_tokenized_tokens - 1) string += 'number correct: {:.4E} | '.format(output) From 24a5ed3e2533d9ac0a01cf3cd23efa8a28c11a10 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 27 Jun 2024 17:42:23 +0100 Subject: [PATCH 19/26] fix guardrail for new task --- tasks/zeroshot_gpt/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index bd2049ff44..603fc1ea44 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -47,7 +47,7 @@ def model_provider(pre_process=True, post_process=True): config = core_transformer_config_from_args(get_args()) - if eval_metric == 'loss' or eval_metric == 'force_decoded_accuracy': + if eval_metric in {"loss", "force_decoded_accuracy", "force_decoded_accuracy_at_k"}: parallel_output = True elif eval_metric == 'accuracy': parallel_output = False From 77fe1089a893726097a67b389c88f43c320acfdc Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 27 Jun 2024 19:13:02 +0100 Subject: [PATCH 20/26] small change to try to get LAMBADA to run --- tasks/zeroshot_gpt/evaluate.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 603fc1ea44..6fb725cd80 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -415,7 +415,10 @@ def main(): model = [model] if args.load is not None: - _ = load_checkpoint(model, None, None, load_iteration=args.load_iteration) + if args.task == "LAMBADA": + _ = load_checkpoint(model, None, None, load_iteration=args.load_iteration, load_only_weights=True) + else: + _ = load_checkpoint(model, None, None, load_iteration=args.load_iteration) assert len(model) == 1, "Above condition should have caught this" model = model[0] From c41f8b02c2365d35948f82c48591d98d7a8866e8 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 27 Jun 2024 21:42:00 +0000 Subject: [PATCH 21/26] don't load optimizer state for lambada --- tasks/zeroshot_gpt/evaluate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 6fb725cd80..a656f7228c 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -282,6 +282,8 @@ def forward_step(batch, model, eval_metric): # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': + if isinstance(output, tuple): + output = output[0] # not sure why this was necessary outputs = torch.argmax(output, -1) correct = (outputs == labels).float() correct[(1 - loss_mask).bool()] = 1 From 4cfc1336b5ceeb0d34b7001f4408fb5011318a5f Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 27 Jun 2024 22:46:53 +0100 Subject: [PATCH 22/26] write lambada to results file --- tasks/zeroshot_gpt/evaluate.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index a656f7228c..9fb6fbaf73 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -357,6 +357,9 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): string += 'number correct: {:.4E} | '.format(output) string += 'total examples: {:.4E} | '.format(num_examples) string += 'avg accuracy: {:.4E}'.format(acc) + results = {"accuracy": acc.item()} + with open('./eval_results', 'w') as json_file: + json.dump(results, json_file) elif eval_metric == "force_decoded_accuracy" or eval_metric == "force_decoded_accuracy_at_k": num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens From 01825324e5d22373a4dac17aa5d389a99ce37595 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Fri, 28 Jun 2024 12:39:46 +0100 Subject: [PATCH 23/26] fix entmax_bisect_loss return type --- megatron/model/gpt_model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 5ed0804d9f..e905825c36 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -81,11 +81,25 @@ def post_language_model_processing(lm_output, labels, logit_weights, b, s = labels.size() output = output.transpose(0, 1).contiguous() vocab_size = output.size(-1) + + # currently entmax_bisect_loss always returns a None support size, + # which is not ideal. This is a stopgap until entmax_bisect_loss is + # fixed + loss = f(output.float().view(-1, vocab_size), labels.view(-1)) + if isinstance(loss, tuple): + loss, support = loss + else: + support = None + + # old version which breaks because entmax_bisect unexpectedly returned + # a tuple: + ''' if loss_function != "entmax_bisect": loss, support = f(output.float().view(-1, vocab_size), labels.view(-1)) else: loss = f(output.float().view(-1, vocab_size), labels.view(-1)) support = None + ''' loss = loss.view(b, s) if return_support_size: From 9e7f53d3f0b25daef9080016fb0c7b3deff1781d Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Fri, 12 Jul 2024 11:49:22 +0100 Subject: [PATCH 24/26] refactor zeroshot gpt evaluation for sparsemax score --- tasks/zeroshot_gpt/evaluate.py | 202 ++++++++++++++++++++++----------- 1 file changed, 135 insertions(+), 67 deletions(-) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 9fb6fbaf73..534ab75d44 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -4,6 +4,7 @@ import math from functools import partial +from collections import defaultdict import torch @@ -47,7 +48,7 @@ def model_provider(pre_process=True, post_process=True): config = core_transformer_config_from_args(get_args()) - if eval_metric in {"loss", "force_decoded_accuracy", "force_decoded_accuracy_at_k"}: + if eval_metric in {"loss", "force_decoded_accuracy", "force_decoded_accuracy_at_k", "sparsemax_score"}: parallel_output = True elif eval_metric == 'accuracy': parallel_output = False @@ -232,8 +233,60 @@ def _force_decoded_accuracy_at_k(output, labels, loss_mask, k): return correct_sum +def _gini_entropy(probs): + return probs * (1 - probs).sum(dim=-1) / 2 + + +def _sparsemax_score(output, labels, loss_mask, loss_function="cross_entropy", topk=512, alpha=1.5, n_iter=30): + # loss_function is really the generator function here + + if isinstance(output, torch.Tensor): + print("size before indexing", output.size()) + output = output[0] # based on how loss was previously computed + vocab_size = output.size(-1) + + output = output.view(-1, vocab_size) + labels = labels.view(-1) + loss_mask = loss_mask.contiguous().view(-1).float() + + # you can get the accuracy almost for free, so you might as well + predictions = output.argmax(dim=-1) + correct = predictions.eq(labels).float() + correct_sum = torch.sum(correct * loss_mask) + + gen_funcs = { + "cross_entropy": torch.softmax, + "entmax15": partial(entmax.entmax15, k=topk, return_support_size=True), + "sparsemax": partial(entmax.sparsemax, k=topk, return_support_size=True), + "entmax_bisect": partial(entmax.entmax_bisect, alpha=alpha, n_iter=n_iter) + } + + f = gen_funcs[loss_function] + + if loss_function not in {"cross_entropy", "entmax_bisect"}: + probs, support_size = f(output.float(), dim=-1) + else: + probs = f(output.float(), dim=-1) + support_size = None + + # now...p_theta(x) + # sp = p_theta(x) + H_2(p_theta) + + gold_probs = probs.gather(1, labels.unsqueeze(1)).view(-1) + entropy = _gini_entropy(probs) + sp = ((gold_probs + entropy) * loss_mask).sum() + + return sp, correct_sum + + def forward_step(batch, model, eval_metric): """Forward step.""" + # TODO: return dict + eval_metrics = {"loss", "accuracy", "force_decoded_accuracy", + "force_decoded_accuracy_at_k", "sparsemax_score"} + if eval_metric not in eval_metrics: + raise NotImplementedError('forward method for evaluation metric {} ' + 'is not implemented.'.format(eval_metric)) # Get the batch. tokens, labels, attention_mask, position_ids, loss_mask = process_batch( @@ -256,6 +309,9 @@ def forward_step(batch, model, eval_metric): if parallel_state.is_pipeline_last_stage(): # For loss, return the unreduced loss. + + scores = dict() + if eval_metric == 'loss': ''' losses = tensor_parallel.vocab_parallel_cross_entropy( @@ -269,16 +325,25 @@ def forward_step(batch, model, eval_metric): output, labels, loss_mask, loss_function=args.loss_function, topk=args.entmax_topk, n_iter=args.entmax_n_iter, alpha=args.entmax_alpha ) - return loss + scores["loss"] = loss if eval_metric == "force_decoded_accuracy": correct_sum = _force_decoded_accuracy(output, labels, loss_mask) - return correct_sum + scores["force_decoded_accuracy"] = correct_sum if eval_metric == "force_decoded_accuracy_at_k": - k = args.acc_k - correct_sum = _force_decoded_accuracy_at_k(output, labels, loss_mask, k) - return correct_sum + correct_sum = _force_decoded_accuracy_at_k(output, labels, loss_mask, args.acc_k) + scores["force_decoded_accuracy_at_k"] = correct_sum + + if eval_metric == "sparsemax_score": + sp_sum, correct_sum = _sparsemax_score( + output, labels, loss_mask, + loss_function=args.loss_function, topk=args.entmax_topk, n_iter=args.entmax_n_iter, alpha=args.entmax_alpha + ) + scores["sparsemax_score"] = sp_sum + scores["force_decoded_accuracy"] = correct_sum + # currently computes the accuracy but doesn't return it. annoying. + # return {"sparsemax_score": sp_sum, "force_decoded_accuracy": correct_sum} # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': @@ -288,10 +353,10 @@ def forward_step(batch, model, eval_metric): correct = (outputs == labels).float() correct[(1 - loss_mask).bool()] = 1 correct = correct.prod(-1) - return correct.sum() + scores["accuracy"] = correct.sum() + + return scores - raise NotImplementedError('forward method for evaluation metric {} ' - 'is not implemented.'.format(eval_metric)) return None @@ -302,21 +367,24 @@ def evaluate(data_loader, model, eval_metric): # Turn on evaluation mode which disables dropout. model.eval() - total_output = 0.0 + total_output = defaultdict(float) + with torch.no_grad(): # For all the batches in the dataset. for iteration, batch in enumerate(data_loader): if iteration % args.log_interval == 0: print_rank_0('> working on iteration: {}'.format(iteration)) # Forward evaluation. - output = forward_step(batch, model, eval_metric) + output_dict = forward_step(batch, model, eval_metric) # problem if this doesn't return a tensor # Reduce across processes. if parallel_state.is_pipeline_last_stage(): - torch.distributed.all_reduce(output, - group=parallel_state.get_data_parallel_group()) - - total_output += output + for metric_name, output in output_dict.items(): + torch.distributed.all_reduce( + output, + group=parallel_state.get_data_parallel_group() + ) + total_output[metric_name] += output return total_output @@ -325,60 +393,61 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): """Evaluate and print results on screen.""" # Evaluate and get results. - output = evaluate(data_loader, model, eval_metric) + output_dict = evaluate(data_loader, model, eval_metric) # this is a dict string = ' validation results on {} | '.format(task) if is_last_rank(): - if eval_metric == 'loss': - num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens - num_original_tokens = data_loader.dataset.num_original_tokens - val_loss = output / (num_tokenized_tokens - 1) - ppl = math.exp(min(20, val_loss)) - token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) - adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) - string += 'avg loss: {:.4E} | '.format(val_loss) - string += 'ppl: {:.4E} | '.format(ppl) - string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) - string += 'token ratio: {} |'.format(token_ratio) - - results = { - "loss": val_loss.item(), - "ppl": ppl, - "ajusted_ppl": adjusted_ppl, - "token_ratio": token_ratio - } - - with open('./eval_results', 'w') as json_file: - json.dump(results, json_file) - - elif eval_metric == 'accuracy': - num_examples = len(data_loader.dataset) - acc = output / num_examples - string += 'number correct: {:.4E} | '.format(output) - string += 'total examples: {:.4E} | '.format(num_examples) - string += 'avg accuracy: {:.4E}'.format(acc) - results = {"accuracy": acc.item()} - with open('./eval_results', 'w') as json_file: - json.dump(results, json_file) - - elif eval_metric == "force_decoded_accuracy" or eval_metric == "force_decoded_accuracy_at_k": - num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens - acc = output / (num_tokenized_tokens - 1) - string += 'number correct: {:.4E} | '.format(output) - string += 'total tokens: {:.4E} | '.format(num_tokenized_tokens) - string += 'avg accuracy: {:.4E}'.format(acc) - - results = { - "accuracy": acc.item(), - "n_correct": output.item(), - "n_tokens": num_tokenized_tokens - } - with open('./eval_results', 'w') as json_file: - json.dump(results, json_file) - - else: - raise NotImplementedError('evaluation method for {} metric is not ' - 'implemented yet.'.format(eval_metric)) + results = dict() + + num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens + num_original_tokens = data_loader.dataset.num_original_tokens + num_examples = len(data_loader.dataset) + + results["n_tokens"] = num_tokenized_tokens + + for eval_metric, output in output_dict.items(): + if eval_metric == 'loss': + + val_loss = output / (num_tokenized_tokens - 1) + ppl = math.exp(min(20, val_loss)) + token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) + adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) + string += 'avg loss: {:.4E} | '.format(val_loss) + string += 'ppl: {:.4E} | '.format(ppl) + string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) + string += 'token ratio: {} |'.format(token_ratio) + + results["loss"] = val_loss.item() + results["ppl"] = ppl + results["adjusted_ppl"] = adjusted_ppl + results["token_ratio"] = token_ratio + + elif eval_metric == 'accuracy': + # remember this is Lambada accuracy + acc = output / num_examples + string += 'number correct: {:.4E} | '.format(output) + string += 'total examples: {:.4E} | '.format(num_examples) + string += 'avg accuracy: {:.4E}'.format(acc) + results["accuracy"] = acc.item() + + elif eval_metric == "force_decoded_accuracy" or eval_metric == "force_decoded_accuracy_at_k": + acc = output / (num_tokenized_tokens - 1) + string += 'number correct: {:.4E} | '.format(output) + string += 'total tokens: {:.4E} | '.format(num_tokenized_tokens) + string += 'avg accuracy: {:.4E}'.format(acc) + + results["accuracy"] = acc.item() + results["n_correct"] = output.item() + elif eval_metric == "sparsemax_score": + avg_sparsemax_score = output / (num_tokenized_tokens - 1) + string += 'sparsemax score: {:.4E} | '.format(avg_sparsemax_score) + results["sparsemax_score"] = avg_sparsemax_score + else: + raise NotImplementedError('evaluation method for {} metric is not ' + 'implemented yet.'.format(eval_metric)) + + with open('./eval_results', 'w') as json_file: + json.dump(results, json_file) length = len(string) + 1 print('-' * length) @@ -438,4 +507,3 @@ def main(): print_rank_0('done :-)') - From f3c66c4943c476ed2a445dfe2e2a8701d102d5b5 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Fri, 12 Jul 2024 14:29:10 +0100 Subject: [PATCH 25/26] add missing parentheses --- tasks/zeroshot_gpt/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 534ab75d44..8501a98026 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -234,7 +234,7 @@ def _force_decoded_accuracy_at_k(output, labels, loss_mask, k): def _gini_entropy(probs): - return probs * (1 - probs).sum(dim=-1) / 2 + return (probs * (1 - probs)).sum(dim=-1) / 2 def _sparsemax_score(output, labels, loss_mask, loss_function="cross_entropy", topk=512, alpha=1.5, n_iter=30): From 4f2c4fc4bccf79d492bb494d0ae11e5c7baaa1b8 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Fri, 12 Jul 2024 15:26:11 +0100 Subject: [PATCH 26/26] get item from sparsemax score tensor --- tasks/zeroshot_gpt/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 8501a98026..6d33c4c20c 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -441,7 +441,7 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): elif eval_metric == "sparsemax_score": avg_sparsemax_score = output / (num_tokenized_tokens - 1) string += 'sparsemax score: {:.4E} | '.format(avg_sparsemax_score) - results["sparsemax_score"] = avg_sparsemax_score + results["sparsemax_score"] = avg_sparsemax_score.item() else: raise NotImplementedError('evaluation method for {} metric is not ' 'implemented yet.'.format(eval_metric))