Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patched score function #3

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mlflow/getml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os

Check failure on line 1 in mlflow/getml/__init__.py

View workflow job for this annotation

GitHub Actions / lint

[*] Import block is un-sorted or un-formatted. Run `ruff --fix .` or comment `/autoformat` to fix this error.
import logging
import pathlib

from typing import Any, Literal, Union

Check failure on line 5 in mlflow/getml/__init__.py

View workflow job for this annotation

GitHub Actions / lint

`typing.Literal` imported but unused. See https://docs.astral.sh/ruff/rules/F401 for how to fix this error.

import yaml

import mlflow
from mlflow import pyfunc
from mlflow.models import Model, ModelSignature, ModelInputExample

Check failure on line 11 in mlflow/getml/__init__.py

View workflow job for this annotation

GitHub Actions / lint

`mlflow.models.ModelSignature` imported but unused; consider removing, adding to `__all__`, or using a redundant alias. See https://docs.astral.sh/ruff/rules/F401 for how to fix this error.

Check failure on line 11 in mlflow/getml/__init__.py

View workflow job for this annotation

GitHub Actions / lint

`mlflow.models.ModelInputExample` imported but unused; consider removing, adding to `__all__`, or using a redundant alias. See https://docs.astral.sh/ruff/rules/F401 for how to fix this error.
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
Expand Down Expand Up @@ -38,7 +38,7 @@
from mlflow.utils.autologging_utils import autologging_integration
from mlflow.utils.requirements_utils import _get_pinned_requirement

from .autologging import autolog as _autolog

Check failure on line 41 in mlflow/getml/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Prefer absolute imports over relative imports. See https://docs.astral.sh/ruff/rules/TID252 for how to fix this error.

FLAVOR_NAME = "getml"

Expand Down Expand Up @@ -196,7 +196,7 @@


@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
def log_model(

Check failure on line 199 in mlflow/getml/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Missing argument description in the docstring for `log_model`: `getml_pipeline`. See https://docs.astral.sh/ruff/rules/D417 for how to fix this error.
getml_pipeline,
artifact_path,
conda_env=None,
Expand Down Expand Up @@ -302,12 +302,10 @@


def _load_model(path):
import getml

Check failure on line 305 in mlflow/getml/__init__.py

View workflow job for this annotation

GitHub Actions / lint

[*] Import block is un-sorted or un-formatted. Run `ruff --fix .` or comment `/autoformat` to fix this error.
import shutil

import pdb

pdb.set_trace()
with open(os.path.join(path, "getml.yaml")) as f:
getml_settings = yaml.safe_load(f.read())

Expand Down
70 changes: 58 additions & 12 deletions mlflow/getml/autologging.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json

Check failure on line 1 in mlflow/getml/autologging.py

View workflow job for this annotation

GitHub Actions / lint

[*] Import block is un-sorted or un-formatted. Run `ruff --fix .` or comment `/autoformat` to fix this error.
from dataclasses import dataclass, field
from typing import Any
import threading

import mlflow
from mlflow.utils import gorilla

Check failure on line 7 in mlflow/getml/autologging.py

View workflow job for this annotation

GitHub Actions / lint

[*] `mlflow.utils.gorilla` imported but unused. Run `ruff --fix .` or comment `/autoformat` to fix this error.
from mlflow.utils.autologging_utils import safe_patch
from mlflow.utils.autologging_utils.client import MlflowAutologgingQueueingClient



@dataclass
class LogInfo:
params: dict[str, Any] = field(default_factory=dict)
Expand All @@ -30,7 +31,7 @@
log_post_training_metrics=True,
):
flavor_name = "getml"
import getml

Check failure on line 34 in mlflow/getml/autologging.py

View workflow job for this annotation

GitHub Actions / lint

[*] Import block is un-sorted or un-formatted. Run `ruff --fix .` or comment `/autoformat` to fix this error.
from dataclasses import fields, dataclass, is_dataclass

