From e7e1a6d300037d4fb8c951509a11dd65b0454e96 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 15 Jun 2024 15:07:42 +0200 Subject: [PATCH] Casted sample_weights as tensor as well --- pomegranate/distributions/normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pomegranate/distributions/normal.py b/pomegranate/distributions/normal.py index 3ca19c76..26133b60 100644 --- a/pomegranate/distributions/normal.py +++ b/pomegranate/distributions/normal.py @@ -257,7 +257,7 @@ def summarize(self, X, sample_weight=None): X, sample_weight = super().summarize(X, sample_weight=sample_weight) X = _cast_as_tensor(X, dtype=self.means.dtype) - + sample_weight = _cast_as_tensor(sample_weight, dtype=self.means.dtype) if self.covariance_type == 'full': self._w_sum += torch.sum(sample_weight, dim=0) self._xw_sum += torch.sum(X * sample_weight, axis=0)