diff --git a/CHANGELOG.md b/CHANGELOG.md index 29ae7be61..37e86d1b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ - Drop click-params dependency - Make pyarrow a polars extra +### Fix +- Fix benchmark code + ## 0.102.0 - 2025-01-17 ### Feature diff --git a/benchmarks/interpolation.py b/benchmarks/interpolation.py index 1b123da39..be6f2098e 100644 --- a/benchmarks/interpolation.py +++ b/benchmarks/interpolation.py @@ -5,7 +5,6 @@ from dataclasses import dataclass, field from zoneinfo import ZoneInfo -import matplotlib.pyplot as plt import polars as pl import utm from scipy import interpolate @@ -47,7 +46,7 @@ class Data: def request_weather_data( - parameter: str, + parameters: list[tuple[str, str, str]], lat: float, lon: float, distance: float, @@ -55,8 +54,7 @@ def request_weather_data( end_date: dt.datetime, ): stations = DwdObservationRequest( - parameters=parameter, - resolution="hourly", + parameters=parameters, start_date=start_date, end_date=end_date, ) @@ -109,29 +107,45 @@ def interpolate_data(latitude: float, longitude: float, data: Data): def visualize_points(data: Data): - fig, ax = plt.subplots() - ax.scatter(data.utm_y, data.utm_x, color=data.colors) - + try: + import plotly.graph_objects as go + except ImportError as e: + raise ImportError("Please install extra `plotting` with wetterdienst[plotting]") from e + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=data.utm_y, + y=data.utm_x, + mode="markers", + marker=dict(color=data.colors), + text=[f"id:{station}\nval:{value : .2f}\n" for station, value in zip(data.station_ids, data.values)], + ) + ) for i, (station, value) in enumerate(zip(data.station_ids, data.values)): - ax.annotate( - f"id:{station}\nval:{value : .2f}\n", - (data.utm_y[i], data.utm_x[i]), - horizontalalignment="center", - verticalalignment="bottom", + fig.add_trace( + go.Scatter( + x=[data.utm_y[i]], + y=[data.utm_x[i]], + mode="markers+text", + marker=dict(color=data.colors[i]), + text=f"id:{station}\nval:{value : .2f}\n", + textposition="top center", + ) ) + fig.update_layout(showlegend=False) if "PYTEST_CURRENT_TEST" not in os.environ: - plt.show() + fig.show() def main(): - parameter = [("temperature_air_mean_2m", "temperature_air")] + parameters = [("hourly", "temperature_air", "temperature_air_mean_2m")] latitude = 50.0 longitude = 8.9 distance = 21.0 start_date = dt.datetime(2022, 1, 1, tzinfo=ZoneInfo("UTC")) end_date = dt.datetime(2022, 1, 20, tzinfo=ZoneInfo("UTC")) - data = request_weather_data(parameter, latitude, longitude, distance, start_date, end_date) + data = request_weather_data(parameters, latitude, longitude, distance, start_date, end_date) interpolate_data(latitude, longitude, data) visualize_points(data) diff --git a/benchmarks/interpolation_over_time.py b/benchmarks/interpolation_over_time.py index 33d94f829..9a2230092 100644 --- a/benchmarks/interpolation_over_time.py +++ b/benchmarks/interpolation_over_time.py @@ -1,37 +1,34 @@ # Copyright (C) 2018-2023, earthobservations developers. # Distributed under the MIT License. See LICENSE for more info. +import datetime as dt import os -from datetime import datetime -import matplotlib.pyplot as plt import numpy as np import polars as pl from sklearn.feature_selection import r_regression from sklearn.metrics import root_mean_squared_error -from wetterdienst import Parameter from wetterdienst.provider.dwd.observation import ( DwdObservationRequest, - DwdObservationResolution, ) -plt.style.use("ggplot") - -def get_interpolated_df(parameter: str, start_date: datetime, end_date: datetime) -> pl.DataFrame: +def get_interpolated_df( + parameters: tuple[str, str, str], start_date: dt.datetime, end_date: dt.datetime +) -> pl.DataFrame: stations = DwdObservationRequest( - parameters=parameter, - resolution=DwdObservationResolution.HOURLY, + parameters=parameters, start_date=start_date, end_date=end_date, ) return stations.interpolate(latlon=(50.0, 8.9)).df -def get_regular_df(parameter: str, start_date: datetime, end_date: datetime, exclude_stations: list) -> pl.DataFrame: +def get_regular_df( + parameters: tuple[str, str, str], start_date: dt.datetime, end_date: dt.datetime, exclude_stations: list +) -> pl.DataFrame: stations = DwdObservationRequest( - parameters=parameter, - resolution=DwdObservationResolution.HOURLY, + parameters=parameters, start_date=start_date, end_date=end_date, ) @@ -56,36 +53,53 @@ def get_corr(regular_values: pl.Series, interpolated_values: pl.Series) -> float ).item() -def visualize(parameter: str, unit: str, regular_df: pl.DataFrame, interpolated_df: pl.DataFrame): +def visualize(parameter: tuple[str, str, str], unit: str, regular_df: pl.DataFrame, interpolated_df: pl.DataFrame): + try: + import plotly.graph_objects as go + except ImportError as e: + raise ImportError("Please install extra `plotting` with wetterdienst[plotting]") from e + rmse = get_rmse(regular_df.get_column("value"), interpolated_df.get_column("value")) corr = get_corr(regular_df.get_column("value"), interpolated_df.get_column("value")) - factor = 0.5 - plt.figure(figsize=(factor * 19.2, factor * 10.8)) - plt.plot(regular_df.get_column("date"), regular_df.get_column("value"), color="red", label="regular") - plt.plot( - interpolated_df.get_column("date"), - interpolated_df.get_column("value"), - color="black", - label="interpolated", + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=regular_df.get_column("date"), + y=regular_df.get_column("value"), + mode="lines", + name="regular", + line=dict(color="red"), + ) ) - ylabel = f"{parameter.lower()} [{unit}]" - plt.ylabel(ylabel) + fig.add_trace( + go.Scatter( + x=interpolated_df.get_column("date"), + y=interpolated_df.get_column("value"), + mode="lines", + name="interpolated", + line=dict(color="black"), + ) + ) + + ylabel = f"{parameter[-1].lower()} [{unit}]" title = ( f"rmse: {np.round(rmse, 2)}, corr: {np.round(corr, 2)}\n" f"station_ids: {interpolated_df.get_column('taken_station_ids').to_list()[0]}" ) - plt.title(title) - plt.legend() - plt.tight_layout() + + fig.update_layout( + title=title, xaxis_title="Date", yaxis_title=ylabel, legend=dict(x=0, y=1), margin=dict(l=40, r=40, t=40, b=40) + ) + if "PYTEST_CURRENT_TEST" not in os.environ: - plt.show() + fig.show() def main(): - parameter = Parameter.TEMPERATURE_AIR_MEAN_2M.name + parameter = ("hourly", "air_temperature", "temperature_air_mean_2m") unit = "K" - start_date = datetime(2022, 3, 1) - end_date = datetime(2022, 3, 31) + start_date = dt.datetime(2022, 3, 1) + end_date = dt.datetime(2022, 3, 31) interpolated_df = get_interpolated_df(parameter, start_date, end_date) exclude_stations = interpolated_df.get_column("taken_station_ids")[0] regular_df = get_regular_df(parameter, start_date, end_date, exclude_stations) diff --git a/benchmarks/interpolation_precipitation_difference.py b/benchmarks/interpolation_precipitation_difference.py index e706f19c8..d9defc27a 100644 --- a/benchmarks/interpolation_precipitation_difference.py +++ b/benchmarks/interpolation_precipitation_difference.py @@ -1,32 +1,28 @@ # Copyright (C) 2018-2023, earthobservations developers. # Distributed under the MIT License. See LICENSE for more info. -from datetime import datetime +import datetime as dt import polars as pl -from wetterdienst import Parameter from wetterdienst.provider.dwd.observation import ( DwdObservationRequest, - DwdObservationResolution, ) LATLON = (52.52, 13.40) -def get_interpolated_df(start_date: datetime, end_date: datetime) -> pl.DataFrame: +def get_interpolated_df(start_date: dt.datetime, end_date: dt.datetime) -> pl.DataFrame: stations = DwdObservationRequest( - parameters=Parameter.PRECIPITATION_HEIGHT, - resolution=DwdObservationResolution.DAILY, + parameters=[("daily", "climate_summary", "precipitation_height")], start_date=start_date, end_date=end_date, ) return stations.interpolate(latlon=LATLON).df -def get_regular_df(start_date: datetime, end_date: datetime, exclude_stations: list) -> pl.DataFrame: +def get_regular_df(start_date: dt.datetime, end_date: dt.datetime, exclude_stations: list) -> pl.DataFrame: stations = DwdObservationRequest( - parameters=Parameter.PRECIPITATION_HEIGHT.name, - resolution=DwdObservationResolution.DAILY, + parameters=[("daily", "climate_summary", "precipitation_height")], start_date=start_date, end_date=end_date, ) @@ -46,8 +42,8 @@ def calculate_percentage_difference(df: pl.DataFrame, text: str = "") -> float: def main(): - start_date = datetime(2021, 1, 1) - end_date = datetime(2022, 1, 1) + start_date = dt.datetime(2021, 1, 1) + end_date = dt.datetime(2022, 1, 1) interpolated_df = get_interpolated_df(start_date, end_date) print(interpolated_df) exclude_stations = interpolated_df.get_column("taken_station_ids")[0] diff --git a/benchmarks/summary_over_time.py b/benchmarks/summary_over_time.py index 915ec7b9d..c1458d144 100644 --- a/benchmarks/summary_over_time.py +++ b/benchmarks/summary_over_time.py @@ -1,32 +1,27 @@ # Copyright (C) 2018-2023, earthobservations developers. # Distributed under the MIT License. See LICENSE for more info. +import datetime as dt import os -from datetime import datetime -import matplotlib.pyplot as plt import polars as pl -from wetterdienst import Parameter from wetterdienst.provider.dwd.observation import ( DwdObservationRequest, - DwdObservationResolution, ) -def get_summarized_df(start_date: datetime, end_date: datetime, lat, lon) -> pl.DataFrame: +def get_summarized_df(start_date: dt.datetime, end_date: dt.datetime, lat, lon) -> pl.DataFrame: stations = DwdObservationRequest( - parameters=Parameter.TEMPERATURE_AIR_MEAN_2M, - resolution=DwdObservationResolution.DAILY, + parameters=[("daily", "climate_summary", "temperature_air_mean_2m")], start_date=start_date, end_date=end_date, ) return stations.summarize(latlon=(lat, lon)).df -def get_regular_df(start_date: datetime, end_date: datetime, station_id) -> pl.DataFrame: +def get_regular_df(start_date: dt.datetime, end_date: dt.datetime, station_id) -> pl.DataFrame: stations = DwdObservationRequest( - parameters=Parameter.TEMPERATURE_AIR_MEAN_2M, - resolution=DwdObservationResolution.DAILY, + parameters=[("daily", "climate_summary", "temperature_air_mean_2m")], start_date=start_date, end_date=end_date, ) @@ -35,8 +30,8 @@ def get_regular_df(start_date: datetime, end_date: datetime, station_id) -> pl.D def main(): - start_date = datetime(1934, 1, 1) - end_date = datetime(1980, 12, 31) + start_date = dt.datetime(1934, 1, 1) + end_date = dt.datetime(1980, 12, 31) lat = 51.0221 lon = 13.8470 @@ -53,22 +48,82 @@ def main(): regular_df_01051 = get_regular_df(start_date, end_date, "01051") regular_df_05282 = get_regular_df(start_date, end_date, "05282") - fig, ax = plt.subplots(nrows=5, tight_layout=True, sharex=True, sharey=True) + try: + import plotly.graph_objects as go + from plotly.subplots import make_subplots + except ImportError as e: + raise ImportError("Please install extra `plotting` with wetterdienst[plotting]") from e + + fig = make_subplots(rows=5, shared_xaxes=True, subplot_titles=("Summarized", "01050", "01051", "01048", "05282")) + + fig.add_trace( + go.Scatter( + x=summarized_df.get_column("date"), + y=summarized_df.get_column("value"), + mode="markers", + marker=dict(color=summarized_df.get_column("color")), + name="summarized", + ), + row=1, + col=1, + ) - summarized_df.to_pandas().plot("date", "value", c="color", label="summarized", kind="scatter", ax=ax[0], s=5) - regular_df_01050.to_pandas().plot("date", "value", color="yellow", label="01050", ax=ax[1]) - regular_df_01051.to_pandas().plot("date", "value", color="blue", label="01051", ax=ax[2]) - regular_df_01048.to_pandas().plot("date", "value", color="green", label="01048", ax=ax[3]) - regular_df_05282.to_pandas().plot("date", "value", color="pink", label="05282", ax=ax[4]) + fig.add_trace( + go.Scatter( + x=regular_df_01050.get_column("date"), + y=regular_df_01050.get_column("value"), + mode="lines", + line=dict(color="yellow"), + name="01050", + ), + row=2, + col=1, + ) - ax[0].set_ylabel(None) + fig.add_trace( + go.Scatter( + x=regular_df_01051.get_column("date"), + y=regular_df_01051.get_column("value"), + mode="lines", + line=dict(color="blue"), + name="01051", + ), + row=3, + col=1, + ) + + fig.add_trace( + go.Scatter( + x=regular_df_01048.get_column("date"), + y=regular_df_01048.get_column("value"), + mode="lines", + line=dict(color="green"), + name="01048", + ), + row=4, + col=1, + ) + + fig.add_trace( + go.Scatter( + x=regular_df_05282.get_column("date"), + y=regular_df_05282.get_column("value"), + mode="lines", + line=dict(color="pink"), + name="05282", + ), + row=5, + col=1, + ) + + fig.update_layout( + title="Comparison of summarized values and single stations for temperature_air_mean_2m", + legend_title="Stations", + showlegend=True, + ) - title = "Comparison of summarized values and single stations for\ntemperature_air_mean_2m" - plt.suptitle(title) - plt.legend() - plt.tight_layout() if "PYTEST_CURRENT_TEST" not in os.environ: - plt.show() + fig.show() if __name__ == "__main__":