Skip to content

Commit

Permalink
fix over_time function
Browse files Browse the repository at this point in the history
  • Loading branch information
yusufuyanik1 committed Jan 23, 2025
1 parent b817f63 commit 8cd7992
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 25 deletions.
49 changes: 30 additions & 19 deletions python/pdstools/adm/Plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def over_time(
by: Union[pl.Expr, str] = "ModelID",
*,
every: Union[str, timedelta] = "1d",
cumulative: bool = False,
show_changes: bool = False,
query: Optional[QUERY] = None,
facet: Optional[str] = None,
return_df: bool = False,
Expand All @@ -280,8 +280,8 @@ def over_time(
The column to group by, by default "ModelID"
every : Union[str, timedelta], optional
By what time period to group, by default "1d"
cumulative : bool, optional
Whether to take the cumulative value or the absolute one, by default False
show_changes : bool, optional
Whether to show period-over-period changes instead of raw values, by default False
query : Optional[QUERY], optional
The query to apply to the data, by default None
facet : Optional[str], optional
Expand All @@ -299,6 +299,10 @@ def over_time(
"Performance_weighted_average": ":.2", # is not a percentage!
"Positives": ":.d",
"ResponseCount": ":.d",
"SuccessRate_weighted_average_change": ":.4%",
"Performance_weighted_average_change": ":.2",
"Positives_change": ":.d",
"ResponseCount_change": ":.d",
}

if metric == "Performance":
Expand All @@ -315,9 +319,7 @@ def over_time(
columns_to_select.add(facet)

df = (
cdh_utils._apply_query(
self.datamart.model_data.sort(by="SnapshotTime"), query
)
cdh_utils._apply_query(self.datamart.model_data, query)
.sort("SnapshotTime")
.select(list(columns_to_select))
)
Expand All @@ -328,11 +330,10 @@ def over_time(
grouping_columns = [by_col, facet]
else:
grouping_columns = [by_col]
if metric in ["Performance", "SuccessRate"]: # we need to weigh these

if metric in ["Performance", "SuccessRate"]:
df = (
df.group_by_dynamic(
"SnapshotTime", every=every, group_by=grouping_columns
)
df.group_by(grouping_columns + ["SnapshotTime"])
.agg(
(
metric_scaling
Expand All @@ -342,48 +343,58 @@ def over_time(
.sort("SnapshotTime", by_col)
)
metric += "_weighted_average"
elif cumulative:
else:
df = (
df.group_by(grouping_columns + ["SnapshotTime"])
.agg(pl.sum(metric))
.sort(grouping_columns + ["SnapshotTime"])
)
else:

if show_changes:
df = (
df.with_columns(
Delta=pl.col(metric).cast(pl.Int64).diff().over(grouping_columns)
PeriodChange=pl.col(metric).diff().over(grouping_columns)
)
.group_by_dynamic(
"SnapshotTime", every=every, group_by=grouping_columns
)
.agg(Increase=pl.sum("Delta"))
.agg(pl.sum("PeriodChange").alias(f"{metric}_change"))
)
plot_metric = f"{metric}_change"
else:
plot_metric = metric

if return_df:
return df

final_df = df.collect()
unique_facet_values = final_df.select(facet).unique().shape[0]
facet_col_wrap = max(2, int(unique_facet_values**0.5))

facet_col_wrap = None
if facet:
unique_facet_values = final_df.select(facet).unique().shape[0]
facet_col_wrap = max(2, int(unique_facet_values**0.5))

title = "over all models" if facet is None else f"per {facet}"
fig = px.line(
final_df,
x="SnapshotTime",
y=metric,
y=plot_metric,
color=by_col,
hover_data={
by_col: ":.d",
metric: metric_formatting[metric],
plot_metric: metric_formatting[plot_metric],
},
markers=True,
title=f"{metric} over time, per {by_col} {title}",
facet_col=facet,
facet_col_wrap=facet_col_wrap,
template="pega",
)
if metric in ["SuccessRate"]:

if metric.startswith("SuccessRate"):
fig.update_yaxes(tickformat=".2%")
fig.update_layout(yaxis={"rangemode": "tozero"})

return fig

@requires({"ModelID", "Name"})
Expand Down
61 changes: 55 additions & 6 deletions python/tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
from pathlib import Path

import plotly.express as px
Expand All @@ -19,6 +20,28 @@ def sample():
)


@pytest.fixture
def sample2():
data = pl.DataFrame(
{
"SnapshotTime": [
datetime(2024, 1, 1),
datetime(2024, 1, 1),
datetime(2024, 1, 2),
datetime(2024, 1, 2),
datetime(2024, 1, 3),
datetime(2024, 1, 3),
],
"ModelID": [1, 2, 1, 2, 1, 2],
"Performance": [0.75, 0.80, 0.78, 0.82, 0.77, 0.85],
"ResponseCount": [100, 150, 120, 160, 110, 170],
"Positives": [100, 150, 120, 160, 110, 170],
"Group": ["A", "B", "A", "B", "A", "B"],
}
).lazy()
return ADMDatamart(model_df=data)


def test_bubble_chart(sample: ADMDatamart):
df = sample.plot.bubble_chart(return_df=True)

Expand All @@ -31,12 +54,38 @@ def test_bubble_chart(sample: ADMDatamart):
assert plot is not None


def test_over_time(sample: ADMDatamart):
df = sample.plot.over_time(return_df=True).collect()
assert df.shape == (70, 3)
assert round(df.sort("ModelID").row(0)[2], 2) == 55.46
plot = sample.plot.over_time()
assert plot is not None
def test_over_time(sample2: ADMDatamart):
fig = sample2.plot.over_time(metric="Performance", by="ModelID")
assert fig is not None

fig = sample2.plot.over_time(metric="ResponseCount", by="ModelID")
assert fig is not None

performance_changes = (
sample2.plot.over_time(metric="Performance", show_changes=True, return_df=True)
.collect()
.get_column("Performance_weighted_average_change")
.to_list()
)
assert performance_changes == [0.0, 3.0, -1.0, 0.0, 2.0, 3.0]

responses_over_time = (
sample2.plot.over_time(metric="ResponseCount", by="ModelID", return_df=True)
.collect()
.get_column("ResponseCount")
.to_list()
)
assert responses_over_time == [100.0, 120.0, 110.0, 150.0, 160.0, 170.0]

fig_faceted = sample2.plot.over_time(
metric="Performance", by="ModelID", facet="Group"
)
assert fig_faceted is not None

with pytest.raises(
ValueError, match="The given query resulted in no more remaining data."
):
sample2.plot.over_time(query=pl.col("ModelID") == "3")


def test_proposition_success_rates(sample: ADMDatamart):
Expand Down

0 comments on commit 8cd7992

Please sign in to comment.