From 8272d937c78909e739eb0ad4962ecc61c595de21 Mon Sep 17 00:00:00 2001 From: Anleeos <2937160075@qq.com> Date: Thu, 23 Nov 2023 20:14:30 +0800 Subject: [PATCH] [v0.10.0 -> v0.13.0] Task6 (#217) --- docs/tutorial/05-python_AutoTVM.md | 39 +++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/docs/tutorial/05-python_AutoTVM.md b/docs/tutorial/05-python_AutoTVM.md index 1e3c3165..ca7e90fd 100644 --- a/docs/tutorial/05-python_AutoTVM.md +++ b/docs/tutorial/05-python_AutoTVM.md @@ -279,7 +279,44 @@ tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=par # 按顺序调优提取的任务 for i, task in enumerate(tasks): prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) - tuner_obj = XGBTuner(task, loss_type="rank") + + # choose tuner + tuner = "xgb" + + # create tuner + if tuner == "xgb": + tuner_obj = XGBTuner(task, loss_type="reg") + elif tuner == "xgb_knob": + tuner_obj = XGBTuner(task, loss_type="reg", feature_type="knob") + elif tuner == "xgb_itervar": + tuner_obj = XGBTuner(task, loss_type="reg", feature_type="itervar") + elif tuner == "xgb_curve": + tuner_obj = XGBTuner(task, loss_type="reg", feature_type="curve") + elif tuner == "xgb_rank": + tuner_obj = XGBTuner(task, loss_type="rank") + elif tuner == "xgb_rank_knob": + tuner_obj = XGBTuner(task, loss_type="rank", feature_type="knob") + elif tuner == "xgb_rank_itervar": + tuner_obj = XGBTuner(task, loss_type="rank", feature_type="itervar") + elif tuner == "xgb_rank_curve": + tuner_obj = XGBTuner(task, loss_type="rank", feature_type="curve") + elif tuner == "xgb_rank_binary": + tuner_obj = XGBTuner(task, loss_type="rank-binary") + elif tuner == "xgb_rank_binary_knob": + tuner_obj = XGBTuner(task, loss_type="rank-binary", feature_type="knob") + elif tuner == "xgb_rank_binary_itervar": + tuner_obj = XGBTuner(task, loss_type="rank-binary", feature_type="itervar") + elif tuner == "xgb_rank_binary_curve": + tuner_obj = XGBTuner(task, loss_type="rank-binary", feature_type="curve") + elif tuner == "ga": + tuner_obj = GATuner(task, pop_size=50) + elif tuner == "random": + tuner_obj = RandomTuner(task) + elif tuner == "gridsearch": + tuner_obj = GridSearchTuner(task) + else: + raise ValueError("Invalid tuner: " + tuner) + tuner_obj.tune( n_trial=min(tuning_option["trials"], len(task.config_space)), early_stopping=tuning_option["early_stopping"],