Skip to content

Commit

Permalink
Merge branch 'main' into ablation_studies
Browse files Browse the repository at this point in the history
# Conflicts:
#	config/run_specific_config/config.yaml
#	src/ntl/run_language_modeling.py
#	tests/test_evaluation.py
  • Loading branch information
Jonas Zausinger committed Jan 16, 2025
2 parents c937db6 + c0f1567 commit b72e707
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 63 deletions.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2025 TUM.ai

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
3 changes: 2 additions & 1 deletion config/model_args/vanilla_t5_ntl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ config_name: t5-base
number_encoding: none
number_token_loss: true
number_token_loss_weight: 0.3
number_token_loss_with_wasserstein: false
number_token_loss_with_wasserstein: false
#number_token_loss_function:
2 changes: 1 addition & 1 deletion config/run_specific_config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
training_args:
trial: ablation_studies
trial:
special_name:
max_steps: 2500000
load_best_model_at_end: false
Expand Down
2 changes: 2 additions & 0 deletions src/ntl/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")




class CustomMetrics:
"""
Compute custom metrics for the model with access to the vocab to compute MSE
Expand Down
143 changes: 84 additions & 59 deletions src/ntl/utils/label_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@dataclass
class GaussianLabelSmoother(LabelSmoother):
class GaussianLabelSmoother:
"""
A label smoother that applies Gaussian smoothing ONLY to number tokens, as
selected by `NumberTokenSelector`. Non-number tokens remain untouched or masked out.
Expand All @@ -26,20 +26,30 @@ class GaussianLabelSmoother(LabelSmoother):
sigma: float = 1.0
ignore_index: int = -100
selector: object = None # Instance of `NumberTokenSelector`
eps = 1e-8 # epsilon

def __call__(self, model_output, labels: Tensor, shift_labels: bool = False) -> Tensor:
"""
Compute the Gaussian-smoothed cross-entropy loss.
Parameters:
model_output: torch.Tensor or Dict[str, torch.Tensor]
The model output logits or a dictionary containing the logits.
labels: torch.Tensor of shape (batch_size, seq_len)
shift_labels: bool
"""
# Get logits from model output
if isinstance(model_output, dict):
logits = model_output["logits"]
logits = model_output["logits"] # (batch_size, seq_len, voc_size)
else:
logits = model_output[0]
logits = model_output[0] # (batch_size, seq_len, voc_size)

# Handle empty logits or labels gracefully by returning zero loss
if logits.numel() == 0 or labels.numel() == 0:
return torch.tensor(0.0, device=logits.device)
# Return a zero that still has grad_fn
print("requires_grad:", logits.requires_grad)
return logits.sum() * 0.0


# Shift labels if needed
if shift_labels:
Expand All @@ -52,76 +62,91 @@ def __call__(self, model_output, labels: Tensor, shift_labels: bool = False) ->
raise AttributeError("The selector must have an attribute 'nvocab' representing the number of valid vocab tokens.")

# Select number tokens
logits, number_tokens = self.selector.select_number_tokens(logits)
number_logits, vocab_numbers_mask = self.selector.select_number_tokens(logits) # (batch_size, seq_len, num_classes_numbers)

# Get the number of classes and the mask for number tokens
tokens_encoding_numbers = self.selector.nvocab[number_tokens]
num_classes = tokens_encoding_numbers.shape[0]
labels_mask = torch.isin(labels, tokens_encoding_numbers)
tokens_encoding_numbers = self.selector.nvocab[vocab_numbers_mask]
num_classes_numbers = tokens_encoding_numbers.shape[0]
labels_number_mask = torch.isin(labels, tokens_encoding_numbers) # (batch_size, seq_len)

else:
# If no selector is given, assume all are number tokens
labels_mask = torch.ones_like(labels, dtype=torch.bool)
num_classes = logits.size(-1) # Dynamic determination of num_classes
# raise Exception("A NumberTokenSelector needs to be provided to the GaussianLabelSmoother.")
labels_number_mask = torch.ones_like(labels, dtype=torch.bool)
num_classes_numbers = logits.size(-1) # Dynamic determination of num_classes_numbers

# Mask for valid number labels and non-padding tokens. Potentially unnecessary, as number labels certainly do not include the ignore_index. Added for safety.
valid_mask = (labels != self.ignore_index) & labels_mask
# All labels that are not self.ignore_index
valid_mask = (labels != self.ignore_index) # (batch_size, seq_len)

# Validation to ensure that labels are within the valid range [0, num_classes - 1]
valid_labels = (labels[valid_mask] >= 0) & (labels[valid_mask] < num_classes)
if not torch.all(valid_labels):
raise RuntimeError("Some labels are out of the valid range [0, num_classes - 1].")
if not valid_mask.any():
# If no valid tokens are present, return zero loss that still has grad_fn
return logits.sum() * 0.0

if self.sigma == 0.0:
# When sigma is zero, use one-hot labels directly without smoothing.
# To avoid F.one_hot error, all labels outside of valid_mask are set to 0
safe_labels = labels.clone()
safe_labels = labels * valid_mask
labels_to_calculate_loss = F.one_hot(safe_labels, num_classes=num_classes).float()

# Zero out the labels_to_calculate_loss where not valid
labels_to_calculate_loss = labels_to_calculate_loss * valid_mask.unsqueeze(-1)

else:
# Check if there are any number tokens to smooth
if valid_mask.any():
# Create a tensor of class indices
class_indices = torch.arange(num_classes, device=labels.device).view(1, 1, num_classes) # (1, 1, num_classes)

# Expand labels to shape (batch_size, seq_length, 1). Cast to float32 if necessary
labels_expanded = labels.unsqueeze(-1).float() # (batch_size, seq_length, 1)

