Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added gaussian label smoother for number tokens #34

Closed
wants to merge 23 commits into from
Closed

Conversation

ad045
Copy link
Collaborator

@ad045 ad045 commented Jan 6, 2025

Description:

This pull request introduces a label smoothing technique that applies Gaussian smoothing exclusively to number tokens. The key changes are outlined below:

  • Gaussian Label Smoother:

    • Implemented the GaussianLabelSmoother class in src/ntl/utils/label_smoother.py.
    • Controlled via two new ModelArguments: gaussian_label_smoother and label_smoother_sigma.
  • Number Token Selector:

    • Extracted the token selection logic from NumberTokenLoss into a new NumberTokenSelector class located in src/ntl/utils/number_token_selector.py.
    • Utilized by both GaussianLabelSmoother and NumberTokenLoss to consistently select number tokens.
  • Other Changes:

    • Updated src/ntl/run_language_modeling.py and src/ntl/trainer.py to support the new label smoothing feature.
    • Modified src/ntl/loss_functions/number_token_loss.py to integrate with the new NumberTokenSelector.
    • Updated src/ntl/args.py to include new arguments for the Gaussian Label Smoother.
    • Added unit tests to ensure the correctness of the GaussianLabelSmoother and NumberTokenSelector.

Types of Changes

  • New feature (non-breaking change which adds functionality)
  • Refactor (non-breaking change which improves the code structure)

@ad045 ad045 marked this pull request as ready for review January 6, 2025 21:07
@jannisborn jannisborn mentioned this pull request Jan 7, 2025
Merged
* Refractored number token loss so get a number token selector

* Added first implementation of gaussian label smoother

* Added test cases for the label smoother

* Small change to args

* added changes to number_token_loss and label_smoother

* added number_token_selector.py

* added init

* changed trainer

* Added comments, and valid mask before one-hot encoding

* added a case differenciation for sigma=0

* For label_smoother: Fixed gradient flow for no valid number tokens case. Added more extensive testing of the gaussian smoother.

* Fixed bug in test file

* nvocab fix

* Bigger commit, will probably fail the gce tests

* Cleaned up

* Cleaned up

---------

Co-authored-by: ad045 <[email protected]>
Copy link
Collaborator

@jannisborn jannisborn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @ad045 ! Seems very flexible and easy to use. Super solid job on the test suite 👍🏼

I left some cosmetics comments where unnecessary things can be removed, but conceptually we are ready to merge already now!

@@ -349,6 +364,8 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
# selector=selector,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comment

selector=selector
)
else:
selector = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove selector since it is unused downstream

self.nvocab[id] = self.tokenizer.decode_number_token(token, ignore_order=True)


def select_number_tokens(self, logits: Tensor, labels: Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Labels are not used in here so I would remove them from the signature

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and remove them from the return statement below then obviously

raise AttributeError("The selector must have an attribute 'nvocab' representing the number of valid vocab tokens.")

# Select number tokens
logits, labels, number_tokens = self.selector.select_number_tokens(logits, labels)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove labels here

# Create a mask to filter out non-digit tokens
number_tokens = ~torch.isnan(self.nvocab)
logits = logits[:, :, number_tokens]
logits, labels, number_tokens = self.selector.select_number_tokens(logits, labels)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove labels here also

@zausin33 zausin33 closed this Jan 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants