Skip to content

Commit

Permalink
use torch constraints to check if covariance is positive definite dur…
Browse files Browse the repository at this point in the history
…ing mean resizing. (#35693)

* use torch constraints to check for psd

* small nit

* Small change

* Small change for the ci

* nit
  • Loading branch information
abuelnasr0 authored Jan 28, 2025
1 parent 61cbb72 commit ec7afad
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version
from torch import Tensor, nn
from torch.distributions import constraints
from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint

Expand Down Expand Up @@ -2425,14 +2426,12 @@ def _init_added_embeddings_weights_with_mean(
covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens

# Check if the covariance is positive definite.
eigenvalues = torch.linalg.eigvals(covariance)
is_covariance_psd = bool(
(covariance == covariance.T).all() and not torch.is_complex(eigenvalues) and (eigenvalues > 0).all()
)
epsilon = 1e-9
is_covariance_psd = constraints.positive_definite.check(epsilon * covariance).all()
if is_covariance_psd:
# If covariances is positive definite, a distribution can be created. and we can sample new weights from it.
distribution = torch.distributions.multivariate_normal.MultivariateNormal(
mean_embeddings, covariance_matrix=1e-9 * covariance
mean_embeddings, covariance_matrix=epsilon * covariance
)
new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample(
sample_shape=(added_num_tokens,)
Expand Down

0 comments on commit ec7afad

Please sign in to comment.