Skip to content

Commit

Permalink
Implement ensemble binders for RandomForest and AdaBoost classifiers;…
Browse files Browse the repository at this point in the history
… remove legacy binder classes
  • Loading branch information
eminyous committed Jan 2, 2025
1 parent f6bc8fc commit 179cfb4
Show file tree
Hide file tree
Showing 14 changed files with 48 additions and 176 deletions.
1 change: 0 additions & 1 deletion fipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
__all__ = [
"FIPE",
"OCEAN",
"OCEAN",
"BaseOCEAN",
"BasePruner",
"Ensemble",
Expand Down
10 changes: 5 additions & 5 deletions fipe/ensemble/generic.py → fipe/ensemble/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,33 @@
PT = TypeVar("PT", bound=ParsableTree)


class Callback:
class EnsembleBinderCallback:
__metaclass__ = ABCMeta

@abstractmethod
def predict_leaf(self, leaf_index: int, index: int) -> Prob:
raise NotImplementedError


class GenericEnsemble(Generic[BE, PT]):
class GenericEnsembleBinder(Generic[BE, PT]):
__metaclass__ = ABCMeta

NUM_BINARY_CLASSES = 2

_base: BE
__callback: Callback
__callback: EnsembleBinderCallback

def __init__(
self,
base: BE,
*,
callback: Callback,
callback: EnsembleBinderCallback,
) -> None:
self._base = base
self.__callback = callback

@property
def callback(self) -> Callback:
def callback(self) -> EnsembleBinderCallback:
return self.__callback

def predict(self, X: npt.ArrayLike, w: npt.ArrayLike) -> MClass:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
LGBMClassifier,
RandomForestClassifier,
)
from ..generic import Callback
from ..binder import EnsembleBinderCallback
from .ab import AdaBoostBinder
from .gb import GradientBoostingBinder
from .lgbm import LightGBMBinder
Expand All @@ -22,10 +22,10 @@
)


