Skip to content

Commit

Permalink
added vanilla_t5_regression_head
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Zausinger committed Oct 6, 2024
1 parent 96c35a0 commit 8cf9fb1
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 61 deletions.
3 changes: 3 additions & 0 deletions config/model_args/vanilla_t5_regression_head.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
name: vanilla_t5_regression_head
config_name: t5-base
number_encoding: none_regression_head
2 changes: 1 addition & 1 deletion src/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class ModelArguments:
number_encoding: Optional[str] = field(
default="rt",
metadata={
"help": "Chose either xval or rt or None for number encodings"
"help": "Choose either xval or rt or None, or none_regression_head for number encodings"
},
)
number_token_loss: Optional[bool] = field(
Expand Down
39 changes: 39 additions & 0 deletions src/collators/regression_head_question_answer_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Dict, List, Union
import re
import torch
from transformers import DataCollatorForLanguageModeling


class RegressionHeadQuestionAnswerCLMCollator(DataCollatorForLanguageModeling):
def __init__(self, tokenizer):
super().__init__(tokenizer, mlm=False)
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id

def __call__(self, examples: List[Dict[str, Union[str, List[int]]]]) -> Dict[str, torch.Tensor]:
# Tokenize questions and answers separately
questions = [example['question'] for example in examples]
answers = [example['answer'] for example in examples]

question_encodings = self.tokenizer(questions, padding=True, truncation=True, return_tensors="pt")

answer_numbers = []

for answer in answers:
answer_number = re.findall(r"\s*([+-]?\s*(\d+)(\.\d+)?)", answer)
if not answer_number or len(answer_number) == 0:
raise ValueError(f"Answer: {answer} does not contain any number")
if len(answer_number) > 1:
raise ValueError(f"Answer: {answer} contains more than one number")
answer_numbers.append(float(answer_number[0][0]))

answer_numbers = torch.tensor(answer_numbers, dtype=torch.float32).unsqueeze(1).to(question_encodings['input_ids'].device)

input_ids = question_encodings['input_ids']
attention_mask = question_encodings['attention_mask']

return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': answer_numbers
}
9 changes: 0 additions & 9 deletions src/collators/rt_question_answer_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,8 @@ def __call__(self, examples: List[Dict[str, Union[str, List[int]]]]) -> Dict[str
labels = answer_input_ids.clone()
labels[labels == self.tokenizer.pad_token_id] = -100

# Generate number_labels for easier evaluation
number_token_ids = self.tokenizer.get_num_token_ids()
label_num_mask = torch.isin(answer_input_ids, torch.tensor(number_token_ids, dtype=torch.long, device=answer_input_ids.device))
tokens = self.tokenizer.convert_ids_to_tokens(answer_input_ids[label_num_mask])
number_values = torch.tensor([encoding_to_number(token) for token in tokens], dtype=torch.float, device=answer_input_ids.device)
number_labels = torch.zeros_like(answer_input_ids, dtype=torch.float)
number_labels[label_num_mask] = number_values

return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
"number_labels": number_labels
}
118 changes: 74 additions & 44 deletions src/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,48 +157,43 @@ def __call__(self, pred: EvalPrediction, compute_result: bool) -> Dict[str, floa
Overall results if compute_result else None
"""
if not self.number_encoding.lower() in ["xval", "rt", "none"]:
if not self.number_encoding.lower() in ["xval", "rt", "none", "none_regression_head"]:
raise NotImplementedError(
f"Requesting evaluation for not supported number_encoding: {self.number_encoding}")

# Extract predictions and labels from pred tuple
model_output, labels = pred
logits, predictions = model_output

if self.number_encoding in ["xval", "rt"]:
if self.number_encoding == "xval":
token_labels, number_labels = labels

else:
token_labels, number_labels = labels, None
token_labels = labels
number_labels = None

# replace -100 with padding token
token_labels_for_decoding = token_labels.clone()
token_labels_for_decoding[token_labels_for_decoding == -100] = self.tokenizer.pad_token_id
if self.number_encoding != "none_regression_head":
# replace -100 with padding token
token_labels_for_decoding = token_labels.clone()
token_labels_for_decoding[token_labels_for_decoding == -100] = self.tokenizer.pad_token_id

if self.number_encoding == "xval":
predictions, predicted_numbers = predictions
decoded_preds, count_invalid_number_prediction, count_no_number_prediction = self.tokenizer.decode_into_human_readable(
predictions, predicted_numbers)
decoded_labels, sanity_invalid_number_prediction, sanity_no_number_prediction = self.tokenizer.decode_into_human_readable(
token_labels_for_decoding, number_labels)
(
count_invalid_number_prediction, count_no_number_prediction,
decoded_labels,
decoded_preds,
predictions,
sanity_invalid_number_prediction,
sanity_no_number_prediction
) = self._decode_preds_and_labels(number_labels, predictions, token_labels_for_decoding)

# We should never observe invalid numbers and mostly likely never no number for gt
if max(sanity_invalid_number_prediction, sanity_no_number_prediction) > 0:
print(sanity_invalid_number_prediction)
print(sanity_no_number_prediction)
else:
if hasattr(self.tokenizer, "decode_into_human_readable"):
decoded_preds, count_invalid_number_prediction, count_no_number_prediction = self.tokenizer.decode_into_human_readable(
predictions)
decoded_labels, sanity_invalid_number_prediction, sanity_no_number_prediction = self.tokenizer.decode_into_human_readable(
token_labels_for_decoding)
else:
decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True)
count_invalid_number_prediction, count_no_number_prediction = check_number_predictions(
decoded_preds)
decoded_labels = self.tokenizer.batch_decode(token_labels_for_decoding, skip_special_tokens=True)
sanity_invalid_number_prediction, sanity_no_number_prediction = check_number_predictions(
decoded_labels)

# We should never observe invalid numbers and mostly likely never no number for gt
if max(sanity_invalid_number_prediction, sanity_no_number_prediction) > 0:
print(sanity_invalid_number_prediction)
print(sanity_no_number_prediction)
decoded_labels = [str("{0:.12f}".format(label).rstrip('0').rstrip('.')) for label in labels.squeeze(-1).tolist()]
decoded_preds = [str("{0:.12f}".format(logit).rstrip('0').rstrip('.')) for logit in logits.squeeze(-1).tolist()]
count_invalid_number_prediction = 0
count_no_number_prediction = 0

if compute_result or self.save_all_output:
# save decoded predictions and labels for debugging
Expand All @@ -209,26 +204,31 @@ def __call__(self, pred: EvalPrediction, compute_result: bool) -> Dict[str, floa
if compute_result:
self.eval_count += 1

# compute perplexity
perplexity_value = self.perplexity(logits, token_labels[:, :logits.size(1)])
if self.number_encoding != "none_regression_head":
# compute perplexity
perplexity_value = self.perplexity(logits, token_labels[:, :logits.size(1)])

bleu = self.compute_bleu(decoded_preds, decoded_labels)
rouge = self.compute_rouge(decoded_preds, decoded_labels)
# Mask to ignore panumeric_tokening tokens (-100)
mask = token_labels != PADDING_TOKEN

# Mask to ignore panumeric_tokening tokens (-100)
mask = token_labels != PADDING_TOKEN
# Apply mask to predictions and labels
masked_predictions = torch.where(mask, predictions, MASKED_OUT)
masked_labels = torch.where(mask, token_labels, MASKED_OUT)

# Apply mask to predictions and labels
masked_predictions = torch.where(mask, predictions, MASKED_OUT)
masked_labels = torch.where(mask, token_labels, MASKED_OUT)
# compute whole number accuracy and token accuracy
correct_predictions_w = torch.all(masked_predictions == masked_labels, dim=1)
accuracy_w = torch.mean(correct_predictions_w.float()).item()
correct_predictions = (predictions == token_labels) & mask
accuracy = (torch.sum(correct_predictions) / torch.sum(mask)).item() if torch.sum(mask) > 0 else 0
else:
perplexity_value = 0
accuracy_w = 0
accuracy = 0

# compute whole number accuracy and token accuracy
correct_predictions_w = torch.all(masked_predictions == masked_labels, dim=1)
accuracy_w = torch.mean(correct_predictions_w.float()).item()
correct_predictions = (predictions == token_labels) & mask
accuracy = (torch.sum(correct_predictions) / torch.sum(mask)).item() if torch.sum(mask) > 0 else 0


bleu = self.compute_bleu(decoded_preds, decoded_labels)
rouge = self.compute_rouge(decoded_preds, decoded_labels)

number_results = self.parse_number_result(decoded_preds, decoded_labels)

Expand Down Expand Up @@ -289,3 +289,33 @@ def __call__(self, pred: EvalPrediction, compute_result: bool) -> Dict[str, floa
}
self.batch_stats = []
return computed_metrics

def _decode_preds_and_labels(self, number_labels, predictions, token_labels_for_decoding):
if self.number_encoding == "xval":
predictions, predicted_numbers = predictions
decoded_preds, count_invalid_number_prediction, count_no_number_prediction \
= self.tokenizer.decode_into_human_readable(predictions, predicted_numbers)
decoded_labels, sanity_invalid_number_prediction, sanity_no_number_prediction \
= self.tokenizer.decode_into_human_readable(token_labels_for_decoding, number_labels)
else:
if hasattr(self.tokenizer, "decode_into_human_readable"):
decoded_preds, count_invalid_number_prediction, count_no_number_prediction \
= self.tokenizer.decode_into_human_readable(predictions)
decoded_labels, sanity_invalid_number_prediction, sanity_no_number_prediction \
= self.tokenizer.decode_into_human_readable(token_labels_for_decoding)
else:
decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=True)
count_invalid_number_prediction, count_no_number_prediction = check_number_predictions(
decoded_preds)
decoded_labels = self.tokenizer.batch_decode(token_labels_for_decoding, skip_special_tokens=True)
sanity_invalid_number_prediction, sanity_no_number_prediction = check_number_predictions(
decoded_labels)
return (
count_invalid_number_prediction,
count_no_number_prediction,
decoded_labels,
decoded_preds,
predictions,
sanity_invalid_number_prediction,
sanity_no_number_prediction
)
20 changes: 16 additions & 4 deletions src/run_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys

from src.args import ModelArguments, TrainingArguments, DatasetArguments
from src.collators.regression_head_question_answer_collator import RegressionHeadQuestionAnswerCLMCollator

sys.path.append(".")

Expand All @@ -19,8 +20,6 @@
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from dataclasses import dataclass, field
from typing import Optional

import numpy as np
import pandas as pd
Expand All @@ -32,7 +31,7 @@
MODEL_WITH_LM_HEAD_MAPPING,
AutoConfig,
set_seed,
EarlyStoppingCallback, T5ForConditionalGeneration, Seq2SeqTrainingArguments
EarlyStoppingCallback, T5ForConditionalGeneration, T5ForSequenceClassification, Trainer
)

from src.data.data import load_txt_dataset, load_json_dataset
Expand Down Expand Up @@ -156,6 +155,8 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg
model_params = config.__dict__
logger.warning("You are instantiating a new config instance from scratch.")

trainer_class = CustomSeq2SeqTrainer

if model_args.number_encoding == "rt":
model_class = T5RegressionModelRT
tokenizer_class = RtTokenizer
Expand All @@ -169,6 +170,11 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg
else:
model_class = T5ForConditionalGeneration
tokenizer_class = transformers.AutoTokenizer
elif model_args.number_encoding.lower() == "none_regression_head":
trainer_class = Trainer
config.num_labels = 1
model_class = T5ForSequenceClassification
tokenizer_class = transformers.AutoTokenizer
else:
raise ValueError(f"Unknown number encoding: {model_args.number_encoding}")

Expand Down Expand Up @@ -331,6 +337,10 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg
data_collator = VanillaQuestionAnswerCLMCollator(
tokenizer=tokenizer
)
elif model_args.number_encoding.lower() == "none_regression_head":
data_collator = RegressionHeadQuestionAnswerCLMCollator(
tokenizer=tokenizer
)

# Custom Metric
custom_metrics = CustomMetrics(
Expand All @@ -347,8 +357,10 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg

# custom_trainer_params = get_trainer_dict(model_params)

logger.info("Trainer class: %s", trainer_class)

# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(
trainer = trainer_class(
model=model,
args=training_args,
data_collator=data_collator,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_run_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_model_training(self, mock_load_txt_dataset_fn, mock_load_json_dataset_f
mock_load_json_dataset_fn.side_effect = self.mock_load_json_dataset
mock_load_txt_dataset_fn.side_effect = self.mock_load_txt_dataset

number_encodings = ["rt", "xval", "none"]
number_encodings = ["rt", "xval", "none", "none_regression_head"]
number_token_losses = [True, False]
log_scale_embeddings_options = [True, False]
model_names_or_paths = [None, "google-t5/t5-small"]
Expand All @@ -89,13 +89,13 @@ def test_model_training(self, mock_load_txt_dataset_fn, mock_load_json_dataset_f
for xval_bigger_language_head in xval_bigger_language_heads:

# Skip invalid combinations
if number_encoding == "xval" and number_token_loss:
if number_encoding in ["xval", "none_regression_head"] and number_token_loss:
continue # NumberTokenLoss is only applicable when number_encoding is not 'xval'

if number_encoding != "xval" and xval_bigger_language_head:
continue

if number_encoding == "none" and log_scale_embeddings:
if number_encoding in ["none", "none_regression_head"] and log_scale_embeddings:
continue # Log scaling is only applicable for 'rt' and 'xval' encodings

checkpoint_dir = os.path.join(self.temp_dir, "checkpoint-10")
Expand Down

0 comments on commit 8cf9fb1

Please sign in to comment.