Skip to content

Commit

Permalink
Gce (#38)
Browse files Browse the repository at this point in the history
* Refractored number token loss so get a number token selector

* Added first implementation of gaussian label smoother

* Added test cases for the label smoother

* Small change to args

* added changes to number_token_loss and label_smoother

* added number_token_selector.py

* added init

* changed trainer

* Added comments, and valid mask before one-hot encoding

* added a case differenciation for sigma=0

* For label_smoother: Fixed gradient flow for no valid number tokens case. Added more extensive testing of the gaussian smoother.

* Fixed bug in test file

* nvocab fix

* Bigger commit, will probably fail the gce tests

* Cleaned up

* Cleaned up

* Cosmetic fix: Removed unused labels

* Cosmetic fix: Removed unused labels

* Fix: Removed comment and unused selector

* Fix: Removed unused labels

---------

Co-authored-by: ad045 <[email protected]>
  • Loading branch information
zausin33 and ad045 authored Jan 9, 2025
1 parent e7f8108 commit d501e1b
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/ntl/loss_functions/number_token_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def forward(self, logits: Tensor, labels: Tensor):
if labels.numel() == 0:
raise ValueError("Labels passed to the NumberTokenLoss are empty!")

logits, labels, number_tokens = self.selector.select_number_tokens(logits, labels)
logits, number_tokens = self.selector.select_number_tokens(logits)

# Compute the weighted average of number tokens (yhat)
softmaxed = F.softmax(logits, dim=-1)
Expand Down
2 changes: 0 additions & 2 deletions src/ntl/run_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg
selector=selector
)
else:
selector = None
label_smoother = None

if model_args.model_name_or_path:
Expand Down Expand Up @@ -364,7 +363,6 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
# selector=selector,
label_smoother=label_smoother,
# callbacks=[early_stopping_callback],
compute_metrics=custom_metrics,
Expand Down
2 changes: 1 addition & 1 deletion src/ntl/utils/label_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(self, model_output, labels: Tensor, shift_labels: bool = False) ->
raise AttributeError("The selector must have an attribute 'nvocab' representing the number of valid vocab tokens.")

# Select number tokens
logits, labels, number_tokens = self.selector.select_number_tokens(logits, labels)
logits, number_tokens = self.selector.select_number_tokens(logits)

# Get the number of classes and the mask for number tokens
tokens_encoding_numbers = self.selector.nvocab[number_tokens]
Expand Down
8 changes: 2 additions & 6 deletions src/ntl/utils/number_token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@ def __init__(self, tokenizer: NumberEncodingTokenizer, vocab_size, device): # nv
if token in hashed_num_tokens:
self.nvocab[id] = self.tokenizer.decode_number_token(token, ignore_order=True)


def select_number_tokens(self, logits: Tensor, labels: Tensor):

def select_number_tokens(self, logits: Tensor):
# Create a mask to filter out non-digit tokens and labels
number_tokens = ~torch.isnan(self.nvocab)
logits = logits[:, :, number_tokens]
# labels = labels.masked_fill(labels == -100, 0)

return logits, labels, number_tokens
return logits, number_tokens


0 comments on commit d501e1b

Please sign in to comment.