Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
 into v3.4-release
  • Loading branch information
tomaarsen committed Jan 23, 2025
2 parents f443625 + f4dc7b5 commit 90be5b7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
15 changes: 10 additions & 5 deletions sentence_transformers/losses/AdaptiveLayerLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses.CachedGISTEmbedLoss import CachedGISTEmbedLoss
from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import CachedMultipleNegativesRankingLoss
from sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss import (
CachedMultipleNegativesSymmetricRankingLoss,
)
from sentence_transformers.models import Transformer


Expand Down Expand Up @@ -149,7 +152,8 @@ def __init__(
- `Adaptive Layers <../../examples/training/adaptive_layer/README.html>`_
Requirements:
1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss` or :class:`CachedGISTEmbedLoss`.
1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`,
:class:`CachedMultipleNegativesSymmetricRankingLoss`, or :class:`CachedGISTEmbedLoss`.
Inputs:
+---------------------------------------+--------+
Expand Down Expand Up @@ -192,10 +196,11 @@ def __init__(
self.kl_div_weight = kl_div_weight
self.kl_temperature = kl_temperature
assert isinstance(self.model[0], Transformer)
if isinstance(loss, CachedMultipleNegativesRankingLoss):
warnings.warn("MatryoshkaLoss is not compatible with CachedMultipleNegativesRankingLoss.", stacklevel=2)
if isinstance(loss, CachedGISTEmbedLoss):
warnings.warn("MatryoshkaLoss is not compatible with CachedGISTEmbedLoss.", stacklevel=2)
if isinstance(
loss,
(CachedMultipleNegativesRankingLoss, CachedMultipleNegativesSymmetricRankingLoss, CachedGISTEmbedLoss),
):
warnings.warn(f"MatryoshkaLoss is not compatible with {loss.__class__.__name__}.", stacklevel=2)

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
# Decorate the forward function of the transformer to cache the embeddings of all layers
Expand Down
3 changes: 2 additions & 1 deletion sentence_transformers/losses/Matryoshka2dLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __init__(
- `Adaptive Layers <../../examples/training/adaptive_layer/README.html>`_
Requirements:
1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`.
1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`,
:class:`CachedMultipleNegativesSymmetricRankingLoss`, or :class:`CachedGISTEmbedLoss`.
Inputs:
+---------------------------------------+--------+
Expand Down
6 changes: 3 additions & 3 deletions sentence_transformers/losses/MatryoshkaLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def __init__(
different embedding dimensions. This is useful for when you want to train a model where users have the option
to lower the embedding dimension to improve their embedding comparison speed and costs.
This loss is also compatible with the Cached... losses, which are in-batch negative losses that allow for
higher batch sizes. The higher batch sizes allow for more negatives, and often result in a stronger model.
Args:
model: SentenceTransformer model
loss: The loss function to be used, e.g.
Expand All @@ -143,9 +146,6 @@ def __init__(
- The concept was introduced in this paper: https://arxiv.org/abs/2205.13147
- `Matryoshka Embeddings <../../examples/training/matryoshka/README.html>`_
Requirements:
1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss` or :class:`CachedGISTEmbedLoss`.
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
Expand Down

0 comments on commit 90be5b7

Please sign in to comment.