Skip to content

Commit

Permalink
Fix benchmark code
Browse files Browse the repository at this point in the history
  • Loading branch information
gutzbenj committed Jan 24, 2025
1 parent 69c7104 commit da274fb
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 80 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
- Drop click-params dependency
- Make pyarrow a polars extra

### Fix
- Fix benchmark code

## 0.102.0 - 2025-01-17

### Feature
Expand Down
44 changes: 29 additions & 15 deletions benchmarks/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dataclasses import dataclass, field
from zoneinfo import ZoneInfo

import matplotlib.pyplot as plt
import polars as pl
import utm
from scipy import interpolate
Expand Down Expand Up @@ -47,16 +46,15 @@ class Data:


def request_weather_data(
parameter: str,
parameters: list[tuple[str, str, str]],
lat: float,
lon: float,
distance: float,
start_date: dt.datetime,
end_date: dt.datetime,
):
stations = DwdObservationRequest(
parameters=parameter,
resolution="hourly",
parameters=parameters,
start_date=start_date,
end_date=end_date,
)
Expand Down Expand Up @@ -109,29 +107,45 @@ def interpolate_data(latitude: float, longitude: float, data: Data):


def visualize_points(data: Data):
fig, ax = plt.subplots()
ax.scatter(data.utm_y, data.utm_x, color=data.colors)

try:
import plotly.graph_objects as go
except ImportError as e:
raise ImportError("Please install extra `plotting` with wetterdienst[plotting]") from e
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=data.utm_y,
y=data.utm_x,
mode="markers",
marker=dict(color=data.colors),
text=[f"id:{station}\nval:{value : .2f}\n" for station, value in zip(data.station_ids, data.values)],
)
)
for i, (station, value) in enumerate(zip(data.station_ids, data.values)):
ax.annotate(
f"id:{station}\nval:{value : .2f}\n",
(data.utm_y[i], data.utm_x[i]),
horizontalalignment="center",
verticalalignment="bottom",
fig.add_trace(
go.Scatter(
x=[data.utm_y[i]],
y=[data.utm_x[i]],
mode="markers+text",
marker=dict(color=data.colors[i]),
text=f"id:{station}\nval:{value : .2f}\n",
textposition="top center",
)
)
fig.update_layout(showlegend=False)
if "PYTEST_CURRENT_TEST" not in os.environ:
plt.show()
fig.show()


def main():
parameter = [("temperature_air_mean_2m", "temperature_air")]
parameters = [("hourly", "temperature_air", "temperature_air_mean_2m")]
latitude = 50.0
longitude = 8.9
distance = 21.0
start_date = dt.datetime(2022, 1, 1, tzinfo=ZoneInfo("UTC"))
end_date = dt.datetime(2022, 1, 20, tzinfo=ZoneInfo("UTC"))

data = request_weather_data(parameter, latitude, longitude, distance, start_date, end_date)
data = request_weather_data(parameters, latitude, longitude, distance, start_date, end_date)
interpolate_data(latitude, longitude, data)
visualize_points(data)

Expand Down
74 changes: 44 additions & 30 deletions benchmarks/interpolation_over_time.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
# Copyright (C) 2018-2023, earthobservations developers.
# Distributed under the MIT License. See LICENSE for more info.
import datetime as dt
import os
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from sklearn.feature_selection import r_regression
from sklearn.metrics import root_mean_squared_error

from wetterdienst import Parameter
from wetterdienst.provider.dwd.observation import (
DwdObservationRequest,
DwdObservationResolution,
)

plt.style.use("ggplot")


def get_interpolated_df(parameter: str, start_date: datetime, end_date: datetime) -> pl.DataFrame:
def get_interpolated_df(
parameters: tuple[str, str, str], start_date: dt.datetime, end_date: dt.datetime
) -> pl.DataFrame:
stations = DwdObservationRequest(
parameters=parameter,
resolution=DwdObservationResolution.HOURLY,
parameters=parameters,
start_date=start_date,
end_date=end_date,
)
return stations.interpolate(latlon=(50.0, 8.9)).df


def get_regular_df(parameter: str, start_date: datetime, end_date: datetime, exclude_stations: list) -> pl.DataFrame:
def get_regular_df(
parameters: tuple[str, str, str], start_date: dt.datetime, end_date: dt.datetime, exclude_stations: list
) -> pl.DataFrame:
stations = DwdObservationRequest(
parameters=parameter,
resolution=DwdObservationResolution.HOURLY,
parameters=parameters,
start_date=start_date,
end_date=end_date,
)
Expand All @@ -56,36 +53,53 @@ def get_corr(regular_values: pl.Series, interpolated_values: pl.Series) -> float
).item()