def _patch_pipeline_method(flavor_name, class_def, func_name, patched_fn, manage_run):
Expand All @@ -50,7 +51,7 @@
"feature_learners",
"feature_selectors",
"predictors",
"loss_function",
"share_selected_features",
)
pipeline_informations = {}

Expand All @@ -63,15 +64,27 @@
for field in fields(v):
field_value = getattr(v, field.name)
if isinstance(field_value, (frozenset, set)):
field_value = json.dumps(list(field_value))
try:
field_value = json.dumps(list(field_value))
except:
print("Error in converting frozenset to list")
elif isinstance(field_value, getml.feature_learning.FastProp):
field_value = field_value.__class__.__name__
elif not isinstance(field_value, str):
field_value = json.dumps(field_value)
try:
field_value = json.dumps(field_value)
except:
print("Error in converting field_value to json")
print(field_value)

pipeline_informations[f"{parameter_name}.{name}.{field.name}"] = (
field_value
)
# else:
# value_name = values.__class__.__name__
# pipeline_informations[parameter_name] = value_name
elif isinstance(values, str):
pipeline_informations[parameter_name] = values
else:
value_name = values.__class__.__name__
pipeline_informations[parameter_name] = value_name
tags = [str(t) for t in getml_pipeline.tags]
return LogInfo(params=pipeline_informations, tags=dict(zip(tags, tags)))

Expand All @@ -85,14 +98,14 @@
scores = getml_pipeline.scores

if getml_pipeline.is_classification:
metrics["auc"] = scores.auc
metrics["accuracy"] = scores.accuracy
metrics["cross_entropy"] = scores.cross_entropy
metrics["train_auc"] = round(scores.auc,2)
metrics["train_accuracy"] = round(scores.accuracy, 2)
metrics["train_cross_entropy"] = round(scores.cross_entropy, 4)

if getml_pipeline.is_regression:
metrics["mae"] = scores.mae
metrics["rmse"] = scores.rmse
metrics["rsquared"] = scores.rsquared
metrics["train_mae"] = scores.mae
metrics["train_rmse"] = scores.rmse
metrics["train_rsquared"] = round(scores.rsquared, 2)

# for feature in getml_pipeline.features:
# metrics[f"{feature.name}.importance"] = json.dumps(feature.importance)
Expand Down Expand Up @@ -150,6 +163,10 @@
assert (active_run := mlflow.active_run())
run_id = active_run.info.run_id
pipeline_log_info = _extract_pipeline_informations(self)
# with open("my_dict.json", "w") as f:
# json.dump(pipeline_log_info.params, f)
# mlflow.log_artifact("my_dict.json")
# mlflow.log_dict(pipeline_log_info.params, 'params.json')
autologging_client.log_params(
run_id=run_id,
params=pipeline_log_info.params,
Expand Down Expand Up @@ -186,6 +203,27 @@

autologging_client.flush(synchronous=True)
return fit_output

def patched_score_method(original, self: getml.Pipeline, *args, **kwargs):

target = self.data_model.population.roles.target[0]
pop_df = args[0].population.to_pandas()
pop_df["predictions"] = self.predict(*args)
pop_df['predictions'] = pop_df.round({'predictions': 0})['predictions'].astype(bool)
pop_df[target] = pop_df[target].astype(bool)

mlflow.evaluate(
data = pop_df,
targets=target,
predictions="predictions",
model_type=["regressor" if self.is_regression else "classifier"][0],
evaluators=["default"],
)

score_output = original(self, *args, **kwargs)

return score_output


_patch_pipeline_method(
flavor_name=flavor_name,
Expand All @@ -194,3 +232,11 @@
patched_fn=patched_fit_mlflow,
manage_run=True,
)

_patch_pipeline_method(
flavor_name=flavor_name,
class_def=getml.pipeline.Pipeline,
func_name="score",
patched_fn=patched_score_method,
manage_run=True,
)
Loading