Skip to content

Commit

Permalink
Fixed code: GaussianLabelSmoother handles non-number tokens correctly (
Browse files Browse the repository at this point in the history
  • Loading branch information
ad045 authored Jan 15, 2025
1 parent 75e8565 commit 0d8ccaf
Showing 1 changed file with 84 additions and 59 deletions.
143 changes: 84 additions & 59 deletions src/ntl/utils/label_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@dataclass
class GaussianLabelSmoother(LabelSmoother):
class GaussianLabelSmoother:
"""
A label smoother that applies Gaussian smoothing ONLY to number tokens, as
selected by `NumberTokenSelector`. Non-number tokens remain untouched or masked out.
Expand All @@ -26,20 +26,30 @@ class GaussianLabelSmoother(LabelSmoother):
sigma: float = 1.0
ignore_index: int = -100
selector: object = None # Instance of `NumberTokenSelector`
eps = 1e-8 # epsilon

def __call__(self, model_output, labels: Tensor, shift_labels: bool = False) -> Tensor:
"""
Compute the Gaussian-smoothed cross-entropy loss.
Parameters:
model_output: torch.Tensor or Dict[str, torch.Tensor]
The model output logits or a dictionary containing the logits.
labels: torch.Tensor of shape (batch_size, seq_len)
shift_labels: bool
"""
# Get logits from model output
if isinstance(model_output, dict):
logits = model_output["logits"]
logits = model_output["logits"] # (batch_size, seq_len, voc_size)
else:
logits = model_output[0]
logits = model_output[0] # (batch_size, seq_len, voc_size)

# Handle empty logits or labels gracefully by returning zero loss
if logits.numel() == 0 or labels.numel() == 0:
return torch.tensor(0.0, device=logits.device)
# Return a zero that still has grad_fn
print("requires_grad:", logits.requires_grad)
return logits.sum() * 0.0


# Shift labels if needed
if shift_labels:
Expand All @@ -52,76 +62,91 @@ def __call__(self, model_output, labels: Tensor, shift_labels: bool = False) ->
raise AttributeError("The selector must have an attribute 'nvocab' representing the number of valid vocab tokens.")

# Select number tokens
logits, number_tokens = self.selector.select_number_tokens(logits)
number_logits, vocab_numbers_mask = self.selector.select_number_tokens(logits) # (batch_size, seq_len, num_classes_numbers)

# Get the number of classes and the mask for number tokens
tokens_encoding_numbers = self.selector.nvocab[number_tokens]
num_classes = tokens_encoding_numbers.shape[0]
labels_mask = torch.isin(labels, tokens_encoding_numbers)
tokens_encoding_numbers = self.selector.nvocab[vocab_numbers_mask]
num_classes_numbers = tokens_encoding_numbers.shape[0]
labels_number_mask = torch.isin(labels, tokens_encoding_numbers) # (batch_size, seq_len)

else:
# If no selector is given, assume all are number tokens
labels_mask = torch.ones_like(labels, dtype=torch.bool)
num_classes = logits.size(-1) # Dynamic determination of num_classes
# raise Exception("A NumberTokenSelector needs to be provided to the GaussianLabelSmoother.")
labels_number_mask = torch.ones_like(labels, dtype=torch.bool)
num_classes_numbers = logits.size(-1) # Dynamic determination of num_classes_numbers

# Mask for valid number labels and non-padding tokens. Potentially unnecessary, as number labels certainly do not include the ignore_index. Added for safety.
valid_mask = (labels != self.ignore_index) & labels_mask
# All labels that are not self.ignore_index
valid_mask = (labels != self.ignore_index) # (batch_size, seq_len)

# Validation to ensure that labels are within the valid range [0, num_classes - 1]
valid_labels = (labels[valid_mask] >= 0) & (labels[valid_mask] < num_classes)
if not torch.all(valid_labels):
raise RuntimeError("Some labels are out of the valid range [0, num_classes - 1].")
if not valid_mask.any():
# If no valid tokens are present, return zero loss that still has grad_fn
return logits.sum() * 0.0

if self.sigma == 0.0:
# When sigma is zero, use one-hot labels directly without smoothing.
# To avoid F.one_hot error, all labels outside of valid_mask are set to 0
safe_labels = labels.clone()
safe_labels = labels * valid_mask
labels_to_calculate_loss = F.one_hot(safe_labels, num_classes=num_classes).float()

# Zero out the labels_to_calculate_loss where not valid
labels_to_calculate_loss = labels_to_calculate_loss * valid_mask.unsqueeze(-1)

else:
# Check if there are any number tokens to smooth
if valid_mask.any():
# Create a tensor of class indices
class_indices = torch.arange(num_classes, device=labels.device).view(1, 1, num_classes) # (1, 1, num_classes)

# Expand labels to shape (batch_size, seq_length, 1). Cast to float32 if necessary
labels_expanded = labels.unsqueeze(-1).float() # (batch_size, seq_length, 1)

# Gaussian distribution around each label index:
# Over [0..num_classes-1] for each label l_i:
# dist_j = exp(-((j - l_i)^2 / (2*sigma^2)))

# Calculate the Gaussian probability for each class
gaussian = torch.exp(-0.5 * ((class_indices - labels_expanded) / self.sigma) ** 2) # (batch_size, num_outputs, num_classes)
# Mask for valid number labels and non-padding tokens.
number_mask = valid_mask * labels_number_mask # (batch_size, seq_len) # should not change anything, as labels_number_mask is already a subset of valid_mask
non_number_mask = valid_mask * ~labels_number_mask # (batch_size, seq_len)

# Validation to ensure that labels are within the valid range [0, num_classes_numbers - 1]
if not torch.all((labels[number_mask] >= 0) & (labels[number_mask] < num_classes_numbers)):
print("min", labels[number_mask].min(), "max", labels[number_mask].max())
raise RuntimeError("Some labels are out of the valid range [0, num_classes_numbers - 1].")

# Compute log probabilities once for efficiency
log_probs = F.log_softmax(logits, dim=-1) # [B, S, C]

# Initialize loss tensors
loss_numbers = torch.zeros_like(labels, dtype=logits.dtype, device=logits.device) # (batch_size, seq_len)
loss_non_numbers = torch.zeros_like(labels, dtype=logits.dtype, device=logits.device) # (batch_size, seq_len)

# Compute loss for number tokens
if number_mask.any():
if self.sigma == 0.0:
# When sigma is zero, use one-hot labels directly without smoothing.
# To avoid F.one_hot error, all labels outside of valid_mask are set to 0
number_labels_filled = labels.clone()
number_labels_filled = labels.masked_fill(~number_mask, 0) # All non-number tokens are filled with zero
number_one_hot = F.one_hot(number_labels_filled, num_classes=num_classes_numbers).float()
number_one_hot = number_one_hot * number_mask.unsqueeze(-1) # Zero out non-number tokens

# Compute the loss for number tokens
loss_numbers = -(number_one_hot * log_probs[..., :num_classes_numbers]).sum(dim=-1)

# Normalize to ensure each (batch, output) sums to 1
gaussian_probs = gaussian / gaussian.sum(dim=2 , keepdim=True) # [B, S, C]

# Apply the valid mask
labels_to_calculate_loss = gaussian_probs * valid_mask.unsqueeze(-1)
else:
# Gaussian smoothing for number tokens
# Create a tensor of class indices
class_indices = torch.arange(num_classes_numbers, device=labels.device).view(1, 1, num_classes_numbers) # (1, 1, num_classes_numbers)

else:
# If no valid tokens, set labels_to_calculate_loss to zero
labels_to_calculate_loss = torch.zeros_like(logits)
# Expand labels to shape (batch_size, seq_length, 1). Cast to float32 if necessary
labels_expanded = labels.unsqueeze(-1).float() # (batch_size, seq_length, 1)

# Compute Gaussian distribution around each label index
gaussian = torch.exp(-0.5 * ((class_indices - labels_expanded) / self.sigma) ** 2) # (batch_size, seq_len//number_outputs, num_classes_numbers)

# Compute cross-entropy using smoothed label distribution
log_probs = F.log_softmax(logits, dim=-1) # shape [B, S, C]
loss_per_token = -(labels_to_calculate_loss * log_probs).sum(dim=-1) # distribution = - sum_{j} (smoothed_label_j * log_probs_j)
# Normalize to ensure each (batch, output) sums to 1. Prevent division by zero
gaussian_probs = gaussian / (gaussian.sum(dim=2, keepdim=True) + self.eps)

# Apply mask to Gaussian probabilities
gaussian_probs = gaussian_probs * number_mask.unsqueeze(-1) # Zero out non-number tokens

# Compute the loss for number tokens
loss_numbers = -(gaussian_probs * log_probs[..., :num_classes_numbers]).sum(dim=-1) # (batch_size, seq_len)

# Average across the valid tokens. Also works in the case that num_valid == 0.
# Invalid positions are replaced with zero, ensuring that the tensor remains connected to the graph
loss_per_token = torch.where(valid_mask, loss_per_token, torch.zeros_like(loss_per_token))
# Compute loss for non-number tokens
if non_number_mask.any():
# One-hot encoding for non-number tokens
non_number_labels_filled = labels.clone()
non_number_labels_filled = non_number_labels_filled.masked_fill(~non_number_mask, 0) # Fill non-valid tokens with 0 # (batch_size, seq_len)
one_hot_non_num = F.one_hot(non_number_labels_filled, num_classes=logits.size(-1)).float()
one_hot_non_num = one_hot_non_num * non_number_mask.unsqueeze(-1).expand(-1, -1, one_hot_non_num.size(-1)) # non_number_mask.unsqueeze(-1) # Zero out non-number tokens

# Compute the loss for non-number tokens
loss_non_numbers = -(one_hot_non_num * log_probs).sum(dim=-1)

# Combine the two losses into a single tensor
loss_per_token = torch.where(number_mask, loss_numbers, loss_non_numbers) # (batch_size, seq_len)

# Average across the valid tokens.
num_valid = valid_mask.sum().float()
loss = loss_per_token.sum() / torch.clamp(num_valid, min=1.0)

return loss




0 comments on commit 0d8ccaf

Please sign in to comment.