From 79b3d52a6842d6ba12f0a544e27a444562a486df Mon Sep 17 00:00:00 2001 From: Wonho <42860714+GoldenCorgi@users.noreply.github.com> Date: Thu, 21 Oct 2021 15:37:47 +0800 Subject: [PATCH] Optuna Integration (#215) * optuna integration and tests * optuna integration * optuna integration * added typing for tests --- hiplot/experiment.py | 26 ++++++++++++++++++++++++++ hiplot/test_experiment.py | 19 +++++++++++++++++++ requirements/dev.txt | 1 + 3 files changed, 46 insertions(+) diff --git a/hiplot/experiment.py b/hiplot/experiment.py index bb11cc8d..5c9497c1 100644 --- a/hiplot/experiment.py +++ b/hiplot/experiment.py @@ -15,6 +15,7 @@ if tp.TYPE_CHECKING: import pandas as pd from .streamlit_helpers import ExperimentStreamlitComponent + import optuna DisplayableType = tp.Union[bool, int, float, str] @@ -502,6 +503,31 @@ def from_dataframe(dataframe: "pd.DataFrame") -> "Experiment": # No type hint t return experiment + @staticmethod + def from_optuna(study: "optuna.study.Study") -> "Experiment": # No type hint to avoid having optuna as an additional dependency + """ + Creates a HiPlot experiment from a Optuna Study. + + :param study: Optuna Study + """ + + + # Create a list of dictionary objects using study trials + # All parameters are taken using params.copy() + + hyper_opt_data = [] + for each_trial in study.trials: + trial_params = {} + trial_params["value"] = each_trial.value # name = value, as it could be RMSE / accuracy, or any value that the user selects for tuning + trial_params["uid"] = each_trial.number + trial_params.update(each_trial.params.copy()) + hyper_opt_data.append(trial_params) + experiment = Experiment.from_iterable(hyper_opt_data) + + return experiment + + + @staticmethod def merge(xp_dict: tp.Dict[str, "Experiment"]) -> "Experiment": """ diff --git a/hiplot/test_experiment.py b/hiplot/test_experiment.py index 3a8d41bd..3c36b650 100644 --- a/hiplot/test_experiment.py +++ b/hiplot/test_experiment.py @@ -10,6 +10,7 @@ import pytest import pandas as pd +import optuna import hiplot as hip @@ -39,6 +40,24 @@ def test_from_dataframe() -> None: xp.validate() xp._asdict() +def test_from_optuna() -> None: + + def objective(trial: "optuna.trial.Trial") -> float: + x = trial.suggest_float("x", -1, 1) + return x ** 2 + + study = optuna.create_study() + study.optimize(objective, n_trials=3) + + # Create a dataframe from the study. + df = study.trials_dataframe() + assert isinstance(df, pd.DataFrame) + assert df.shape[0] == 3 # n_trials. + xp = hip.Experiment.from_optuna(study) + assert len(xp.datapoints) == 3 + xp.validate() + xp._asdict() + def test_from_dataframe_nan_values() -> None: # Pandas automatically convert numeric-based columns None to NaN in dataframes diff --git a/requirements/dev.txt b/requirements/dev.txt index 36784a42..346f535c 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -11,3 +11,4 @@ pre-commit pandas streamlit>=0.63 beautifulsoup4 +optuna \ No newline at end of file