Skip to content

Commit

Permalink
[v0.10.0 -> v0.13.0] Task6 (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anleeos authored Nov 23, 2023
1 parent 8878bd3 commit 8272d93
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion docs/tutorial/05-python_AutoTVM.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,44 @@ tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=par
# 按顺序调优提取的任务
for i, task in enumerate(tasks):
prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
tuner_obj = XGBTuner(task, loss_type="rank")

# choose tuner
tuner = "xgb"

# create tuner
if tuner == "xgb":
tuner_obj = XGBTuner(task, loss_type="reg")
elif tuner == "xgb_knob":
tuner_obj = XGBTuner(task, loss_type="reg", feature_type="knob")
elif tuner == "xgb_itervar":
tuner_obj = XGBTuner(task, loss_type="reg", feature_type="itervar")
elif tuner == "xgb_curve":
tuner_obj = XGBTuner(task, loss_type="reg", feature_type="curve")
elif tuner == "xgb_rank":
tuner_obj = XGBTuner(task, loss_type="rank")
elif tuner == "xgb_rank_knob":
tuner_obj = XGBTuner(task, loss_type="rank", feature_type="knob")
elif tuner == "xgb_rank_itervar":
tuner_obj = XGBTuner(task, loss_type="rank", feature_type="itervar")
elif tuner == "xgb_rank_curve":
tuner_obj = XGBTuner(task, loss_type="rank", feature_type="curve")
elif tuner == "xgb_rank_binary":
tuner_obj = XGBTuner(task, loss_type="rank-binary")
elif tuner == "xgb_rank_binary_knob":
tuner_obj = XGBTuner(task, loss_type="rank-binary", feature_type="knob")
elif tuner == "xgb_rank_binary_itervar":
tuner_obj = XGBTuner(task, loss_type="rank-binary", feature_type="itervar")
elif tuner == "xgb_rank_binary_curve":
tuner_obj = XGBTuner(task, loss_type="rank-binary", feature_type="curve")
elif tuner == "ga":
tuner_obj = GATuner(task, pop_size=50)
elif tuner == "random":
tuner_obj = RandomTuner(task)
elif tuner == "gridsearch":
tuner_obj = GridSearchTuner(task)
else:
raise ValueError("Invalid tuner: " + tuner)

tuner_obj.tune(
n_trial=min(tuning_option["trials"], len(task.config_space)),
early_stopping=tuning_option["early_stopping"],
Expand Down

0 comments on commit 8272d93

Please sign in to comment.