Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Oct 31, 2024
1 parent 5e74d9b commit c45b099
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/depiction/calibration/models/constant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def zero(cls) -> ConstantModel:

@classmethod
def fit_mean(cls, x_arr: NDArray[np.float64], y_arr: NDArray[np.float64]) -> ConstantModel:
return ConstantModel(value=np.mean(y_arr))
return ConstantModel(value=float(np.mean(y_arr)))

@classmethod
def fit_median(cls, x_arr: NDArray[np.float64], y_arr: NDArray[np.float64]) -> ConstantModel:
return ConstantModel(value=np.median(y_arr))
return ConstantModel(value=float(np.median(y_arr)))
17 changes: 10 additions & 7 deletions src/depiction/calibration/models/fit_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from typing import Union
from numpy.typing import NDArray
from depiction.calibration.models import LinearModel, PolynomialModel
from __future__ import annotations

import numpy as np
from depiction.calibration.models import LinearModel, PolynomialModel
from numpy.typing import NDArray

ModelType = LinearModel | PolynomialModel


def fit_model(x: NDArray[np.float64], y: NDArray[np.float64], model_type: str) -> Union[LinearModel, PolynomialModel]:
def fit_model(x: NDArray[np.float64], y: NDArray[np.float64], model_type: str) -> ModelType:
"""Fits a model to the given data, with the particular model_type."""
model: ModelType
if len(x) < 3:
# If there are not enough points, return a zero model.
if model_type.startswith("poly_"):
model_class = PolynomialModel
model = PolynomialModel.zero()
elif model_type.startswith("linear"):
model_class = LinearModel
model = LinearModel.zero()
else:
raise ValueError(f"Unknown {model_type=}")
model = model_class.zero()
elif model_type == "linear":
model = LinearModel.fit_lsq(x_arr=x, y_arr=y)
elif model_type.startswith("poly_"):
Expand Down
6 changes: 3 additions & 3 deletions src/depiction/calibration/models/polynomial_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __post_init__(self) -> None:
@property
def is_zero(self) -> bool:
"""Returns True if the model is the (exact) zero function."""
return np.all(self.coef == 0)
return bool(np.all(self.coef == 0))

@property
def degree(self) -> int:
Expand All @@ -34,11 +34,11 @@ def predict(self, x: NDArray[np.float64]) -> NDArray[np.float64]:

@classmethod
def identity(cls, degree: int = 1) -> "PolynomialModel":
return cls([0] * degree + [1])
return cls(np.array([0] * degree + [1]))

@classmethod
def zero(cls, degree: int = 1) -> "PolynomialModel":
return cls([0] * (degree + 1))
return cls(np.array([0] * (degree + 1)))

@classmethod
def fit_lsq(cls, x_arr: NDArray[np.float64], y_arr: NDArray[np.float64], degree: int) -> "PolynomialModel":
Expand Down

0 comments on commit c45b099

Please sign in to comment.