diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b5df36e12a9..b31368c6062 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -37,6 +37,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from packaging import version from torch import Tensor, nn +from torch.distributions import constraints from torch.nn import CrossEntropyLoss, Identity from torch.utils.checkpoint import checkpoint @@ -2425,14 +2426,12 @@ def _init_added_embeddings_weights_with_mean( covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens # Check if the covariance is positive definite. - eigenvalues = torch.linalg.eigvals(covariance) - is_covariance_psd = bool( - (covariance == covariance.T).all() and not torch.is_complex(eigenvalues) and (eigenvalues > 0).all() - ) + epsilon = 1e-9 + is_covariance_psd = constraints.positive_definite.check(epsilon * covariance).all() if is_covariance_psd: # If covariances is positive definite, a distribution can be created. and we can sample new weights from it. distribution = torch.distributions.multivariate_normal.MultivariateNormal( - mean_embeddings, covariance_matrix=1e-9 * covariance + mean_embeddings, covariance_matrix=epsilon * covariance ) new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample( sample_shape=(added_num_tokens,)