Skip to content

Commit

Permalink
Stripes: Replace matplotlib by plotly
Browse files Browse the repository at this point in the history
  • Loading branch information
gutzbenj committed Jan 22, 2025
1 parent 5c52af7 commit 45efa7f
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 113 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ interpolation = [
"utm>=0.7,<1",
]
plotting = [
"kaleido>=0.2.1",
"matplotlib>=3.3,<4",
"plotly>=5.24.1",
]
mysql = [
"mysqlclient>=2,<3",
Expand Down
12 changes: 12 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 5 additions & 12 deletions wetterdienst/ui/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import cloup
from cloup import Section
from cloup.constraints import If, RequireExactly, accept_none
from PIL import Image

from wetterdienst import Settings, Wetterdienst, __appname__, __version__
from wetterdienst.ui.core import (
Expand Down Expand Up @@ -1279,12 +1278,10 @@ def stripes_values(
if not target.name.lower().endswith(fmt):
raise click.ClickException(f"'target' must have extension '{fmt}'")

import matplotlib.pyplot as plt

set_logging_level(debug)

try:
buf = _plot_stripes(
fig = _plot_stripes(
kind=kind,
station_id=station,
name=name,
Expand All @@ -1294,21 +1291,16 @@ def stripes_values(
show_title=show_title,
show_years=show_years,
show_data_availability=show_data_availability,
fmt=fmt,
dpi=dpi,
)
except Exception as e:
log.exception(e)
raise click.ClickException(str(e)) from e

if target:
image = Image.open(buf, formats=["png"])
plt.imshow(image)
plt.axis("off")
plt.savefig(target, dpi=300, bbox_inches="tight")
fig.write_image(target, fmt, scale=dpi / 100)
return

click.echo(buf.getvalue(), nl=False)
click.echo(fig.to_image(fmt, scale=dpi / 100), nl=False)


@stripes.command("interactive")
Expand All @@ -1318,7 +1310,8 @@ def interactive(debug: bool):

try:
from wetterdienst.ui.streamlit.stripes import app
except ImportError:
except ImportError as e:
log.exception(e)
log.error("Please install the stripes extras from stripes/requirements.txt")
sys.exit(1)

Expand Down
152 changes: 67 additions & 85 deletions wetterdienst/ui/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import json
import logging
import sys
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import TYPE_CHECKING, Literal

import polars as pl
Expand All @@ -20,6 +18,8 @@
from wetterdienst.util.ui import read_list

if TYPE_CHECKING:
from plotly.graph_objs import Figure

from wetterdienst.core.timeseries.request import TimeseriesRequest
from wetterdienst.core.timeseries.result import (
InterpolatedValuesResult,
Expand Down Expand Up @@ -563,9 +563,7 @@ def _plot_stripes(
show_title: bool = True,
show_years: bool = True,
show_data_availability: bool = True,
fmt: str | Literal["png", "jpg", "svg", "pdf"] = "png",
dpi: int = 100,
) -> BytesIO:
) -> Figure:
"""Create warming stripes for station in Germany.
Code similar to: https://www.s4f-freiburg.de/temperaturstreifen/
"""
Expand All @@ -576,18 +574,12 @@ def _plot_stripes(
raise ValueError("start_year must be less than end_year")
if name_threshold < 0 or name_threshold > 1:
raise ValueError("name_threshold must be between 0.0 and 1.0")
if dpi <= 0:
raise ValueError("dpi must be more than 0")

import matplotlib
import matplotlib.pyplot as plt
import plotly.graph_objects as go

request = CLIMATE_STRIPES_CONFIG[kind]["request"]()
cmap = CLIMATE_STRIPES_CONFIG[kind]["color_map"]

matplotlib.use("agg")
color_map = plt.get_cmap(cmap)

if station_id:
stations = request.filter_by_station_id(station_id)
elif name:
Expand All @@ -614,9 +606,6 @@ def _plot_stripes(
),
pl.when(pl.col("value").is_not_null()).then(-0.02).otherwise(None).alias("availability"),
)
df = df.with_columns(
pl.col("value_scaled").map_elements(color_map, return_dtype=pl.List(pl.Float64)).alias("color")
)

if start_year:
df = df.filter(pl.col("date").dt.year().ge(start_year))
Expand All @@ -626,84 +615,77 @@ def _plot_stripes(
if len(df) == 1:
raise ValueError("At least two years are required to create warming stripes.")

fig, ax = plt.subplots(tight_layout=True)

df_without_nulls = df.drop_nulls("value")

ax.bar(
df_without_nulls.get_column("date").dt.year(),
1.0,
width=1.0,
color=df_without_nulls.get_column("color"),
fig = go.Figure()

# Add bar trace
fig.add_trace(
go.Bar(
x=df_without_nulls.get_column("date").dt.year(),
y=[1.0] * len(df_without_nulls),
marker=dict(color=df_without_nulls.get_column("value_scaled"), colorscale=cmap, cmin=0, cmax=1),
width=1.0,
)
)
ax.set_axis_off()

# Add scatter trace for data availability
if show_data_availability:
ax.scatter(
df.get_column("date").dt.year(),
df.get_column("availability"),
color="gold",
marker=",",
s=0.5,
)
ax.plot(
df.get_column("date").dt.year(),
df.get_column("availability"),
color="gold",
fig.add_trace(
go.Scatter(
x=df.get_column("date").dt.year(),
y=df.get_column("availability"),
mode="lines",
marker=dict(color="gold", size=5),
line=dict(color="gold"),
)
)
ax.text(
df.get_column("date").dt.year().min(),
-0.03,
"data availability",
ha="left",
va="top",
color="gold",
fig.add_annotation(
x=df.get_column("date").dt.year().min(),
xanchor="left",
y=-0.05,
text="data availability",
showarrow=False,
align="right",
font=dict(color="gold"),
)
ax.text(0.5, -0.04, "Source: Deutscher Wetterdienst", ha="center", va="center", transform=ax.transAxes)

# Add source text
fig.add_annotation(
x=0.5, y=-0.05, text="Source: Deutscher Wetterdienst", showarrow=False, xref="paper", yref="paper"
)
if show_title:
ax.set_title(f"""Climate stripes ({kind}) for {station_dict["name"]}, Germany ({station_dict["station_id"]})""")
fig.update_layout(
title=f"Climate stripes ({kind}) for {station_dict['name']}, Germany ({station_dict['station_id']})"
)
if show_years:
ax.text(0.05, -0.05, df.get_column("date").min().year, ha="center", va="center", transform=ax.transAxes)
ax.text(0.95, -0.05, df.get_column("date").max().year, ha="center", va="center", transform=ax.transAxes)

buf = BytesIO()
plt.savefig(buf, format=fmt, dpi=dpi, bbox_inches="tight")
plt.close(fig)
buf.seek(0)

return buf


def _thread_safe_plot_stripes(
kind: Literal["temperature", "precipitation"],
station_id: str | None = None,
name: str | None = None,
start_year: int | None = None,
end_year: int | None = None,
name_threshold: float = 0.9,
show_title: bool = True,
show_years: bool = True,
show_data_availability: bool = True,
fmt: str | Literal["png", "jpg", "svg", "pdf"] = "png",
dpi: int = 100,
) -> BytesIO:
"""Thread-safe wrapper for _plot_warming_stripes because matplotlib is not thread-safe."""
with ThreadPoolExecutor(1) as executor:
return executor.submit(
lambda: _plot_stripes(
kind=kind,
station_id=station_id,
name=name,
start_year=start_year,
end_year=end_year,
name_threshold=name_threshold,
show_title=show_title,
show_years=show_years,
show_data_availability=show_data_availability,
fmt=fmt,
dpi=dpi,
)
).result()
fig.add_annotation(
x=0.05,
y=-0.05,
text=str(df.get_column("date").min().year),
showarrow=False,
xref="paper",
yref="paper",
xanchor="right",
)
fig.add_annotation(
x=0.95,
y=-0.05,
text=str(df.get_column("date").max().year),
showarrow=False,
xref="paper",
yref="paper",
xanchor="left",
)
fig.update_layout(
plot_bgcolor="white",
xaxis=dict(
showticklabels=False,
),
yaxis=dict(range=[None, 1], showticklabels=False),
showlegend=False,
margin=dict(l=10, r=10, t=30, b=30),
)
return fig


def set_logging_level(debug: bool):
Expand Down
8 changes: 3 additions & 5 deletions wetterdienst/ui/restapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ValuesRequest,
ValuesRequestRaw,
_get_stripes_stations,
_thread_safe_plot_stripes,
_plot_stripes,
get_interpolate,
get_stations,
get_summarize,
Expand Down Expand Up @@ -518,7 +518,7 @@ def stripes_values(
)

try:
buf = _thread_safe_plot_stripes(
fig = _plot_stripes(
kind=kind,
station_id=station,
name=name,
Expand All @@ -528,14 +528,12 @@ def stripes_values(
show_title=show_title,
show_years=show_years,
show_data_availability=show_data_availability,
fmt=fmt,
dpi=dpi,
)
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail=str(e)) from e
media_type = f"image/{fmt}"
return Response(content=buf.getvalue(), media_type=media_type)
return Response(content=fig.to_image(fmt, scale=dpi / 100), media_type=media_type)


def start_service(listen_address: str | None = None, reload: bool | None = False): # pragma: no cover
Expand Down
2 changes: 0 additions & 2 deletions wetterdienst/ui/streamlit/climate_stripes/requirements.txt

This file was deleted.

Loading

0 comments on commit 45efa7f

Please sign in to comment.