Skip to content

Commit

Permalink
Implement custom rf
Browse files Browse the repository at this point in the history
  • Loading branch information
timovdk committed Jan 24, 2025
1 parent dd77170 commit cced274
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions asreview2-optuna/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
NaiveBayesClassifier,
LogisticClassifier,
SVMClassifier,
RandomForestClassifier,
)

from sklearn.ensemble import RandomForestClassifier


def naive_bayes_params(trial: optuna.trial.FrozenTrial):
# Use logarithmic normal distribution for alpha (alpha effect is non-linear)
Expand Down Expand Up @@ -52,9 +53,26 @@ def random_forest_params(trial: optuna.trial.FrozenTrial):
}


class RFClassifier(RandomForestClassifier):
"""Random forest classifier.
Based on the sklearn implementation of the random forest
sklearn.ensemble.RandomForestClassifier.
"""

name = "rf"
label = "Random forest"

def __init__(self, n_estimators=100, max_features=10, **kwargs):
super().__init__(
n_estimators=int(n_estimators),
max_features=max_features,
**kwargs,
)

classifiers = {
"nb": NaiveBayesClassifier,
"log": LogisticClassifier,
"svm": SVMClassifier,
"rf": RandomForestClassifier,
"rf": RFClassifier,
}

0 comments on commit cced274

Please sign in to comment.