Skip to content

Commit

Permalink
Add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
QibinLiang committed Oct 11, 2023
1 parent 949a45f commit 855baa6
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions src/pytorch_metric_learning/losses/multilabel_supcon_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@
# adapted from https://github.com/HobbitLong/SupContrast
# modified for multi-supcon
class MultiSupConLoss(GenericPairLoss):
"""
Args:
num_classes: number of classes
temperature: temperature for scaling the similarity matrix
threshold: threshold for jaccard similarity
Inputs:
embeddings: tensor of size (batch_size, embedding_size)
labels: tensor of size (batch_size, num_classes)
each row is a binary vector of size num_classes that only has 1s for the positive
labels, and 0s for the negative labels
indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix)
or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix)
Can also be left as None
ref_emb: tensor of size (batch_size, embedding_size)
"""
def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs):
super().__init__(mat_based_loss=True, **kwargs)
self.temperature = temperature
Expand Down Expand Up @@ -77,10 +93,13 @@ def forward(
"""
Args:
embeddings: tensor of size (batch_size, embedding_size)
labels: tensor of size (batch_size)
indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives)
or size 4 for pairs (anchor1, postives, anchor2, negatives)
labels: tensor of size (batch_size, num_classes)
each row is a binary vector of size num_classes that only has 1s for the positive
labels, and 0s for the negative labels
indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix)
or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix)
Can also be left as None
ref_emb: tensor of size (batch_size, embedding_size)
Returns: the loss
"""
self.reset_stats()
Expand Down

0 comments on commit 855baa6

Please sign in to comment.