diff --git a/project/classifiers/gaussian_mixture_model.py b/project/classifiers/gaussian_mixture_model.py index 079d4e2..77c98f4 100644 --- a/project/classifiers/gaussian_mixture_model.py +++ b/project/classifiers/gaussian_mixture_model.py @@ -152,7 +152,11 @@ def from_json(data: dict) -> "SingleGMM": ( np.array(d["w"]), vcol(np.array(d["mu"])), - np.array(d["C"]) if gmm._type == "full" else np.diag(np.array(d["C"])), + ( + np.diag(np.array(d["C"])) + if gmm._type == "diagonal" + else np.array(d["C"]) + ), ) for d in data["params"] ] @@ -250,7 +254,9 @@ def to_json(self) -> dict: "w": w.tolist(), # Save as row to make the representation more compact "mu": vrow(mu).tolist(), - "C": C.tolist() if self._type == "full" else np.diag(C).tolist(), + "C": ( + np.diag(C).tolist() if self._type == "diagonal" else C.tolist() + ), } for w, mu, C in self.params ],