Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comet-ml experiment management SDK integration #105

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions example/7B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,10 @@ wandb:
run_name: "" # your wandb run name
key: "" # your wandb api key
offline: False

comet_ml:
project_name: "" # your Comet project name (Mandatory if you would like to enable Comet experiment management)
workspace: null # your Comet workspace name (Optional)
experiment_key: null # the Comet Experiment identifier to be used for logging.(Optional)
api_key: null # your Comet API key (Online - mandatory, offline - optional)
online: True # flag to control if data should be sent to the Comet server during the run
37 changes: 37 additions & 0 deletions finetune/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,42 @@ def __post_init__(self) -> None:
if self.experiment_name is None:
raise ValueError("If `mlflow.tracking_uri` is set, `mlflow.experiment_name` must be set as well.")

@dataclass
class CometMLArgs(Serializable):
"""The configuration options for Comet-ml experiment management SDK to be used for tracking metrics, parameters, etc.

See https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/start/#comet_ml.start for more details about
configuration options.

To enable Comet experiment management, set `comet_ml.project_name` in the configuration file.

Args:
project_name (str, optional): your Comet project name. Mandatory if you would like to enable Comet
experiment management SDK.
workspace (str, optional): Comet workspace name. If not provided, uses the default workspace.
experiment_key (str, optional): The Experiment identifier to be used for logging. This is used either to append
data to an Existing Experiment or to control the key of new experiments (for example to match another
identifier). Must be an alphanumeric string whose length is between 32 and 50 characters.
api_key (str, optional): Comet API key. It's recommended to configure the API Key with `comet login`.
online (boolean): If True, the data will be logged to Comet server, otherwise it will be stored locally
in an offline experiment. Default is ``True``.
"""
project_name: Optional[str] = None
workspace: Optional[str] = None
experiment_key: Optional[str] = None
api_key: Optional[str] = None
online: bool = True

def __post_init__(self) -> None:
if self.project_name is not None:
try:
import comet_ml # noqa: F401
except ImportError:
raise ImportError(
"`comet-ml` is not installed. Either `pip install comet-ml` or set `comet_ml.project_name` to None.")

if len(self.project_name) == 0:
raise ValueError("`comet_ml.project_name` must not be an empty string.")


@dataclass
Expand Down Expand Up @@ -91,6 +127,7 @@ class TrainArgs(Serializable):
# logging
wandb: WandbArgs = field(default_factory=WandbArgs)
mlflow: MLFlowArgs = field(default_factory=MLFlowArgs)
comet_ml: CometMLArgs = field(default_factory=CometMLArgs)

# LoRA
lora: Optional[LoraArgs] = field(default_factory=LoraArgs)
Expand Down
24 changes: 23 additions & 1 deletion finetune/monitoring/metrics_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torch.utils.tensorboard import SummaryWriter

from finetune.args import MLFlowArgs, TrainArgs, WandbArgs
from finetune.args import MLFlowArgs, TrainArgs, WandbArgs, CometMLArgs
from finetune.utils import TrainState

logger = logging.getLogger("metrics_logger")
Expand Down Expand Up @@ -109,6 +109,7 @@ def __init__(
is_master: bool,
wandb_args: WandbArgs,
mlflow_args: MLFlowArgs,
comet_ml_args: CometMLArgs,
config: Optional[Dict[str, Any]] = None,
):
self.dst_dir = dst_dir
Expand All @@ -130,6 +131,7 @@ def __init__(
)
self.is_wandb = wandb_args.project is not None
self.is_mlflow = mlflow_args.tracking_uri is not None
self.is_comet = comet_ml_args.project_name is not None

if self.is_wandb:
import wandb
Expand Down Expand Up @@ -162,6 +164,17 @@ def __init__(

self.mlflow_log = mlflow.log_metric

if self.is_comet:
import comet_ml

self.comet_experiment = comet_ml.start(
api_key=comet_ml_args.api_key,
project_name=comet_ml_args.project_name,
workspace=comet_ml_args.workspace,
experiment_key=comet_ml_args.experiment_key,
online=comet_ml_args.online,
)

def log(self, metrics: Dict[str, Union[float, int]], step: int):
if not self.is_master:
return
Expand All @@ -179,6 +192,10 @@ def log(self, metrics: Dict[str, Union[float, int]], step: int):
if self.is_mlflow:
self.mlflow_log(f"{self.tag}.{key}", value, step=step)

if self.is_comet:
with self.comet_experiment.context_manager(self.tag):
self.comet_experiment.log_metric(key, value, step=step, include_context=True)

if self.is_wandb:
# grouping in wandb is done with /
self.wandb_log(
Expand Down Expand Up @@ -218,6 +235,11 @@ def close(self):

mlflow.end_run()

if self.is_comet and self.comet_experiment is not None:

# to make sure everything is logged to Comet at this moment explicitly
self.comet_experiment.end()

def __del__(self):
if self.summary_writer is not None:
raise RuntimeError(
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _train(
is_master=get_rank() == 0,
wandb_args=args.wandb,
mlflow_args=args.mlflow,
comet_ml_args=args.comet_ml,
config=dataclasses.asdict(args),
)
exit_stack.enter_context(logged_closing(metrics_logger, "metrics_logger"))
Expand All @@ -120,6 +121,7 @@ def _train(
is_master=get_rank() == 0,
wandb_args=args.wandb,
mlflow_args=args.mlflow,
comet_ml_args=args.comet_ml,
config=dataclasses.asdict(args),
)
exit_stack.enter_context(logged_closing(eval_logger, "eval_logger"))
Expand Down