Skip to content

Commit

Permalink
Refactor Pruner class to improve pruning logic and enhance clarity by…
Browse files Browse the repository at this point in the history
… separating L1 and L0 pruning methods
  • Loading branch information
eminyous committed Jan 6, 2025
1 parent fa3a22a commit 0310657
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions fipe/prune/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..feature import FeatureEncoder
from ..mip import MIP
from ..typing import BaseEnsemble, MNumber, MProb
from ..typing import BaseEnsemble, MNumber, MProb, Number
from .base import BasePruner


Expand Down Expand Up @@ -45,12 +45,12 @@ def __init__(

def build(self) -> None:
self._add_weight_vars()
self._add_objective()
self._n_samples = 0

def add_samples(self, X: npt.ArrayLike) -> None:
X = np.asarray(X)
classes = self.ensemble.predict(X=X, w=self._weights)
w = self._weights
classes = self.ensemble.predict(X=X, w=w)
prob = self.ensemble.predict_proba(X=X)
n = X.shape[0]
for i in range(n):
Expand All @@ -60,7 +60,9 @@ def prune(self) -> None:
if self._n_samples == 0:
msg = "No samples have been added to the pruner."
raise RuntimeError(msg)
self.optimize()
self._prune_l1()
if self._norm == 0:
self._prune_l0()

@property
def n_samples(self) -> int:
Expand Down Expand Up @@ -110,3 +112,19 @@ def _validate_norm(self, norm: int) -> None:
if norm not in self.VALID_NOMRS:
msg = "The norm must be either 0 or 1."
raise ValueError(msg)

def _prune_l1(self) -> None:
w = self._weight_vars
self.setObjective(w.sum(), gp.GRB.MINIMIZE)
self.optimize()

def _prune_l0(self) -> None:
W = Number(np.sum(self._weight_vars.X))
n = self.n_estimators
w = self._weight_vars
u = self.addMVar(shape=n, vtype=gp.GRB.BINARY, name="u")
contrs = self.addConstr(w <= W * u, name="bigM")
self.setObjective(u.sum(), gp.GRB.MINIMIZE)
self.optimize()
self.remove(contrs)
self.remove(u)

0 comments on commit 0310657

Please sign in to comment.