From 6ca989e734ec566feca5904e97ae98a05da5b7ce Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 24 Nov 2023 14:30:20 +0100 Subject: [PATCH] Update docs about return value of metric function --- src/setfit/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index 0bb91a41..c97b143f 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -151,8 +151,8 @@ class Trainer(ColumnMappingMixin): function when a `trial` is passed. metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): The metric to use for evaluation. If a string is provided, we treat it as the metric - name and load it with default settings. - If a callable is provided, it must take two arguments (`y_pred`, `y_test`). + name and load it with default settings. If a callable is provided, it must take two arguments + (`y_pred`, `y_test`) and return a dictionary with metric keys to values. metric_kwargs (`Dict[str, Any]`, *optional*): Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". For example useful for providing an averaging strategy for computing f1 in a multi-label setting.