From 4f2e1d00772302164d61bf2b1f9e18e008a82edf Mon Sep 17 00:00:00 2001 From: eduardz1 <54178827+eduardz1@users.noreply.github.com> Date: Tue, 16 Jul 2024 19:26:26 +0200 Subject: [PATCH] fix: tied was ignored when saving to JSON --- project/classifiers/gaussian_mixture_model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/project/classifiers/gaussian_mixture_model.py b/project/classifiers/gaussian_mixture_model.py index 079d4e2..77c98f4 100644 --- a/project/classifiers/gaussian_mixture_model.py +++ b/project/classifiers/gaussian_mixture_model.py @@ -152,7 +152,11 @@ def from_json(data: dict) -> "SingleGMM": ( np.array(d["w"]), vcol(np.array(d["mu"])), - np.array(d["C"]) if gmm._type == "full" else np.diag(np.array(d["C"])), + ( + np.diag(np.array(d["C"])) + if gmm._type == "diagonal" + else np.array(d["C"]) + ), ) for d in data["params"] ] @@ -250,7 +254,9 @@ def to_json(self) -> dict: "w": w.tolist(), # Save as row to make the representation more compact "mu": vrow(mu).tolist(), - "C": C.tolist() if self._type == "full" else np.diag(C).tolist(), + "C": ( + np.diag(C).tolist() if self._type == "diagonal" else C.tolist() + ), } for w, mu, C in self.params ],