-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
executable file
·84 lines (72 loc) · 2.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import warnings
from pathlib import Path
import numpy as np
import pandas as pd
from fipe import Pruner
from fipe.typing import (
AdaBoostClassifier,
BaseEnsemble,
GradientBoostingClassifier,
MClass,
MNumber,
RandomForestClassifier,
)
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
def load(dataset_path: Path) -> tuple[pd.DataFrame, MClass, list[str]]:
name = dataset_path.stem
full_path = dataset_path / f"{name}.full.csv"
featurelist_path = dataset_path / f"{name}.featurelist.csv"
data = pd.read_csv(full_path)
# Read labels
labels = data.iloc[:, -1]
y = labels.astype("category").cat.codes
y = np.array(y.values)
data = data.iloc[:, :-1]
with featurelist_path.open(encoding="utf-8") as f:
features = f.read().split(",")[:-1]
f.close()
return data, y, features
def train(
model_cls: type,
options: dict[str, int | float | str | None],
X,
y,
n_estimators: int,
seed: int,
) -> tuple[BaseEnsemble, MNumber, float]:
if model_cls not in {
LGBMClassifier,
AdaBoostClassifier,
GradientBoostingClassifier,
RandomForestClassifier,
XGBClassifier,
}:
msg = f"Invalid model class: {model_cls}"
raise ValueError(msg)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
base = model_cls(
n_estimators=n_estimators,
random_state=seed,
**options,
)
base.fit(X, y)
if isinstance(base, LGBMClassifier):
base = base.booster_
elif isinstance(base, XGBClassifier):
base = base.get_booster()
w = np.ones(n_estimators)
eps = 1e-6
return base, w, eps
def evaluate(pruner: Pruner, X, y, w: MNumber) -> dict[str, float]:
pred = pruner.ensemble.predict(X, w)
new_pred = pruner.predict(X)
accuracy = (pred == y).mean()
pruner_accuracy = (new_pred == y).mean()
fidelity = (pred == new_pred).mean()
return {
"accuracy.before.pruning": accuracy,
"accuracy.after.pruning": pruner_accuracy,
"fidelity": fidelity,
}