Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Larspennig committed Dec 23, 2024
1 parent b08630a commit 32ee7b9
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions tests/loss_functions/test_expression_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pdb


class TestNumberTokenLoss(unittest.TestCase):
class TestExpressionLoss(unittest.TestCase):

def setUp(self):
self.device = torch.device("cpu")
Expand Down Expand Up @@ -83,16 +83,13 @@ def test_convert_logit_seq_to_number(self):

logits = self.create_logits(self.t5_tokenizer, token_logit_value_dict_list)
logits = logits[:, :, self.expression_loss.number_tokens]
softmaxed_logits = F.softmax(logits, dim=-1)

labels = torch.tensor(
self.t5_tokenizer.convert_tokens_to_ids(["1", "2"]), dtype=torch.long
).unsqueeze(0)

# call convert_logit_seq_to_number from the ExpressionLoss instance
result = self.expression_loss.convert_logit_seq_to_number(
softmaxed_logits, labels
)
result = self.expression_loss.convert_logit_seq_to_number(logits, labels)

expected = 16.5
self.assertAlmostEqual(result.item(), expected, places=2)
Expand Down

0 comments on commit 32ee7b9

Please sign in to comment.