diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index a8e226ab..c8293918 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -11,6 +11,22 @@ # adapted from https://github.com/HobbitLong/SupContrast # modified for multi-supcon class MultiSupConLoss(GenericPairLoss): + """ + Args: + num_classes: number of classes + temperature: temperature for scaling the similarity matrix + threshold: threshold for jaccard similarity + + Inputs: + embeddings: tensor of size (batch_size, embedding_size) + labels: tensor of size (batch_size, num_classes) + each row is a binary vector of size num_classes that only has 1s for the positive + labels, and 0s for the negative labels + indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix) + or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix) + Can also be left as None + ref_emb: tensor of size (batch_size, embedding_size) + """ def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs): super().__init__(mat_based_loss=True, **kwargs) self.temperature = temperature @@ -77,10 +93,13 @@ def forward( """ Args: embeddings: tensor of size (batch_size, embedding_size) - labels: tensor of size (batch_size) - indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives) - or size 4 for pairs (anchor1, postives, anchor2, negatives) + labels: tensor of size (batch_size, num_classes) + each row is a binary vector of size num_classes that only has 1s for the positive + labels, and 0s for the negative labels + indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix) + or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix) Can also be left as None + ref_emb: tensor of size (batch_size, embedding_size) Returns: the loss """ self.reset_stats()