diff --git a/src/depiction/calibration/models/constant_model.py b/src/depiction/calibration/models/constant_model.py index 5468a9f..6f66122 100644 --- a/src/depiction/calibration/models/constant_model.py +++ b/src/depiction/calibration/models/constant_model.py @@ -34,8 +34,8 @@ def zero(cls) -> ConstantModel: @classmethod def fit_mean(cls, x_arr: NDArray[np.float64], y_arr: NDArray[np.float64]) -> ConstantModel: - return ConstantModel(value=np.mean(y_arr)) + return ConstantModel(value=float(np.mean(y_arr))) @classmethod def fit_median(cls, x_arr: NDArray[np.float64], y_arr: NDArray[np.float64]) -> ConstantModel: - return ConstantModel(value=np.median(y_arr)) + return ConstantModel(value=float(np.median(y_arr))) diff --git a/src/depiction/calibration/models/fit_model.py b/src/depiction/calibration/models/fit_model.py index b7cb6d3..5195b3f 100644 --- a/src/depiction/calibration/models/fit_model.py +++ b/src/depiction/calibration/models/fit_model.py @@ -1,20 +1,23 @@ -from typing import Union -from numpy.typing import NDArray -from depiction.calibration.models import LinearModel, PolynomialModel +from __future__ import annotations + import numpy as np +from depiction.calibration.models import LinearModel, PolynomialModel +from numpy.typing import NDArray + +ModelType = LinearModel | PolynomialModel -def fit_model(x: NDArray[np.float64], y: NDArray[np.float64], model_type: str) -> Union[LinearModel, PolynomialModel]: +def fit_model(x: NDArray[np.float64], y: NDArray[np.float64], model_type: str) -> ModelType: """Fits a model to the given data, with the particular model_type.""" + model: ModelType if len(x) < 3: # If there are not enough points, return a zero model. if model_type.startswith("poly_"): - model_class = PolynomialModel + model = PolynomialModel.zero() elif model_type.startswith("linear"): - model_class = LinearModel + model = LinearModel.zero() else: raise ValueError(f"Unknown {model_type=}") - model = model_class.zero() elif model_type == "linear": model = LinearModel.fit_lsq(x_arr=x, y_arr=y) elif model_type.startswith("poly_"): diff --git a/src/depiction/calibration/models/polynomial_model.py b/src/depiction/calibration/models/polynomial_model.py index e5de15a..8728e43 100644 --- a/src/depiction/calibration/models/polynomial_model.py +++ b/src/depiction/calibration/models/polynomial_model.py @@ -22,7 +22,7 @@ def __post_init__(self) -> None: @property def is_zero(self) -> bool: """Returns True if the model is the (exact) zero function.""" - return np.all(self.coef == 0) + return bool(np.all(self.coef == 0)) @property def degree(self) -> int: @@ -34,11 +34,11 @@ def predict(self, x: NDArray[np.float64]) -> NDArray[np.float64]: @classmethod def identity(cls, degree: int = 1) -> "PolynomialModel": - return cls([0] * degree + [1]) + return cls(np.array([0] * degree + [1])) @classmethod def zero(cls, degree: int = 1) -> "PolynomialModel": - return cls([0] * (degree + 1)) + return cls(np.array([0] * (degree + 1))) @classmethod def fit_lsq(cls, x_arr: NDArray[np.float64], y_arr: NDArray[np.float64], degree: int) -> "PolynomialModel":