# Gaussian distribution around each label index:
# Over [0..num_classes-1] for each label l_i:
# dist_j = exp(-((j - l_i)^2 / (2*sigma^2)))

# Calculate the Gaussian probability for each class
gaussian = torch.exp(-0.5 * ((class_indices - labels_expanded) / self.sigma) ** 2) # (batch_size, num_outputs, num_classes)
# Mask for valid number labels and non-padding tokens.
number_mask = valid_mask * labels_number_mask # (batch_size, seq_len) # should not change anything, as labels_number_mask is already a subset of valid_mask
non_number_mask = valid_mask * ~labels_number_mask # (batch_size, seq_len)

# Validation to ensure that labels are within the valid range [0, num_classes_numbers - 1]
if not torch.all((labels[number_mask] >= 0) & (labels[number_mask] < num_classes_numbers)):
print("min", labels[number_mask].min(), "max", labels[number_mask].max())
raise RuntimeError("Some labels are out of the valid range [0, num_classes_numbers - 1].")

# Compute log probabilities once for efficiency
log_probs = F.log_softmax(logits, dim=-1) # [B, S, C]

# Initialize loss tensors
loss_numbers = torch.zeros_like(labels, dtype=logits.dtype, device=logits.device) # (batch_size, seq_len)
loss_non_numbers = torch.zeros_like(labels, dtype=logits.dtype, device=logits.device) # (batch_size, seq_len)

# Compute loss for number tokens
if number_mask.any():
if self.sigma == 0.0:
# When sigma is zero, use one-hot labels directly without smoothing.
# To avoid F.one_hot error, all labels outside of valid_mask are set to 0
number_labels_filled = labels.clone()
number_labels_filled = labels.masked_fill(~number_mask, 0) # All non-number tokens are filled with zero
number_one_hot = F.one_hot(number_labels_filled, num_classes=num_classes_numbers).float()
number_one_hot = number_one_hot * number_mask.unsqueeze(-1) # Zero out non-number tokens

# Compute the loss for number tokens
loss_numbers = -(number_one_hot * log_probs[..., :num_classes_numbers]).sum(dim=-1)

# Normalize to ensure each (batch, output) sums to 1
gaussian_probs = gaussian / gaussian.sum(dim=2 , keepdim=True) # [B, S, C]

# Apply the valid mask
labels_to_calculate_loss = gaussian_probs * valid_mask.unsqueeze(-1)
else:
# Gaussian smoothing for number tokens
# Create a tensor of class indices
class_indices = torch.arange(num_classes_numbers, device=labels.device).view(1, 1, num_classes_numbers) # (1, 1, num_classes_numbers)

else:
# If no valid tokens, set labels_to_calculate_loss to zero
labels_to_calculate_loss = torch.zeros_like(logits)
# Expand labels to shape (batch_size, seq_length, 1). Cast to float32 if necessary
labels_expanded = labels.unsqueeze(-1).float() # (batch_size, seq_length, 1)

# Compute Gaussian distribution around each label index
gaussian = torch.exp(-0.5 * ((class_indices - labels_expanded) / self.sigma) ** 2) # (batch_size, seq_len//number_outputs, num_classes_numbers)

# Compute cross-entropy using smoothed label distribution
log_probs = F.log_softmax(logits, dim=-1) # shape [B, S, C]
loss_per_token = -(labels_to_calculate_loss * log_probs).sum(dim=-1) # distribution = - sum_{j} (smoothed_label_j * log_probs_j)
# Normalize to ensure each (batch, output) sums to 1. Prevent division by zero
gaussian_probs = gaussian / (gaussian.sum(dim=2, keepdim=True) + self.eps)

# Apply mask to Gaussian probabilities
gaussian_probs = gaussian_probs * number_mask.unsqueeze(-1) # Zero out non-number tokens

# Compute the loss for number tokens
loss_numbers = -(gaussian_probs * log_probs[..., :num_classes_numbers]).sum(dim=-1) # (batch_size, seq_len)

# Average across the valid tokens. Also works in the case that num_valid == 0.
# Invalid positions are replaced with zero, ensuring that the tensor remains connected to the graph
loss_per_token = torch.where(valid_mask, loss_per_token, torch.zeros_like(loss_per_token))
# Compute loss for non-number tokens
if non_number_mask.any():
# One-hot encoding for non-number tokens
non_number_labels_filled = labels.clone()
non_number_labels_filled = non_number_labels_filled.masked_fill(~non_number_mask, 0) # Fill non-valid tokens with 0 # (batch_size, seq_len)
one_hot_non_num = F.one_hot(non_number_labels_filled, num_classes=logits.size(-1)).float()
one_hot_non_num = one_hot_non_num * non_number_mask.unsqueeze(-1).expand(-1, -1, one_hot_non_num.size(-1)) # non_number_mask.unsqueeze(-1) # Zero out non-number tokens

# Compute the loss for non-number tokens
loss_non_numbers = -(one_hot_non_num * log_probs).sum(dim=-1)

# Combine the two losses into a single tensor
loss_per_token = torch.where(number_mask, loss_numbers, loss_non_numbers) # (batch_size, seq_len)

# Average across the valid tokens.
num_valid = valid_mask.sum().float()
loss = loss_per_token.sum() / torch.clamp(num_valid, min=1.0)

return loss




2 changes: 0 additions & 2 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def test_calculate_result_mse(self):
self.assertEqual(median_absolute_error, expected_median_absolute_error)
self.assertEqual(log_mae, expected_log_mae)
self.assertEqual(log_r2, expected_log_r2)
self.assertEqual(pearson, expected_pearson)
self.assertEqual(spearman, expected_spearman)


if __name__ == "__main__":
Expand Down

0 comments on commit b72e707

Please sign in to comment.