Skip to content

Commit

Permalink
Cleaned up
Browse files Browse the repository at this point in the history
  • Loading branch information
ad045 committed Jan 6, 2025
1 parent a98fa3a commit b4ef612
Showing 1 changed file with 0 additions and 48 deletions.
48 changes: 0 additions & 48 deletions tests/test_gaussian_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,54 +135,6 @@ def test_no_valid_tokens_returns_zero_loss(self):
# Expect the loss to be zero
self.assertEqual(loss_smoothed.item(), 0.0, msg=f"Loss should be zero when all labels are ignore_index, but got {loss_smoothed.item()}")

def test_selector_not_none(self):
"""
Test that when a selector is provided, only selected tokens are smoothed.
"""
# Define a mock selector that selects only the first token in each sequence
class MockSelector:
def __init__(self, num_classes, device):
# Initialize nvocab with NaN for all classes
self.nvocab = torch.full((num_classes,), float('nan'), device=device)
# Set only the first token as a number token (f.e. here the decoded number is 0.0)
self.nvocab[0] = 0.0 # TODO: SET TO OTHER VALUE AND CHECK IF it STiLL WORKS!

def select_number_tokens(self, logits, labels):
number_tokens = torch.zeros_like(labels, dtype=torch.bool)
number_tokens[:, 0] = True # Select only the first token in each sequence
return logits, labels, number_tokens

# Instantiate the MockSelector with the appropriate number of classes and device
selector = MockSelector(num_classes=self.num_classes, device=self.device)
smoother = GaussianLabelSmoother(sigma=1.0, ignore_index=self.ignore_index, selector=selector)

model_output = {"logits": self.logits}
loss_smoothed = smoother(model_output, self.labels, shift_labels=False)

# Manually compute the expected loss by selecting only the first token
logits_selected = self.logits[:, 0, :] # First token in each sequence
labels_selected = self.labels[:, 0] # Corresponding labels

# Compute Gaussian labels manually
classes_arange = torch.arange(self.num_classes, device=self.device).unsqueeze(0) # Shape: [1, C]
labels_flat_expanded = labels_selected.unsqueeze(1).float() # Shape: [B, 1]
dist_sq = (classes_arange - labels_flat_expanded) ** 2
gauss = torch.exp(-dist_sq / (2 * (smoother.sigma ** 2))) # Shape: [B, C]
gauss = gauss / gauss.sum(dim=-1, keepdim=True) # Normalize

# Compute cross entropy with smoothed labels
log_probs = F.log_softmax(logits_selected, dim=-1) # Shape: [B, C]
loss_manual = -(gauss * log_probs).sum(dim=-1).mean() # Scalar

# Compare with loss_smoothed
self.assertAlmostEqual(
loss_smoothed.item(),
loss_manual.item(),
places=6,
msg=f"Smoothed loss with selector ({loss_smoothed.item()}) does not match manually computed loss ({loss_manual.item()})"
)


def test_sigma_zero_no_nan(self):
"""
Test that when sigma=0, the loss is correctly computed and does not result in NaN.
Expand Down

0 comments on commit b4ef612

Please sign in to comment.