From 75e8565bc6de6629962d11eb9c98e6da21f071b1 Mon Sep 17 00:00:00 2001 From: AnamarijaKozina <91478609+AnamarijaKozina@users.noreply.github.com> Date: Wed, 15 Jan 2025 21:52:02 +0100 Subject: [PATCH 1/3] Ablation studies (#37) * Modified and ran run_language_modeling.py to test on all datasets individually. * Comparing losses * Arithmetic dataset testing + Pearson and Spearman metrics * removed nohup.out from git tracking * removed result files from git tracking * Changed calculation of metrics using scipy --------- Co-authored-by: thorben010 --- config/dataset_args/arithmetic.yaml | 1 + config/dataset_args/mathematics_dataset.yaml | 3 +- config/model_args/vanilla_t5_ntl.yaml | 3 +- config/run_specific_config/config.yaml | 5 +- .../arith_create_splits.py | 92 ++++++++++++++++ requirements.txt | 1 + src/ntl/args.py | 9 +- src/ntl/evaluation.py | 18 ++++ src/ntl/results_sorting.py | 64 +++++++++++ src/ntl/run_language_modeling.py | 100 ++++++++++++++++-- tests/test_evaluation.py | 8 ++ 11 files changed, 292 insertions(+), 12 deletions(-) create mode 100644 config/dataset_args/arithmetic.yaml create mode 100644 data/mathematics_dataset-v1.0/arith_create_splits.py create mode 100644 src/ntl/results_sorting.py diff --git a/config/dataset_args/arithmetic.yaml b/config/dataset_args/arithmetic.yaml new file mode 100644 index 0000000..ad82529 --- /dev/null +++ b/config/dataset_args/arithmetic.yaml @@ -0,0 +1 @@ +dataset_name: arithmetic \ No newline at end of file diff --git a/config/dataset_args/mathematics_dataset.yaml b/config/dataset_args/mathematics_dataset.yaml index fd952ff..0ed38f2 100644 --- a/config/dataset_args/mathematics_dataset.yaml +++ b/config/dataset_args/mathematics_dataset.yaml @@ -1 +1,2 @@ -dataset_name: mathematics_dataset \ No newline at end of file +dataset_name: mathematics_dataset +mode: interpolate_extrapolate \ No newline at end of file diff --git a/config/model_args/vanilla_t5_ntl.yaml b/config/model_args/vanilla_t5_ntl.yaml index 0028d6a..5af9d72 100644 --- a/config/model_args/vanilla_t5_ntl.yaml +++ b/config/model_args/vanilla_t5_ntl.yaml @@ -3,4 +3,5 @@ config_name: t5-base number_encoding: none number_token_loss: true number_token_loss_weight: 0.3 -number_token_loss_with_wasserstein: false \ No newline at end of file +number_token_loss_with_wasserstein: false +#number_token_loss_function: \ No newline at end of file diff --git a/config/run_specific_config/config.yaml b/config/run_specific_config/config.yaml index 8d00c2b..c69d3fd 100644 --- a/config/run_specific_config/config.yaml +++ b/config/run_specific_config/config.yaml @@ -1,6 +1,7 @@ training_args: - trial: trial_6 + trial: special_name: model_args: - model_name_or_path: google-t5/t5-base \ No newline at end of file + model_name_or_path: google-t5/t5-small + config_name: t5-small diff --git a/data/mathematics_dataset-v1.0/arith_create_splits.py b/data/mathematics_dataset-v1.0/arith_create_splits.py new file mode 100644 index 0000000..fb4f6cf --- /dev/null +++ b/data/mathematics_dataset-v1.0/arith_create_splits.py @@ -0,0 +1,92 @@ +import random +from pathlib import Path + +selected_categories = [ + + "arithmetic__add_sub_multiple.txt", + +] + +train_folders = [ + "train-easy", + "train-medium", + "train-hard", +] + +train_file_path = Path("arithmetic_train.txt") +val_file_path = Path("arithmetic_val.txt") + +files = [Path("mathematics_dataset-v1.0", folder, category) for folder in train_folders for category in selected_categories] + + +length_data = 0 +length_train = 0 +length_val = 0 + +for file in files: + with open(file, 'r') as f: + lines = f.readlines() + + # Group every two lines (question, answer) together + pairs = [(lines[i], lines[i+1]) for i in range(0, len(lines), 2)] + random.shuffle(pairs) # Shuffle the pairs + + # Split the data + train_end = 33333 + train_pairs = pairs[:train_end] + val_pairs = pairs[train_end:(train_end+1000)] + + # Flatten the pairs back into a list of lines + train_data = [line for pair in train_pairs for line in pair] + val_data = [line for pair in val_pairs for line in pair] + + # Write to train and val files + with open(train_file_path, 'a') as f: + f.writelines(train_data) + with open(val_file_path, 'a') as f: + f.writelines(val_data) + + length_data += len(lines) + length_train += len(train_data) + length_val += len(val_data) + +print(f"Data of size {length_data} has been split into train ({length_train} samples) and val ({length_val} samples).") + + +test_interpolate_folder = "interpolate" +test_interpolate_file_path = Path("arithmetic_test_interpolate.txt") + +length_test_interpolate = 0 + +for category in selected_categories: + file = Path("mathematics_dataset-v1.0", test_interpolate_folder, category) + with open(file, 'r') as f: + data = f.readlines() + with open(test_interpolate_file_path, 'a') as f: + f.writelines(data) + + length_test_interpolate += len(data) + +print(f"Data of size {length_test_interpolate} has been written to arithmetic_test_interpolate.txt.") + + +test_extrapolate_folder = "extrapolate" +test_extrapolate_file_path = Path("arithmetic_test_extrapolate.txt") + +length_test_extrapolate = 0 +selected_categories = [ + + "arithmetic__add_sub_multiple_longer.txt", + +] + +for category in selected_categories: + file = Path("mathematics_dataset-v1.0", test_extrapolate_folder, category) + with open(file, 'r') as f: + data = f.readlines() + with open(test_extrapolate_file_path, 'a') as f: + f.writelines(data) + + length_test_extrapolate += len(data) + +print(f"Data of size {length_test_extrapolate} has been written to arithmetic_test_extrapolate.txt.") diff --git a/requirements.txt b/requirements.txt index 882a06d..615a381 100644 --- a/requirements.txt +++ b/requirements.txt @@ -66,3 +66,4 @@ xxhash==3.5.0 yarl>=1.12.0 hydra-core===1.3.2 pytest +scipy==1.13.1 diff --git a/src/ntl/args.py b/src/ntl/args.py index 3a087ba..a64d791 100644 --- a/src/ntl/args.py +++ b/src/ntl/args.py @@ -148,6 +148,13 @@ class DatasetArguments: dataset_name: str = field( default="mathematics_dataset", metadata={ - "help": "Name of the dataset. Allowed: mathematics_dataset, gsm8k, multiplication" + "help": "Name of the dataset. Allowed: mathematics_dataset, gsm8k, multiplication, arithmetic" + }, + ) + + mode: Optional[str] = field( + default="interpolate_extrapolate", + metadata={ + "help": "Whether we combine mathematics datasets in testing, or test individually. Allowed: interpolate_extrapolate, dataset_comparison" }, ) diff --git a/src/ntl/evaluation.py b/src/ntl/evaluation.py index a8775ed..97d5ea5 100644 --- a/src/ntl/evaluation.py +++ b/src/ntl/evaluation.py @@ -14,11 +14,15 @@ from ntl.tokenizer.t5custom_tokenizer import check_number_predictions from ntl.utils.numerical_operations import inverse_signed_log +from scipy import stats + PADDING_TOKEN = -100 MASKED_OUT = -1 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + class CustomMetrics: """ Compute custom metrics for the model with access to the vocab to compute MSE @@ -80,6 +84,7 @@ def parse_number_result_per_sample(self, prediction: str, label: str) -> Tuple[f label_number = float(label_number.replace(" ", "")) return prediction_number, label_number + def calculate_metrics(self, number_results, total_count): mae = np.mean([np.abs(result[0] - result[1]) for result in number_results if not np.isnan(result[0])]) @@ -96,6 +101,13 @@ def calculate_metrics(self, number_results, total_count): log_r2 = 1 - np.nansum((log_transformed_data[:, 0] - log_transformed_data[:, 1]) ** 2) / np.nansum( (log_transformed_data[:, 1] - np.nanmean(log_transformed_data[:, 1])) ** 2) log_mae = np.mean([np.abs(result[0] - result[1]) for result in log_transformed_data if not np.isnan(result[0])]) + + v1 = number_results[:,0] + v2 = number_results[:,1] + v1_valid = v1[~np.isnan(v1) & ~np.isnan(v2)] + v2_valid = v2[~np.isnan(v1) & ~np.isnan(v2)] + pearson = stats.pearsonr(v1_valid, v2_valid).statistic + spearman = stats.spearmanr(v1_valid, v2_valid).statistic return ( mae, @@ -107,6 +119,8 @@ def calculate_metrics(self, number_results, total_count): median_absolute_error, log_mae, log_r2, + pearson, + spearman, ) def perplexity(self, logits, labels): @@ -265,6 +279,8 @@ def __call__(self, pred: EvalPrediction, compute_result: bool) -> Dict[str, floa median_absolute_error, log_mae, log_r2, + pearson, + spearman ) = self.calculate_metrics(number_results, total_count) computed_metrics = { @@ -292,6 +308,8 @@ def __call__(self, pred: EvalPrediction, compute_result: bool) -> Dict[str, floa "rouge1": np.mean([stat['rouge1'] for stat in self.batch_stats]), "rouge2": np.mean([stat['rouge2'] for stat in self.batch_stats]), "rougeL": np.mean([stat['rougeL'] for stat in self.batch_stats]), + 'pearson': pearson, + 'spearman': spearman, } self.batch_stats = [] return computed_metrics diff --git a/src/ntl/results_sorting.py b/src/ntl/results_sorting.py new file mode 100644 index 0000000..0c5190d --- /dev/null +++ b/src/ntl/results_sorting.py @@ -0,0 +1,64 @@ + +vanilla_t5_int = [ + {'name': 'algebra__linear_1d', 'eval_loss': 0.4258742034435272, 'eval_token_accuracy_whole': 0.5923566878980892, 'eval_token_accuracy': 0.9615265730839626, 'eval_MSE': 69.6653, 'eval_MAE': 2.5855, 'eval_R2': 0.9163741393390171, 'eval_number_accuracy': 0.5925, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.06567459240507162, 'eval_log_r2': 0.9461691299199966, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.532950289689811, 'eval_bleu': 0.0, 'eval_rouge1': 0.5953423566878981, 'eval_rouge2': 0.0, 'eval_rougeL': 0.5954418789808917, 'eval_runtime': 78.7641, 'eval_samples_per_second': 126.961, 'eval_steps_per_second': 1.993}, + {'name': 'algebra__linear_1d_composed', 'eval_loss': 0.2892512083053589, 'eval_token_accuracy_whole': 0.6854100318471338, 'eval_token_accuracy': 0.9676049309930984, 'eval_MSE': 3.8092, 'eval_MAE': 0.8426, 'eval_R2': 0.6178125134195044, 'eval_number_accuracy': 0.6854, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.13071760084236383, 'eval_log_r2': 0.6577526280495558, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.3367852177589563, 'eval_bleu': 0.0, 'eval_rouge1': 0.7258658439490446, 'eval_rouge2': 0.0, 'eval_rougeL': 0.7258160828025477, 'eval_runtime': 104.1375, 'eval_samples_per_second': 96.027, 'eval_steps_per_second': 1.508}, + {'name': 'algebra__linear_2d', 'eval_loss': 0.28746625781059265, 'eval_token_accuracy_whole': 0.7295979299363057, 'eval_token_accuracy': 0.9693018866192763, 'eval_MSE': 13.7938, 'eval_MAE': 1.3242, 'eval_R2': 0.8665982694839666, 'eval_number_accuracy': 0.7295, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.13450324958829207, 'eval_log_r2': 0.7919237242191184, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.3351908724778776, 'eval_bleu': 0.0, 'eval_rouge1': 0.7548765923566879, 'eval_rouge2': 0.0, 'eval_rougeL': 0.7548765923566879, 'eval_runtime': 80.4503, 'eval_samples_per_second': 124.3, 'eval_steps_per_second': 1.952}, + {'name': 'algebra__linear_2d_composed', 'eval_loss': 0.2291424721479416, 'eval_token_accuracy_whole': 0.7485071656050956, 'eval_token_accuracy': 0.970299925394119, 'eval_MSE': 2.9228, 'eval_MAE': 0.7116, 'eval_R2': 0.7120233250150461, 'eval_number_accuracy': 0.7488, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.1209722837040559, 'eval_log_r2': 0.7075631143007152, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.25860724555459, 'eval_bleu': 0.0, 'eval_rouge1': 0.7740843949044586, 'eval_rouge2': 0.0, 'eval_rougeL': 0.7741839171974523, 'eval_runtime': 104.3274, 'eval_samples_per_second': 95.852, 'eval_steps_per_second': 1.505}, + {'name': 'algebra__sequence_next_term', 'eval_loss': 0.39050596952438354, 'eval_token_accuracy_whole': 0.56359474522293, 'eval_token_accuracy': 0.9314322429857437, 'eval_MSE': 230094385261.2377, 'eval_MAE': 13313.0003, 'eval_R2': 0.99891410925065, 'eval_number_accuracy': 0.6475, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.006801564481133004, 'eval_log_r2': 0.999718667600947, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.4801813690525711, 'eval_bleu': 0.0, 'eval_rouge1': 0.6405254777070064, 'eval_rouge2': 0.0, 'eval_rougeL': 0.6405254777070064, 'eval_runtime': 97.1366, 'eval_samples_per_second': 102.948, 'eval_steps_per_second': 1.616}, + {'name': 'arithmetic__add_or_sub', 'eval_loss': 0.11415749788284302, 'eval_token_accuracy_whole': 0.7467157643312102, 'eval_token_accuracy': 0.936481722221253, 'eval_MSE': 6.673432528040758e+20, 'eval_MAE': 4661304987.762815, 'eval_R2': 0.13633948665844453, 'eval_number_accuracy': 0.8577, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.05893851995067928, 'eval_log_r2': 0.9989206233478488, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.1212423037571513, 'eval_bleu': 0.0, 'eval_rouge1': 0.9619824840764332, 'eval_rouge2': 0.7138734076433121, 'eval_rougeL': 0.9618663747346072, 'eval_runtime': 108.6351, 'eval_samples_per_second': 92.051, 'eval_steps_per_second': 1.445}, + {'name': 'arithmetic__add_sub_multiple', 'eval_loss': 0.41933146119117737, 'eval_token_accuracy_whole': 0.5085589171974523, 'eval_token_accuracy': 0.9554504634468419, 'eval_MSE': 7.8443, 'eval_MAE': 1.3773, 'eval_R2': 0.9908491399041514, 'eval_number_accuracy': 0.5092, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.039644936248351824, 'eval_log_r2': 0.9921403150060096, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.522248325833849, 'eval_bleu': 0.0, 'eval_rouge1': 0.5124402866242038, 'eval_rouge2': 0.0, 'eval_rougeL': 0.5124402866242038, 'eval_runtime': 72.7013, 'eval_samples_per_second': 137.549, 'eval_steps_per_second': 2.16}, + {'name': 'arithmetic__mul', 'eval_loss': 0.8451672196388245, 'eval_token_accuracy_whole': 0.3307125796178344, 'eval_token_accuracy': 0.852984408664096, 'eval_MSE': 86015900204285.73, 'eval_MAE': 340314.42553793255, 'eval_R2': -0.16316933091056773, 'eval_number_accuracy': 0.433, 'eval_median_absolute_error': 0.028100000000000014, 'eval_log_mae': 0.015827723981502534, 'eval_log_r2': 0.9993962064235589, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 2.3388660736144726, 'eval_bleu': 0.0, 'eval_rouge1': 0.5656017781316348, 'eval_rouge2': 0.2875199044585987, 'eval_rougeL': 0.5658505838641189, 'eval_runtime': 94.8632, 'eval_samples_per_second': 105.415, 'eval_steps_per_second': 1.655}, + {'name': 'numbers__div_remainder', 'eval_loss': 1.5148051977157593, 'eval_token_accuracy_whole': 0.23616640127388536, 'eval_token_accuracy': 0.896610211414896, 'eval_MSE': 917753.1246, 'eval_MAE': 328.1728, 'eval_R2': -1.092422477107827, 'eval_number_accuracy': 0.2395, 'eval_median_absolute_error': 24.0, 'eval_log_mae': 0.6093405276318679, 'eval_log_r2': -0.5041607434479516, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 4.574191984856964, 'eval_bleu': 0.0, 'eval_rouge1': 0.23915207006369427, 'eval_rouge2': 0.0, 'eval_rougeL': 0.2392515923566879, 'eval_runtime': 71.0488, 'eval_samples_per_second': 140.748, 'eval_steps_per_second': 2.21}, + {'name': 'numbers__div_remainder_composed', 'eval_loss': 1.1680721044540405, 'eval_token_accuracy_whole': 0.26990445859872614, 'eval_token_accuracy': 0.9279677621118582, 'eval_MSE': 615.1854, 'eval_MAE': 9.088, 'eval_R2': 0.06951730346549256, 'eval_number_accuracy': 0.2679, 'eval_median_absolute_error': 2.0, 'eval_log_mae': 0.19292789754353504, 'eval_log_r2': 0.3519345225450018, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 3.233200613860112, 'eval_bleu': 0.0, 'eval_rouge1': 0.26990445859872614, 'eval_rouge2': 0.0, 'eval_rougeL': 0.26990445859872614, 'eval_runtime': 91.7118, 'eval_samples_per_second': 109.037, 'eval_steps_per_second': 1.712}, + {'name': 'numbers__place_value', 'eval_loss': 0.00021375504729803652, 'eval_token_accuracy_whole': 0.9999004777070064, 'eval_token_accuracy': 0.9999917495022913, 'eval_MSE': 0.0004, 'eval_MAE': 0.0002, 'eval_R2': 0.9999506834114683, 'eval_number_accuracy': 0.9999, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 1.2493873660829991e-05, 'eval_log_r2': 0.9999817787321196, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.0002167308406464, 'eval_bleu': 0.0, 'eval_rouge1': 0.9999004777070064, 'eval_rouge2': 0.0, 'eval_rougeL': 0.9999004777070064, 'eval_runtime': 73.0912, 'eval_samples_per_second': 136.815, 'eval_steps_per_second': 2.148}, + {'name': 'numbers__round_number', 'eval_loss': 0.013202344067394733, 'eval_token_accuracy_whole': 0.9808917197452229, 'eval_token_accuracy': 0.9959232697061672, 'eval_MSE': 8.297708350377116e+17, 'eval_MAE': 54089755.922709994, 'eval_R2': 0.6984818480937057, 'eval_number_accuracy': 0.992, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.003449627763777241, 'eval_log_r2': 0.999849256705451, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.0133366812566282, 'eval_bleu': 0.0, 'eval_rouge1': 0.9971636146496815, 'eval_rouge2': 0.48298168789808915, 'eval_rougeL': 0.9971636146496815, 'eval_runtime': 89.5934, 'eval_samples_per_second': 111.615, 'eval_steps_per_second': 1.752}, + {'name': 'numbers__round_number_composed', 'eval_loss': 0.2976441979408264, 'eval_token_accuracy_whole': 0.6739649681528662, 'eval_token_accuracy': 0.9564651402698201, 'eval_MSE': 1985895974680671.5, 'eval_MAE': 2382862.2154631135, 'eval_R2': -6.9185205825717855, 'eval_number_accuracy': 0.6789, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.3154235044066668, 'eval_log_r2': 0.7579880454602687, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.348972459507596, 'eval_bleu': 0.0, 'eval_rouge1': 0.7162536491507431, 'eval_rouge2': 0.43531050955414013, 'eval_rougeL': 0.7163117038216561, 'eval_runtime': 114.2385, 'eval_samples_per_second': 87.536, 'eval_steps_per_second': 1.374}, + ] + +vanilla_t5_ext = [ + {'name': 'arithmetic__add_or_sub_big', 'eval_loss': 0.2507264316082001, 'eval_token_accuracy_whole': 0.5526472929936306, 'eval_token_accuracy': 0.8848125968769098, 'eval_MSE': 2.583616616028877e+35, 'eval_MAE': 4.118241409875469e+16, 'eval_R2': -0.00018917444741184397, 'eval_number_accuracy': 0.4474, 'eval_median_absolute_error': 1000850.0, 'eval_log_mae': 1.6530168219201293, 'eval_log_r2': 0.9248600736325592, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 2, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0002, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.2871854176187212, 'eval_bleu': 0.0, 'eval_rouge1': 0.8464188561571124, 'eval_rouge2': 0.5980294585987261, 'eval_rougeL': 0.8465780918259023, 'eval_runtime': 142.1211, 'eval_samples_per_second': 70.363, 'eval_steps_per_second': 1.105}, + {'name': 'arithmetic__add_sub_multiple_longer', 'eval_loss': 1.2236006259918213, 'eval_token_accuracy_whole': 0.09822850318471338, 'eval_token_accuracy': 0.9076944323861675, 'eval_MSE': 5257.0654, 'eval_MAE': 21.131, 'eval_R2': 0.8386116807837074, 'eval_number_accuracy': 0.0982, 'eval_median_absolute_error': 3.0, 'eval_log_mae': 0.12978320960776954, 'eval_log_r2': 0.9284899166288128, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 3.41554453418513, 'eval_bleu': 0.0, 'eval_rouge1': 0.09952229299363058, 'eval_rouge2': 0.0, 'eval_rougeL': 0.09952229299363058, 'eval_runtime': 88.1948, 'eval_samples_per_second': 113.385, 'eval_steps_per_second': 1.78}, + {'name': 'arithmetic__mixed_longer', 'eval_loss': 2.1138792037963867, 'eval_token_accuracy_whole': 0.007663216560509554, 'eval_token_accuracy': 0.8549421771316771, 'eval_MSE': 29532.5228, 'eval_MAE': 131.62, 'eval_R2': 0.08975772273294524, 'eval_number_accuracy': 0.0077, 'eval_median_absolute_error': 111.0, 'eval_log_mae': 1.6590986202714866, 'eval_log_r2': 0.08228582206945667, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 8.309127625386427, 'eval_bleu': 0.0, 'eval_rouge1': 0.009355095541401274, 'eval_rouge2': 0.0, 'eval_rougeL': 0.009355095541401274, 'eval_runtime': 89.5319, 'eval_samples_per_second': 111.692, 'eval_steps_per_second': 1.754}, + {'name': 'arithmetic__mul_big', 'eval_loss': 1.6134669780731201, 'eval_token_accuracy_whole': 0.18531050955414013, 'eval_token_accuracy': 0.7657490608039176, 'eval_MSE': 5.5556807913926885e+20, 'eval_MAE': 3546304250.595108, 'eval_R2': 0.11181923583159037, 'eval_number_accuracy': 0.2746, 'eval_median_absolute_error': 584.0010000000002, 'eval_log_mae': 0.12359730770643357, 'eval_log_r2': 0.9975416691180987, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 2, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0002, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 5.072481367998062, 'eval_bleu': 0.0, 'eval_rouge1': 0.39377488057324844, 'eval_rouge2': 0.176453025477707, 'eval_rougeL': 0.3935957404458599, 'eval_runtime': 107.0484, 'eval_samples_per_second': 93.416, 'eval_steps_per_second': 1.467}, + {'name': 'arithmetic__mul_div_multiple_longer', 'eval_loss': 2.1390328407287598, 'eval_token_accuracy_whole': 0.006170382165605096, 'eval_token_accuracy': 0.8581209353580597, 'eval_MSE': 32318.0636, 'eval_MAE': 145.1968, 'eval_R2': 0.03474937496316155, 'eval_number_accuracy': 0.0062, 'eval_median_absolute_error': 127.0, 'eval_log_mae': 1.8136232787337097, 'eval_log_r2': -0.006964487147834708, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 8.505382422428982, 'eval_bleu': 0.0, 'eval_rouge1': 0.00875796178343949, 'eval_rouge2': 0.0, 'eval_rougeL': 0.00875796178343949, 'eval_runtime': 91.236, 'eval_samples_per_second': 109.606, 'eval_steps_per_second': 1.721}, + {'name': 'numbers__place_value_big', 'eval_loss': 1.121049404144287, 'eval_token_accuracy_whole': 0.7986664012738853, 'eval_token_accuracy': 0.9800903850300297, 'eval_MSE': 3.3947, 'eval_MAE': 0.7081, 'eval_R2': 0.5815159287707787, 'eval_number_accuracy': 0.7983, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.0665311562579792, 'eval_log_r2': 0.6338582137643287, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 3.265867238591431, 'eval_bleu': 0.0, 'eval_rouge1': 0.7986664012738853, 'eval_rouge2': 0.0, 'eval_rougeL': 0.7986664012738853, 'eval_runtime': 72.6584, 'eval_samples_per_second': 137.63, 'eval_steps_per_second': 2.161}, + {'name': 'numbers__round_number_big', 'eval_loss': 0.02936192788183689, 'eval_token_accuracy_whole': 0.9501393312101911, 'eval_token_accuracy': 0.9885114522496606, 'eval_MSE': 1.8705761996922658e+22, 'eval_MAE': 13459811101.707745, 'eval_R2': 0.014429220200584347, 'eval_number_accuracy': 0.948, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.04740547474012875, 'eval_log_r2': 0.9972067538741404, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.0298788463993438, 'eval_bleu': 0.0, 'eval_rouge1': 0.9928343949044586, 'eval_rouge2': 0.4922372611464968, 'eval_rougeL': 0.9928343949044586, 'eval_runtime': 92.6559, 'eval_samples_per_second': 107.926, 'eval_steps_per_second': 1.694}, + ] + +rt_ntl_int = [ + {'name': 'algebra__linear_1d', 'eval_loss': 1.2600773572921753, 'eval_token_loss': 0.5663610696792603, 'eval_number_loss': 2.312387704849243, 'eval_token_accuracy_whole': 0.5498606687898089, 'eval_token_accuracy': 0.9528332026141464, 'eval_MSE': 240.2661, 'eval_MAE': 4.1417, 'eval_R2': 0.7115858339782103, 'eval_number_accuracy': 0.5494, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.09852558094610887, 'eval_log_r2': 0.8979490192039823, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 15, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0015, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.7694234756907081, 'eval_bleu': 0.17634257049937066, 'eval_rouge1': 0.554073779193206, 'eval_rouge2': 0.0, 'eval_rougeL': 0.5542064755838642, 'eval_runtime': 106.6194, 'eval_samples_per_second': 93.792, 'eval_steps_per_second': 1.473}, + {'name': 'algebra__linear_1d_composed', 'eval_loss': 0.9651321172714233, 'eval_token_loss': 0.5929774641990662, 'eval_number_loss': 1.2405154705047607, 'eval_token_accuracy_whole': 0.5567277070063694, 'eval_token_accuracy': 0.953914934662497, 'eval_MSE': 6.0143, 'eval_MAE': 1.2583, 'eval_R2': 0.3965687807043278, 'eval_number_accuracy': 0.5579, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.19801868301893155, 'eval_log_r2': 0.47037571435984527, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.82390563427263, 'eval_bleu': 0.0, 'eval_rouge1': 0.6133393046709129, 'eval_rouge2': 0.0, 'eval_rougeL': 0.6133227176220807, 'eval_runtime': 126.3358, 'eval_samples_per_second': 79.154, 'eval_steps_per_second': 1.243}, + {'name': 'algebra__linear_2d', 'eval_loss': 1.060610294342041, 'eval_token_loss': 0.5059592127799988, 'eval_number_loss': 1.848836898803711, 'eval_token_accuracy_whole': 0.6605294585987261, 'eval_token_accuracy': 0.9575892717215666, 'eval_MSE': 33.8816, 'eval_MAE': 1.8392, 'eval_R2': 0.6723264022494138, 'eval_number_accuracy': 0.6607, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.1622092811303612, 'eval_log_r2': 0.7460571135018932, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.6702218215177014, 'eval_bleu': 0.12634039142791645, 'eval_rouge1': 0.6898443029016278, 'eval_rouge2': 0.0, 'eval_rougeL': 0.6895125619249823, 'eval_runtime': 107.431, 'eval_samples_per_second': 93.083, 'eval_steps_per_second': 1.461}, + {'name': 'algebra__linear_2d_composed', 'eval_loss': 0.6940405964851379, 'eval_token_loss': 0.4334726631641388, 'eval_number_loss': 0.868559718132019, 'eval_token_accuracy_whole': 0.6892914012738853, 'eval_token_accuracy': 0.9645563246338231, 'eval_MSE': 2.4737, 'eval_MAE': 0.7187, 'eval_R2': 0.7562721017824413, 'eval_number_accuracy': 0.6893, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.12164884115246719, 'eval_log_r2': 0.7468453365438512, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.5509947796536099, 'eval_bleu': 0.31211945857829637, 'eval_rouge1': 0.7172254994209611, 'eval_rouge2': 0.0, 'eval_rougeL': 0.7172254994209611, 'eval_runtime': 130.6368, 'eval_samples_per_second': 76.548, 'eval_steps_per_second': 1.202}, + {'name': 'algebra__sequence_next_term', 'eval_loss': 0.4013504087924957, 'eval_token_loss': 0.25983375310897827, 'eval_number_loss': 0.47172221541404724, 'eval_token_accuracy_whole': 0.6824243630573248, 'eval_token_accuracy': 0.96863310580041, 'eval_MSE': 2145525866253.6248, 'eval_MAE': 42508.47803901951, 'eval_R2': 0.9898796234678747, 'eval_number_accuracy': 0.6939, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.04263236190792732, 'eval_log_r2': 0.991240379854552, 'eval_count_not_produced_valid_results': 5, 'eval_average_count_not_produced_valid_results': 0.0005, 'eval_count_invalid_number_prediction': 27, 'eval_count_no_number_prediction': 5, 'eval_average_invalid_number_prediction': 0.0027, 'eval_average_no_number_prediction': 0.0005, 'eval_token_perplexity': 1.2992431259458992, 'eval_bleu': 6.090133438304074, 'eval_rouge1': 0.6838299969669397, 'eval_rouge2': 0.0, 'eval_rougeL': 0.6837956380800728, 'eval_runtime': 162.6094, 'eval_samples_per_second': 61.497, 'eval_steps_per_second': 0.966}, + {'name': 'arithmetic__add_or_sub', 'eval_loss': 0.09025327861309052, 'eval_token_loss': 0.07104214280843735, 'eval_number_loss': 0.06403713673353195, 'eval_token_accuracy_whole': 0.9381966560509554, 'eval_token_accuracy': 0.9941434017412222, 'eval_MSE': 6.672399390918147e+20, 'eval_MAE': 4656420096.486205, 'eval_R2': 0.13647319292338655, 'eval_number_accuracy': 0.8919, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.06580346248708865, 'eval_log_r2': 0.997956846508289, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 11, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0011, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.074762511405216, 'eval_bleu': 2.8112082785991563, 'eval_rouge1': 0.9685939237690832, 'eval_rouge2': 0.6983346602972399, 'eval_rougeL': 0.9684564882216156, 'eval_runtime': 178.61, 'eval_samples_per_second': 55.988, 'eval_steps_per_second': 0.879}, + {'name': 'arithmetic__add_sub_multiple', 'eval_loss': 0.35576826333999634, 'eval_token_loss': 0.21297582983970642, 'eval_number_loss': 0.4759747087955475, 'eval_token_accuracy_whole': 0.865047770700637, 'eval_token_accuracy': 0.9863450261437969, 'eval_MSE': 17.5129, 'eval_MAE': 1.1193, 'eval_R2': 0.9795701212635177, 'eval_number_accuracy': 0.8653, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.03100907014136144, 'eval_log_r2': 0.9861341705926026, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 15, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0015, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.2432642165262988, 'eval_bleu': 0.39200403962989666, 'eval_rouge1': 0.86609275477707, 'eval_rouge2': 0.0, 'eval_rougeL': 0.8660595806794057, 'eval_runtime': 116.3628, 'eval_samples_per_second': 85.938, 'eval_steps_per_second': 1.349}, + {'name': 'arithmetic__mul', 'eval_loss': 1.7417913675308228, 'eval_token_loss': 0.8531287908554077, 'eval_number_loss': 2.9622082710266113, 'eval_token_accuracy_whole': 0.47253184713375795, 'eval_token_accuracy': 0.9068294224465728, 'eval_MSE': 4211493852888.456, 'eval_MAE': 170152.78892582332, 'eval_R2': 0.9430491284127239, 'eval_number_accuracy': 0.4801, 'eval_median_absolute_error': 0.0008000000000003993, 'eval_log_mae': 0.07955147576700512, 'eval_log_r2': 0.9778094817750923, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 12, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0012, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 2.386666980518657, 'eval_bleu': 0.3618406866495384, 'eval_rouge1': 0.5963381325194621, 'eval_rouge2': 0.32541657188353046, 'eval_rougeL': 0.5964415251238498, 'eval_runtime': 150.6602, 'eval_samples_per_second': 66.375, 'eval_steps_per_second': 1.042}, + {'name': 'numbers__div_remainder', 'eval_loss': 4.396943092346191, 'eval_token_loss': 2.10192608833313, 'eval_number_loss': 7.6500563621521, 'eval_token_accuracy_whole': 0.2327826433121019, 'eval_token_accuracy': 0.8634865587684, 'eval_MSE': 478739.2085, 'eval_MAE': 281.1305, 'eval_R2': -0.0914968891822725, 'eval_number_accuracy': 0.2344, 'eval_median_absolute_error': 22.0, 'eval_log_mae': 0.44689248922714503, 'eval_log_r2': 0.13251370725605183, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 2, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0002, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 8.335521005521155, 'eval_bleu': 0.49480816076571416, 'eval_rouge1': 0.2333681661358811, 'eval_rouge2': 0.0, 'eval_rougeL': 0.23325703290870486, 'eval_runtime': 94.14, 'eval_samples_per_second': 106.225, 'eval_steps_per_second': 1.668}, + {'name': 'numbers__div_remainder_composed', 'eval_loss': 4.4499382972717285, 'eval_token_loss': 2.1079111099243164, 'eval_number_loss': 7.806756496429443, 'eval_token_accuracy_whole': 0.18481289808917198, 'eval_token_accuracy': 0.8994184550206372, 'eval_MSE': 855.1073, 'eval_MAE': 11.3817, 'eval_R2': -0.29337033409821167, 'eval_number_accuracy': 0.1836, 'eval_median_absolute_error': 3.0, 'eval_log_mae': 0.21725255261589146, 'eval_log_r2': 0.3495540000520768, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 8.444629701079837, 'eval_bleu': 0.13409862919343477, 'eval_rouge1': 0.18483777866242038, 'eval_rouge2': 0.0, 'eval_rougeL': 0.18483777866242038, 'eval_runtime': 109.7389, 'eval_samples_per_second': 91.125, 'eval_steps_per_second': 1.431}, + {'name': 'numbers__place_value', 'eval_loss': 1.1262856721878052, 'eval_token_loss': 0.6906732320785522, 'eval_number_loss': 1.4520413875579834, 'eval_token_accuracy_whole': 0.9103304140127388, 'eval_token_accuracy': 0.9753184690596951, 'eval_MSE': 996137048838806.8, 'eval_MAE': 2031765.0180757, 'eval_R2': -122815202397023.75, 'eval_number_accuracy': 0.9099, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.1941363800248147, 'eval_log_r2': -13.122405430609843, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 12, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0012, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 2.112388398237289, 'eval_bleu': 0.0, 'eval_rouge1': 0.9116971868365181, 'eval_rouge2': 0.0, 'eval_rougeL': 0.9116474256900212, 'eval_runtime': 109.9693, 'eval_samples_per_second': 90.934, 'eval_steps_per_second': 1.428}, + {'name': 'numbers__round_number', 'eval_loss': 0.012691711075603962, 'eval_token_loss': 0.012106580659747124, 'eval_number_loss': 0.0019504346419125795, 'eval_token_accuracy_whole': 0.9940286624203821, 'eval_token_accuracy': 0.9995708871798911, 'eval_MSE': 8.010071407000001e+17, 'eval_MAE': 49643300.00039857, 'eval_R2': 0.7089338615804297, 'eval_number_accuracy': 0.9902, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.0013881553233657847, 'eval_log_r2': 0.999970313070524, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.0125112404489214, 'eval_bleu': 0.0, 'eval_rouge1': 0.9971636146496815, 'eval_rouge2': 0.477906050955414, 'eval_rougeL': 0.9971636146496815, 'eval_runtime': 142.4298, 'eval_samples_per_second': 70.21, 'eval_steps_per_second': 1.102}, + {'name': 'numbers__round_number_composed', 'eval_loss': 0.6396009922027588, 'eval_token_loss': 0.37430205941200256, 'eval_number_loss': 0.8843298554420471, 'eval_token_accuracy_whole': 0.678343949044586, 'eval_token_accuracy': 0.9323904442179735, 'eval_MSE': 90143761063992.8, 'eval_MAE': 978833.539920974, 'eval_R2': 0.6406345076573261, 'eval_number_accuracy': 0.6805, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.7775066897180416, 'eval_log_r2': 0.2734046524398641, 'eval_count_not_produced_valid_results': 2, 'eval_average_count_not_produced_valid_results': 0.0002, 'eval_count_invalid_number_prediction': 0, 'eval_count_no_number_prediction': 2, 'eval_average_invalid_number_prediction': 0.0, 'eval_average_no_number_prediction': 0.0002, 'eval_token_perplexity': 1.45956001615828, 'eval_bleu': 0.20486280350904043, 'eval_rouge1': 0.7160214304670912, 'eval_rouge2': 0.45554670912951173, 'eval_rougeL': 0.7157643312101911, 'eval_runtime': 179.9002, 'eval_samples_per_second': 55.586, 'eval_steps_per_second': 0.873}, + ] + +rt_ntl_ext = [ +{'name': 'arithmetic__add_or_sub_big', 'eval_loss': 3.328371524810791, 'eval_token_loss': 2.6555933952331543, 'eval_number_loss': 2.242593288421631, 'eval_token_accuracy_whole': 0.5109474522292994, 'eval_token_accuracy': 0.8210418087661646, 'eval_MSE': 2.5836166248303327e+35, 'eval_MAE': 4.118241649322059e+16, 'eval_R2': -0.0001891778546980749, 'eval_number_accuracy': 0.4822, 'eval_median_absolute_error': 99999.99999999005, 'eval_log_mae': 4.1710831925709435, 'eval_log_r2': 0.2533772029372888, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 2859, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.2859, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 16.426856992351023, 'eval_bleu': 9.446392965805282, 'eval_rouge1': 0.6684312544210036, 'eval_rouge2': 0.39955697500482534, 'eval_rougeL': 0.6689365719869301, 'eval_runtime': 233.0489, 'eval_samples_per_second': 42.909, 'eval_steps_per_second': 0.674}, +{'name': 'arithmetic__add_sub_multiple_longer', 'eval_loss': 3.6553287506103516, 'eval_token_loss': 1.9846962690353394, 'eval_number_loss': 5.56877326965332, 'eval_token_accuracy_whole': 0.1747611464968153, 'eval_token_accuracy': 0.897322247742088, 'eval_MSE': 4658.2378, 'eval_MAE': 24.25, 'eval_R2': 0.8569952793336373, 'eval_number_accuracy': 0.1755, 'eval_median_absolute_error': 6.0, 'eval_log_mae': 0.1682269042208606, 'eval_log_r2': 0.898908297469687, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 1139, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.1139, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 7.439464058845666, 'eval_bleu': 5.125718019753406, 'eval_rouge1': 0.17708333333333334, 'eval_rouge2': 0.0, 'eval_rougeL': 0.1770501592356688, 'eval_runtime': 129.3647, 'eval_samples_per_second': 77.301, 'eval_steps_per_second': 1.214}, +{'name': 'arithmetic__mixed_longer', 'eval_loss': 6.909751892089844, 'eval_token_loss': 3.183138132095337, 'eval_number_loss': 12.422048568725586, 'eval_token_accuracy_whole': 0.005075636942675159, 'eval_token_accuracy': 0.7913994033625171, 'eval_MSE': 25260792897240.17, 'eval_MAE': 82556.5586, 'eval_R2': -778580339.327557, 'eval_number_accuracy': 0.0051, 'eval_median_absolute_error': 128.0, 'eval_log_mae': 2.0421374610304515, 'eval_log_r2': -0.801077112160008, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 1, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0001, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 24.468366556106858, 'eval_bleu': 0.0, 'eval_rouge1': 0.007663216560509554, 'eval_rouge2': 0.0, 'eval_rougeL': 0.007663216560509554, 'eval_runtime': 132.1308, 'eval_samples_per_second': 75.683, 'eval_steps_per_second': 1.188}, +{'name': 'arithmetic__mul_big', 'eval_loss': 4.243931293487549, 'eval_token_loss': 2.3281009197235107, 'eval_number_loss': 6.386099815368652, 'eval_token_accuracy_whole': 0.20242834394904458, 'eval_token_accuracy': 0.7424696759813151, 'eval_MSE': 5.975979778898237e+20, 'eval_MAE': 4310915103.520292, 'eval_R2': 0.04462648485852594, 'eval_number_accuracy': 0.2405, 'eval_median_absolute_error': 3476.3244856, 'eval_log_mae': 0.7650544070037113, 'eval_log_r2': 0.8408977842560506, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 581, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0581, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 10.825987791559498, 'eval_bleu': 5.9522079146070155, 'eval_rouge1': 0.3439358934258417, 'eval_rouge2': 0.14964835456475584, 'eval_rougeL': 0.3435352372105954, 'eval_runtime': 178.904, 'eval_samples_per_second': 55.896, 'eval_steps_per_second': 0.878}, +{'name': 'arithmetic__mul_div_multiple_longer', 'eval_loss': 7.075137138366699, 'eval_token_loss': 3.303149938583374, 'eval_number_loss': 12.573290824890137, 'eval_token_accuracy_whole': 0.0015923566878980893, 'eval_token_accuracy': 0.7781247918013554, 'eval_MSE': 199496744186248.16, 'eval_MAE': 590345.6431, 'eval_R2': -5958412588.378972, 'eval_number_accuracy': 0.0016, 'eval_median_absolute_error': 154.0, 'eval_log_mae': 2.3300404710403413, 'eval_log_r2': -1.2344501883266537, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 1, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0001, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 27.7457017837816, 'eval_bleu': 0.0, 'eval_rouge1': 0.0031847133757961785, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0031847133757961785, 'eval_runtime': 138.37, 'eval_samples_per_second': 72.27, 'eval_steps_per_second': 1.135}, +{'name': 'numbers__place_value_big', 'eval_loss': 5.235705852508545, 'eval_token_loss': 3.5247833728790283, 'eval_number_loss': 5.703073501586914, 'eval_token_accuracy_whole': 0.6030055732484076, 'eval_token_accuracy': 0.8612161616610873, 'eval_MSE': 1.0741987259489624e+19, 'eval_MAE': 1113509126.2240102, 'eval_R2': -1.3242261647402253e+18, 'eval_number_accuracy': 0.6021, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 1.2710678882682367, 'eval_log_r2': -125.21769369562054, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 1143, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.1143, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 43.86019048569309, 'eval_bleu': 0.10174253252329178, 'eval_rouge1': 0.6130888402335456, 'eval_rouge2': 0.0, 'eval_rougeL': 0.6131054272823779, 'eval_runtime': 110.5326, 'eval_samples_per_second': 90.471, 'eval_steps_per_second': 1.42}, +{'name': 'numbers__round_number_big', 'eval_loss': 0.02828773856163025, 'eval_token_loss': 0.02383282594382763, 'eval_number_loss': 0.014849713072180748, 'eval_token_accuracy_whole': 0.9847730891719745, 'eval_token_accuracy': 0.9984898407747791, 'eval_MSE': 1.870963705354604e+22, 'eval_MAE': 13454661048.027798, 'eval_R2': 0.014225050887476276, 'eval_number_accuracy': 0.9404, 'eval_median_absolute_error': 0.0, 'eval_log_mae': 0.04855593274846795, 'eval_log_r2': 0.9953372668065608, 'eval_count_not_produced_valid_results': 0, 'eval_average_count_not_produced_valid_results': 0.0, 'eval_count_invalid_number_prediction': 2, 'eval_count_no_number_prediction': 0, 'eval_average_invalid_number_prediction': 0.0002, 'eval_average_no_number_prediction': 0.0, 'eval_token_perplexity': 1.0246237015268604, 'eval_bleu': 0.0, 'eval_rouge1': 0.9919884554140127, 'eval_rouge2': 0.48059315286624205, 'eval_rougeL': 0.9919884554140127, 'eval_runtime': 151.6608, 'eval_samples_per_second': 65.937, 'eval_steps_per_second': 1.035}, + +] + + +criterion = 'eval_token_accuracy' +sorted_list = sorted(vanilla_t5_int, key=lambda x: x[criterion], reverse=False) +for dict in sorted_list: + print(str(dict[criterion]) + ' - ' + dict['name']) +print('_______________') + +sorted_list = sorted(rt_ntl_int, key=lambda x: x[criterion], reverse=False) +for dict in sorted_list: + print(str(dict[criterion]) + ' - ' + dict['name']) \ No newline at end of file diff --git a/src/ntl/run_language_modeling.py b/src/ntl/run_language_modeling.py index 023d19f..e9478d7 100644 --- a/src/ntl/run_language_modeling.py +++ b/src/ntl/run_language_modeling.py @@ -10,6 +10,7 @@ sys.path.append(".") os.environ["CUDA_VISIBLE_DEVICES"] = "0" +import time import json import logging import numpy as np @@ -94,6 +95,7 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg training_args.n_gpu, bool(training_args.local_rank != -1), ) + logger.info("Training on dataset: %s", dataset_args.dataset_name) logger.info("Training/evaluation parameters %s", training_args) @@ -121,6 +123,7 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg # Set seed set_seed(training_args.seed) + if model_args.config_name: # if file exists load it otherwise just use config name if os.path.isfile(model_args.config_name): @@ -134,8 +137,10 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg cache_dir=model_args.cache_dir, mem_len=model_params.get("mem_len", 1024), ) + elif model_args.model_name_or_path: + if "checkpoint" not in model_args.model_name_or_path: model_args.model_name_or_path = get_latest_checkpoint( model_args.model_name_or_path, @@ -148,11 +153,14 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg ) model_params = config.__dict__ + else: config = CONFIG_MAPPING[model_args.model_type]() model_params = config.__dict__ logger.warning("You are instantiating a new config instance from scratch.") + + if training_args.language_modelling == "clm": # Set generation arguments training_args.predict_with_generate = True @@ -324,6 +332,18 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg eval_dataset = load_txt_dataset(eval_data_path) test_interpolate_dataset = load_txt_dataset(test_interpolate_data_path) test_extrapolate_dataset = load_txt_dataset(test_extrapolate_data_path) + elif dataset_args.dataset_name == "arithmetic": + logger.info("Training on arithmetic dataset") + train_data_path = 'data/mathematics_dataset-v1.0/arithmetic_train.txt' + eval_data_path = 'data/mathematics_dataset-v1.0/arithmetic_val.txt' + test_interpolate_data_path = 'data/mathematics_dataset-v1.0/arithmetic_test_interpolate.txt' + test_extrapolate_data_path = 'data/mathematics_dataset-v1.0/arithmetic_test_extrapolate.txt' + + train_dataset = load_txt_dataset(train_data_path) + eval_dataset = load_txt_dataset(eval_data_path) + test_interpolate_dataset = load_txt_dataset(test_interpolate_data_path) + test_extrapolate_dataset = load_txt_dataset(test_extrapolate_data_path) + elif dataset_args.dataset_name == "multiplication": train_data_path = 'data/digit-multiplication/data/train.jsonl' eval_data_path = 'data/digit-multiplication/data/val.jsonl' @@ -354,6 +374,7 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg early_stopping_threshold=0.001) # custom_trainer_params = get_trainer_dict(model_params) + # Initialize our Trainer trainer = CustomSeq2SeqTrainer( @@ -377,7 +398,11 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg ) if not training_args.do_only_eval: + start_time = time.time() trainer.train(model_path=model_path) + end_time=time.time() + logger.info("Elapsed time:") + logger.info(end_time - start_time) trainer.save_model() # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) @@ -386,19 +411,32 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg else: logger.info("Skipping training.") - logger.info("*** Evaluate on validation data ***") - eval_results_val = trainer.evaluate(eval_dataset=eval_dataset) - logger.info(f"eval_results validation data: {eval_results_val}") - - if not training_args.do_only_eval: - return eval_results_val, model + if not (dataset_args.dataset_name == "mathematics_dataset" and dataset_args.mode == "dataset_comparison"): + + logger.info("*** Evaluate on validation data ***") + eval_results_val = trainer.evaluate(eval_dataset=eval_dataset) + logger.info(f"eval_results validation data: {eval_results_val}") + + if not training_args.do_only_eval: + return eval_results_val, model + if dataset_args.dataset_name in ["gsm8k", "multiplication"]: logger.info("*** Evaluate on test set ***") eval_results_test = trainer.evaluate(eval_dataset=test_dataset) logger.info(f"eval_results test data: {eval_results_test}") return eval_results_val, eval_results_test - elif dataset_args.dataset_name == "mathematics_dataset": + elif dataset_args.dataset_name == "arithmetic": + logger.info("*** Evaluate on interpolation data for arithmetic ***") + eval_results_test_interpolate = trainer.evaluate(eval_dataset=test_interpolate_dataset) + logger.info(f"eval_results interpolate data: {eval_results_test_interpolate}") + + logger.info("*** Evaluate on extrapolation data for arithmetic ***") + eval_results_test_extrapolate = trainer.evaluate(eval_dataset=test_extrapolate_dataset) + logger.info(f"eval_results extrapolate data: {eval_results_test_extrapolate}") + + return eval_results_val, eval_results_test_interpolate, eval_results_test_extrapolate + elif dataset_args.dataset_name == "mathematics_dataset" and dataset_args.mode == "interpolate_extrapolate": logger.info("*** Evaluate on interpolate data ***") eval_results_test_interpolate = trainer.evaluate(eval_dataset=test_interpolate_dataset) logger.info(f"eval_results interpolate data: {eval_results_test_interpolate}") @@ -407,6 +445,54 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg eval_results_test_extrapolate = trainer.evaluate(eval_dataset=test_extrapolate_dataset) logger.info(f"eval_results extrapolate data: {eval_results_test_extrapolate}") return eval_results_val, eval_results_test_interpolate, eval_results_test_extrapolate + elif dataset_args.dataset_name == "mathematics_dataset" and dataset_args.mode == "dataset_comparison": + logger.info("*** Comparing loss on individuals datasets ***") + + # interpolation + int_categories = [ + "algebra__linear_1d.txt", + "algebra__linear_1d_composed.txt", + "algebra__linear_2d.txt", + "algebra__linear_2d_composed.txt", + "algebra__sequence_next_term.txt", + "arithmetic__add_or_sub.txt", + "arithmetic__add_sub_multiple.txt", + "arithmetic__mul.txt", + "numbers__div_remainder.txt", + "numbers__div_remainder_composed.txt", + "numbers__place_value.txt", + "numbers__round_number.txt", + "numbers__round_number_composed.txt", + ] + int_results = [] + for name in int_categories: + path = 'data/mathematics_dataset-v1.0/mathematics_dataset-v1.0/interpolate/' + name + test_dataset = load_txt_dataset(path) + logger.info("*** Testing interpolation on " + name + " data ***") + result = trainer.evaluate(eval_dataset=test_dataset) + logger.info(f"Test results: {result}") + int_results.append(result) + + #extrapolation + ext_categories = [ + "arithmetic__add_or_sub_big.txt", + "arithmetic__add_sub_multiple_longer.txt", + "arithmetic__mixed_longer.txt", + "arithmetic__mul_big.txt", + "arithmetic__mul_div_multiple_longer.txt", + "numbers__place_value_big.txt", + "numbers__round_number_big.txt", + ] + ext_results = [] + for name in ext_categories: + path = 'data/mathematics_dataset-v1.0/mathematics_dataset-v1.0/extrapolate/' + name + test_dataset = load_txt_dataset(path) + logger.info("*** Testing extrapolation on " + name + " data ***") + result = trainer.evaluate(eval_dataset=test_dataset) + logger.info(f"Test results: {result}") + ext_results.append(result) + + return int_results, ext_results def get_data_collator(model_args: ModelArguments, tokenizer, training_args: TrainingArguments) -> transformers.DataCollator: diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 29665d6..a944362 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -61,6 +61,8 @@ def test_calculate_result_mse(self): median_absolute_error, log_mae, log_r2, + pearson, + spearman ) = self.metrics_xval.calculate_metrics(number_results, 5) expected_mae = np.mean([abs(-1.0 - -4.0), abs(34.452 - 34.452), abs(80 - 78), abs(-1 - 0)]) @@ -83,6 +85,10 @@ def test_calculate_result_mse(self): expected_count_not_produced_valid_results = 1 expected_average_count_not_produced_valid_results = 1/5 + expected_pearson = 0.9989029443838093 + expected_spearman = 0.9486832980505139 + + self.assertEqual(mae, expected_mae) self.assertEqual(mse, expected_mse) self.assertEqual(r2, expected_r2) @@ -92,6 +98,8 @@ def test_calculate_result_mse(self): self.assertEqual(median_absolute_error, expected_median_absolute_error) self.assertEqual(log_mae, expected_log_mae) self.assertEqual(log_r2, expected_log_r2) + self.assertEqual(pearson, expected_pearson) + self.assertEqual(spearman, expected_spearman) if __name__ == "__main__": From 0d8ccafdaebdbf841c8f18d7a12786cd6145f199 Mon Sep 17 00:00:00 2001 From: ad045 <82264509+ad045@users.noreply.github.com> Date: Wed, 15 Jan 2025 23:05:24 +0100 Subject: [PATCH 2/3] Fixed code: GaussianLabelSmoother handles non-number tokens correctly (#41) --- src/ntl/utils/label_smoother.py | 143 +++++++++++++++++++------------- 1 file changed, 84 insertions(+), 59 deletions(-) diff --git a/src/ntl/utils/label_smoother.py b/src/ntl/utils/label_smoother.py index bc0207f..b9bdc0c 100644 --- a/src/ntl/utils/label_smoother.py +++ b/src/ntl/utils/label_smoother.py @@ -8,7 +8,7 @@ @dataclass -class GaussianLabelSmoother(LabelSmoother): +class GaussianLabelSmoother: """ A label smoother that applies Gaussian smoothing ONLY to number tokens, as selected by `NumberTokenSelector`. Non-number tokens remain untouched or masked out. @@ -26,20 +26,30 @@ class GaussianLabelSmoother(LabelSmoother): sigma: float = 1.0 ignore_index: int = -100 selector: object = None # Instance of `NumberTokenSelector` + eps = 1e-8 # epsilon def __call__(self, model_output, labels: Tensor, shift_labels: bool = False) -> Tensor: """ Compute the Gaussian-smoothed cross-entropy loss. + + Parameters: + model_output: torch.Tensor or Dict[str, torch.Tensor] + The model output logits or a dictionary containing the logits. + labels: torch.Tensor of shape (batch_size, seq_len) + shift_labels: bool """ # Get logits from model output if isinstance(model_output, dict): - logits = model_output["logits"] + logits = model_output["logits"] # (batch_size, seq_len, voc_size) else: - logits = model_output[0] + logits = model_output[0] # (batch_size, seq_len, voc_size) # Handle empty logits or labels gracefully by returning zero loss if logits.numel() == 0 or labels.numel() == 0: - return torch.tensor(0.0, device=logits.device) + # Return a zero that still has grad_fn + print("requires_grad:", logits.requires_grad) + return logits.sum() * 0.0 + # Shift labels if needed if shift_labels: @@ -52,76 +62,91 @@ def __call__(self, model_output, labels: Tensor, shift_labels: bool = False) -> raise AttributeError("The selector must have an attribute 'nvocab' representing the number of valid vocab tokens.") # Select number tokens - logits, number_tokens = self.selector.select_number_tokens(logits) + number_logits, vocab_numbers_mask = self.selector.select_number_tokens(logits) # (batch_size, seq_len, num_classes_numbers) # Get the number of classes and the mask for number tokens - tokens_encoding_numbers = self.selector.nvocab[number_tokens] - num_classes = tokens_encoding_numbers.shape[0] - labels_mask = torch.isin(labels, tokens_encoding_numbers) + tokens_encoding_numbers = self.selector.nvocab[vocab_numbers_mask] + num_classes_numbers = tokens_encoding_numbers.shape[0] + labels_number_mask = torch.isin(labels, tokens_encoding_numbers) # (batch_size, seq_len) else: # If no selector is given, assume all are number tokens - labels_mask = torch.ones_like(labels, dtype=torch.bool) - num_classes = logits.size(-1) # Dynamic determination of num_classes - # raise Exception("A NumberTokenSelector needs to be provided to the GaussianLabelSmoother.") + labels_number_mask = torch.ones_like(labels, dtype=torch.bool) + num_classes_numbers = logits.size(-1) # Dynamic determination of num_classes_numbers - # Mask for valid number labels and non-padding tokens. Potentially unnecessary, as number labels certainly do not include the ignore_index. Added for safety. - valid_mask = (labels != self.ignore_index) & labels_mask + # All labels that are not self.ignore_index + valid_mask = (labels != self.ignore_index) # (batch_size, seq_len) - # Validation to ensure that labels are within the valid range [0, num_classes - 1] - valid_labels = (labels[valid_mask] >= 0) & (labels[valid_mask] < num_classes) - if not torch.all(valid_labels): - raise RuntimeError("Some labels are out of the valid range [0, num_classes - 1].") + if not valid_mask.any(): + # If no valid tokens are present, return zero loss that still has grad_fn + return logits.sum() * 0.0 - if self.sigma == 0.0: - # When sigma is zero, use one-hot labels directly without smoothing. - # To avoid F.one_hot error, all labels outside of valid_mask are set to 0 - safe_labels = labels.clone() - safe_labels = labels * valid_mask - labels_to_calculate_loss = F.one_hot(safe_labels, num_classes=num_classes).float() - - # Zero out the labels_to_calculate_loss where not valid - labels_to_calculate_loss = labels_to_calculate_loss * valid_mask.unsqueeze(-1) - - else: - # Check if there are any number tokens to smooth - if valid_mask.any(): - # Create a tensor of class indices - class_indices = torch.arange(num_classes, device=labels.device).view(1, 1, num_classes) # (1, 1, num_classes) - - # Expand labels to shape (batch_size, seq_length, 1). Cast to float32 if necessary - labels_expanded = labels.unsqueeze(-1).float() # (batch_size, seq_length, 1) - - # Gaussian distribution around each label index: - # Over [0..num_classes-1] for each label l_i: - # dist_j = exp(-((j - l_i)^2 / (2*sigma^2))) - - # Calculate the Gaussian probability for each class - gaussian = torch.exp(-0.5 * ((class_indices - labels_expanded) / self.sigma) ** 2) # (batch_size, num_outputs, num_classes) + # Mask for valid number labels and non-padding tokens. + number_mask = valid_mask * labels_number_mask # (batch_size, seq_len) # should not change anything, as labels_number_mask is already a subset of valid_mask + non_number_mask = valid_mask * ~labels_number_mask # (batch_size, seq_len) + + # Validation to ensure that labels are within the valid range [0, num_classes_numbers - 1] + if not torch.all((labels[number_mask] >= 0) & (labels[number_mask] < num_classes_numbers)): + print("min", labels[number_mask].min(), "max", labels[number_mask].max()) + raise RuntimeError("Some labels are out of the valid range [0, num_classes_numbers - 1].") + + # Compute log probabilities once for efficiency + log_probs = F.log_softmax(logits, dim=-1) # [B, S, C] + + # Initialize loss tensors + loss_numbers = torch.zeros_like(labels, dtype=logits.dtype, device=logits.device) # (batch_size, seq_len) + loss_non_numbers = torch.zeros_like(labels, dtype=logits.dtype, device=logits.device) # (batch_size, seq_len) + + # Compute loss for number tokens + if number_mask.any(): + if self.sigma == 0.0: + # When sigma is zero, use one-hot labels directly without smoothing. + # To avoid F.one_hot error, all labels outside of valid_mask are set to 0 + number_labels_filled = labels.clone() + number_labels_filled = labels.masked_fill(~number_mask, 0) # All non-number tokens are filled with zero + number_one_hot = F.one_hot(number_labels_filled, num_classes=num_classes_numbers).float() + number_one_hot = number_one_hot * number_mask.unsqueeze(-1) # Zero out non-number tokens + + # Compute the loss for number tokens + loss_numbers = -(number_one_hot * log_probs[..., :num_classes_numbers]).sum(dim=-1) - # Normalize to ensure each (batch, output) sums to 1 - gaussian_probs = gaussian / gaussian.sum(dim=2 , keepdim=True) # [B, S, C] - - # Apply the valid mask - labels_to_calculate_loss = gaussian_probs * valid_mask.unsqueeze(-1) + else: + # Gaussian smoothing for number tokens + # Create a tensor of class indices + class_indices = torch.arange(num_classes_numbers, device=labels.device).view(1, 1, num_classes_numbers) # (1, 1, num_classes_numbers) - else: - # If no valid tokens, set labels_to_calculate_loss to zero - labels_to_calculate_loss = torch.zeros_like(logits) + # Expand labels to shape (batch_size, seq_length, 1). Cast to float32 if necessary + labels_expanded = labels.unsqueeze(-1).float() # (batch_size, seq_length, 1) + # Compute Gaussian distribution around each label index + gaussian = torch.exp(-0.5 * ((class_indices - labels_expanded) / self.sigma) ** 2) # (batch_size, seq_len//number_outputs, num_classes_numbers) - # Compute cross-entropy using smoothed label distribution - log_probs = F.log_softmax(logits, dim=-1) # shape [B, S, C] - loss_per_token = -(labels_to_calculate_loss * log_probs).sum(dim=-1) # distribution = - sum_{j} (smoothed_label_j * log_probs_j) + # Normalize to ensure each (batch, output) sums to 1. Prevent division by zero + gaussian_probs = gaussian / (gaussian.sum(dim=2, keepdim=True) + self.eps) + + # Apply mask to Gaussian probabilities + gaussian_probs = gaussian_probs * number_mask.unsqueeze(-1) # Zero out non-number tokens + + # Compute the loss for number tokens + loss_numbers = -(gaussian_probs * log_probs[..., :num_classes_numbers]).sum(dim=-1) # (batch_size, seq_len) - # Average across the valid tokens. Also works in the case that num_valid == 0. - # Invalid positions are replaced with zero, ensuring that the tensor remains connected to the graph - loss_per_token = torch.where(valid_mask, loss_per_token, torch.zeros_like(loss_per_token)) + # Compute loss for non-number tokens + if non_number_mask.any(): + # One-hot encoding for non-number tokens + non_number_labels_filled = labels.clone() + non_number_labels_filled = non_number_labels_filled.masked_fill(~non_number_mask, 0) # Fill non-valid tokens with 0 # (batch_size, seq_len) + one_hot_non_num = F.one_hot(non_number_labels_filled, num_classes=logits.size(-1)).float() + one_hot_non_num = one_hot_non_num * non_number_mask.unsqueeze(-1).expand(-1, -1, one_hot_non_num.size(-1)) # non_number_mask.unsqueeze(-1) # Zero out non-number tokens + + # Compute the loss for non-number tokens + loss_non_numbers = -(one_hot_non_num * log_probs).sum(dim=-1) + + # Combine the two losses into a single tensor + loss_per_token = torch.where(number_mask, loss_numbers, loss_non_numbers) # (batch_size, seq_len) + + # Average across the valid tokens. num_valid = valid_mask.sum().float() loss = loss_per_token.sum() / torch.clamp(num_valid, min=1.0) return loss - - - \ No newline at end of file From c0f15672b3a8c7036f8f0f79d2a2a2d461436861 Mon Sep 17 00:00:00 2001 From: Larspennig <116651828+Larspennig@users.noreply.github.com> Date: Thu, 16 Jan 2025 11:18:28 +0100 Subject: [PATCH 3/3] Add files via upload --- LICENSE | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..73dff59 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 TUM.ai + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file