def create_ensemble(
def create_binder(
base: BaseEnsemble,
*,
callback: Callback,
callback: EnsembleBinderCallback,
) -> EnsembleBinder:
if isinstance(base, RandomForestClassifier):
return RandomForestBinder(base, callback=callback)
Expand All @@ -43,5 +43,5 @@ def create_ensemble(

__all__ = [
"EnsembleBinder",
"create_ensemble",
"create_binder",
]
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import numpy as np
import numpy.typing as npt

from ...typing import LGBMClassifier, MProb, LightGBMParsableTree, Prob
from ..generic import GenericEnsemble
from ...typing import LGBMClassifier, LightGBMParsableTree, MProb, Prob
from ..binder import GenericEnsembleBinder


class LightGBMBinder(GenericEnsemble[LGBMClassifier, LightGBMParsableTree]):
class LightGBMBinder(
GenericEnsembleBinder[LGBMClassifier, LightGBMParsableTree]
):
TREE_INFO_KEY = "tree_info"

@property
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions fipe/ensemble/classes/skl.py → fipe/ensemble/binders/skl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
AdaBoostClassifier,
BaseDecisionTree,
GradientBoostingClassifier,
SKLearnParsableTree,
RandomForestClassifier,
SKLearnParsableTree,
)
from ..generic import GenericEnsemble
from ..binder import GenericEnsembleBinder

Classifier = (
RandomForestClassifier | AdaBoostClassifier | GradientBoostingClassifier
Expand All @@ -20,7 +20,7 @@


class EnsembleBinderSKLearn(
GenericEnsemble[CL, SKLearnParsableTree], Generic[CL, DT]
GenericEnsembleBinder[CL, SKLearnParsableTree], Generic[CL, DT]
):
__metaclass__ = ABCMeta

Expand Down
4 changes: 2 additions & 2 deletions fipe/ensemble/classes/xgb.py → fipe/ensemble/binders/xgb.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from ...tree import XGBoostTreeParser
from ...typing import Booster, XGBoostParsableTree
from ..generic import GenericEnsemble
from ..binder import GenericEnsembleBinder


class XGBoostBinder(GenericEnsemble[Booster, XGBoostParsableTree]):
class XGBoostBinder(GenericEnsembleBinder[Booster, XGBoostParsableTree]):
TREE_KEY = "Tree"

INDEX = (
Expand Down
46 changes: 24 additions & 22 deletions fipe/ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,62 +5,64 @@
from ..feature import FeatureEncoder
from ..tree import Tree, TreeParser, create_parser
from ..typing import BaseEnsemble, MClass, MProb, Prob
from .classes import EnsembleBinder, create_ensemble
from .generic import Callback
from .binder import EnsembleBinderCallback
from .binders import EnsembleBinder, create_binder


class Ensemble(Sequence[Tree], Callback):
ensemble: EnsembleBinder
tree_parser: TreeParser
trees: Sequence[Tree]
class Ensemble(Sequence[Tree], EnsembleBinderCallback):
_binder: EnsembleBinder
_tree_parser: TreeParser
_trees: Sequence[Tree]

def __init__(self, base: BaseEnsemble, encoder: FeatureEncoder) -> None:
self.ensemble = self.init_ensemble(base=base, callback=self)
self.tree_parser = self.init_tree_parser(base=base, encoder=encoder)
parse = self.tree_parser.parse
base_trees = self.ensemble.base_trees
self.trees = list(map(parse, base_trees))
self._binder = self.init_ensemble_binder(base=base, callback=self)
self._tree_parser = self.init_tree_parser(base=base, encoder=encoder)
parse = self._tree_parser.parse
base_trees = self._binder.base_trees
self._trees = list(map(parse, base_trees))

def predict(self, X: npt.ArrayLike, w: npt.ArrayLike) -> MClass:
return self.ensemble.predict(X=X, w=w)
return self._binder.predict(X=X, w=w)

def score(self, X: npt.ArrayLike, w: npt.ArrayLike) -> MProb:
return self.ensemble.score(X=X, w=w)
return self._binder.score(X=X, w=w)

def scores(self, X: npt.ArrayLike) -> MProb:
return self.ensemble.scores(X=X)
return self._binder.scores(X=X)

def predict_leaf(self, leaf_index: int, index: int) -> Prob:
return Prob(self[index].predict(leaf_index))

@property
def is_binary(self) -> bool:
return self.ensemble.is_binary
return self._binder.is_binary

@property
def n_classes(self) -> int:
return self.ensemble.n_classes
return self._binder.n_classes

@property
def n_estimators(self) -> int:
return self.ensemble.n_estimators
return self._binder.n_estimators

@property
def max_depth(self) -> int:
return max(tree.max_depth for tree in self)

def __getitem__(self, t: int) -> Tree:
return self.trees[t]
return self._trees[t]

def __iter__(self) -> Iterator[Tree]:
return iter(self.trees)
return iter(self._trees)

def __len__(self) -> int:
return len(self.trees)
return len(self._trees)

@staticmethod
def init_ensemble(base: BaseEnsemble, callback: Callback) -> EnsembleBinder:
return create_ensemble(base=base, callback=callback)
def init_ensemble_binder(
base: BaseEnsemble, callback: EnsembleBinderCallback
) -> EnsembleBinder:
return create_binder(base=base, callback=callback)

@staticmethod
def init_tree_parser(
Expand Down
131 changes: 0 additions & 131 deletions fipe/ensemble/parser.py

This file was deleted.

4 changes: 2 additions & 2 deletions fipe/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
)
from .container import TreeContainer
from .parsers import (
TreeParserCL,
LightGBMTreeParser,
TreeParserCL,
TreeParserRG,
XGBoostTreeParser,
)
Expand All @@ -37,11 +37,11 @@ def create_parser(base: BaseEnsemble, encoder: FeatureEncoder) -> TreeParser:


__all__ = [
"LightGBMTreeParser",
"Tree",
"TreeContainer",
"TreeParser",
"TreeParserCL",
"LightGBMTreeParser",
"TreeParserRG",
"XGBoostTreeParser",
"create_parser",
Expand Down
6 changes: 3 additions & 3 deletions fipe/tree/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .lgbm import LightGBMTreeParser
from .skl import TreeParserCL, TreeParserRG, SKLearnTreeParser
from .skl import SKLearnTreeParser, TreeParserCL, TreeParserRG
from .xgb import XGBoostTreeParser

__all__ = [
"TreeParserCL",
"LightGBMTreeParser",
"TreeParserRG",
"SKLearnTreeParser",
"TreeParserCL",
"TreeParserRG",
"XGBoostTreeParser",
]

0 comments on commit 179cfb4

Please sign in to comment.