def visualize(parameter: str, unit: str, regular_df: pl.DataFrame, interpolated_df: pl.DataFrame):
def visualize(parameter: tuple[str, str, str], unit: str, regular_df: pl.DataFrame, interpolated_df: pl.DataFrame):
try:
import plotly.graph_objects as go
except ImportError as e:
raise ImportError("Please install extra `plotting` with wetterdienst[plotting]") from e

rmse = get_rmse(regular_df.get_column("value"), interpolated_df.get_column("value"))
corr = get_corr(regular_df.get_column("value"), interpolated_df.get_column("value"))
factor = 0.5
plt.figure(figsize=(factor * 19.2, factor * 10.8))
plt.plot(regular_df.get_column("date"), regular_df.get_column("value"), color="red", label="regular")
plt.plot(
interpolated_df.get_column("date"),
interpolated_df.get_column("value"),
color="black",
label="interpolated",
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=regular_df.get_column("date"),
y=regular_df.get_column("value"),
mode="lines",
name="regular",
line=dict(color="red"),
)
)
ylabel = f"{parameter.lower()} [{unit}]"
plt.ylabel(ylabel)
fig.add_trace(
go.Scatter(
x=interpolated_df.get_column("date"),
y=interpolated_df.get_column("value"),
mode="lines",
name="interpolated",
line=dict(color="black"),
)
)

ylabel = f"{parameter[-1].lower()} [{unit}]"
title = (
f"rmse: {np.round(rmse, 2)}, corr: {np.round(corr, 2)}\n"
f"station_ids: {interpolated_df.get_column('taken_station_ids').to_list()[0]}"
)
plt.title(title)
plt.legend()
plt.tight_layout()

fig.update_layout(
title=title, xaxis_title="Date", yaxis_title=ylabel, legend=dict(x=0, y=1), margin=dict(l=40, r=40, t=40, b=40)
)

if "PYTEST_CURRENT_TEST" not in os.environ:
plt.show()
fig.show()


def main():
parameter = Parameter.TEMPERATURE_AIR_MEAN_2M.name
parameter = ("hourly", "air_temperature", "temperature_air_mean_2m")
unit = "K"
start_date = datetime(2022, 3, 1)
end_date = datetime(2022, 3, 31)
start_date = dt.datetime(2022, 3, 1)
end_date = dt.datetime(2022, 3, 31)
interpolated_df = get_interpolated_df(parameter, start_date, end_date)
exclude_stations = interpolated_df.get_column("taken_station_ids")[0]
regular_df = get_regular_df(parameter, start_date, end_date, exclude_stations)
Expand Down
18 changes: 7 additions & 11 deletions benchmarks/interpolation_precipitation_difference.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
# Copyright (C) 2018-2023, earthobservations developers.
# Distributed under the MIT License. See LICENSE for more info.
from datetime import datetime
import datetime as dt

import polars as pl

from wetterdienst import Parameter
from wetterdienst.provider.dwd.observation import (
DwdObservationRequest,
DwdObservationResolution,
)

LATLON = (52.52, 13.40)


def get_interpolated_df(start_date: datetime, end_date: datetime) -> pl.DataFrame:
def get_interpolated_df(start_date: dt.datetime, end_date: dt.datetime) -> pl.DataFrame:
stations = DwdObservationRequest(
parameters=Parameter.PRECIPITATION_HEIGHT,
resolution=DwdObservationResolution.DAILY,
parameters=[("daily", "climate_summary", "precipitation_height")],
start_date=start_date,
end_date=end_date,
)
return stations.interpolate(latlon=LATLON).df


def get_regular_df(start_date: datetime, end_date: datetime, exclude_stations: list) -> pl.DataFrame:
def get_regular_df(start_date: dt.datetime, end_date: dt.datetime, exclude_stations: list) -> pl.DataFrame:
stations = DwdObservationRequest(
parameters=Parameter.PRECIPITATION_HEIGHT.name,
resolution=DwdObservationResolution.DAILY,
parameters=[("daily", "climate_summary", "precipitation_height")],
start_date=start_date,
end_date=end_date,
)
Expand All @@ -46,8 +42,8 @@ def calculate_percentage_difference(df: pl.DataFrame, text: str = "") -> float:


def main():
start_date = datetime(2021, 1, 1)
end_date = datetime(2022, 1, 1)
start_date = dt.datetime(2021, 1, 1)
end_date = dt.datetime(2022, 1, 1)
interpolated_df = get_interpolated_df(start_date, end_date)
print(interpolated_df)
exclude_stations = interpolated_df.get_column("taken_station_ids")[0]
Expand Down
Loading

0 comments on commit da274fb

Please sign in to comment.