diff --git a/fipe/prune/pruner.py b/fipe/prune/pruner.py index 9ecf410..3aa846d 100644 --- a/fipe/prune/pruner.py +++ b/fipe/prune/pruner.py @@ -6,7 +6,7 @@ from ..feature import FeatureEncoder from ..mip import MIP -from ..typing import BaseEnsemble, MNumber, MProb +from ..typing import BaseEnsemble, MNumber, MProb, Number from .base import BasePruner @@ -45,12 +45,12 @@ def __init__( def build(self) -> None: self._add_weight_vars() - self._add_objective() self._n_samples = 0 def add_samples(self, X: npt.ArrayLike) -> None: X = np.asarray(X) - classes = self.ensemble.predict(X=X, w=self._weights) + w = self._weights + classes = self.ensemble.predict(X=X, w=w) prob = self.ensemble.predict_proba(X=X) n = X.shape[0] for i in range(n): @@ -60,7 +60,9 @@ def prune(self) -> None: if self._n_samples == 0: msg = "No samples have been added to the pruner." raise RuntimeError(msg) - self.optimize() + self._prune_l1() + if self._norm == 0: + self._prune_l0() @property def n_samples(self) -> int: @@ -110,3 +112,19 @@ def _validate_norm(self, norm: int) -> None: if norm not in self.VALID_NOMRS: msg = "The norm must be either 0 or 1." raise ValueError(msg) + + def _prune_l1(self) -> None: + w = self._weight_vars + self.setObjective(w.sum(), gp.GRB.MINIMIZE) + self.optimize() + + def _prune_l0(self) -> None: + W = Number(np.sum(self._weight_vars.X)) + n = self.n_estimators + w = self._weight_vars + u = self.addMVar(shape=n, vtype=gp.GRB.BINARY, name="u") + contrs = self.addConstr(w <= W * u, name="bigM") + self.setObjective(u.sum(), gp.GRB.MINIMIZE) + self.optimize() + self.remove(contrs) + self.remove(u)