-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added gaussian label smoother for number tokens #34
Conversation
…se. Added more extensive testing of the gaussian smoother.
* Refractored number token loss so get a number token selector * Added first implementation of gaussian label smoother * Added test cases for the label smoother * Small change to args * added changes to number_token_loss and label_smoother * added number_token_selector.py * added init * changed trainer * Added comments, and valid mask before one-hot encoding * added a case differenciation for sigma=0 * For label_smoother: Fixed gradient flow for no valid number tokens case. Added more extensive testing of the gaussian smoother. * Fixed bug in test file * nvocab fix * Bigger commit, will probably fail the gce tests * Cleaned up * Cleaned up --------- Co-authored-by: ad045 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @ad045 ! Seems very flexible and easy to use. Super solid job on the test suite 👍🏼
I left some cosmetics comments where unnecessary things can be removed, but conceptually we are ready to merge already now!
src/ntl/run_language_modeling.py
Outdated
@@ -349,6 +364,8 @@ def run_language_modeling(model_args: ModelArguments, training_args: TrainingArg | |||
train_dataset=train_dataset, | |||
eval_dataset=eval_dataset, | |||
tokenizer=tokenizer, | |||
# selector=selector, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove comment
src/ntl/run_language_modeling.py
Outdated
selector=selector | ||
) | ||
else: | ||
selector = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove selector since it is unused downstream
self.nvocab[id] = self.tokenizer.decode_number_token(token, ignore_order=True) | ||
|
||
|
||
def select_number_tokens(self, logits: Tensor, labels: Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Labels are not used in here so I would remove them from the signature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and remove them from the return statement below then obviously
src/ntl/utils/label_smoother.py
Outdated
raise AttributeError("The selector must have an attribute 'nvocab' representing the number of valid vocab tokens.") | ||
|
||
# Select number tokens | ||
logits, labels, number_tokens = self.selector.select_number_tokens(logits, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove labels here
# Create a mask to filter out non-digit tokens | ||
number_tokens = ~torch.isnan(self.nvocab) | ||
logits = logits[:, :, number_tokens] | ||
logits, labels, number_tokens = self.selector.select_number_tokens(logits, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove labels here also
Description:
This pull request introduces a label smoothing technique that applies Gaussian smoothing exclusively to number tokens. The key changes are outlined below:
Gaussian Label Smoother:
GaussianLabelSmoother
class insrc/ntl/utils/label_smoother.py
.ModelArguments
:gaussian_label_smoother
andlabel_smoother_sigma
.Number Token Selector:
NumberTokenLoss
into a newNumberTokenSelector
class located insrc/ntl/utils/number_token_selector.py
.GaussianLabelSmoother
andNumberTokenLoss
to consistently select number tokens.Other Changes:
src/ntl/run_language_modeling.py
andsrc/ntl/trainer.py
to support the new label smoothing feature.src/ntl/loss_functions/number_token_loss.py
to integrate with the newNumberTokenSelector
.src/ntl/args.py
to include new arguments for the Gaussian Label Smoother.GaussianLabelSmoother
andNumberTokenSelector
.Types of Changes