Skip to content

Commit

Permalink
make recommended changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kgoebber committed Jul 24, 2023
1 parent 3e7f2de commit d4ca452
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 45 deletions.
21 changes: 12 additions & 9 deletions src/metpy/calc/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,38 +782,41 @@ def take(indexer):


@exporter.export
def find_local_extrema(var, nsize, extrema):
def find_local_extrema(var, nsize=15, extrema='max'):
r"""Find the local extreme (max/min) values of a 2D array.
Parameters
----------
var : `xarray.DataArray`
var : `numpy.array`
The variable to locate the local extrema using the nearest method
from the maximum_filter or minimum_filter from the scipy.ndimage module.
nsize : int
The minimum number of grid points between each local extrema.
Default value is 15.
extrema: str
The value 'max' for local maxima or 'min' for local minima.
Default value is 'max'.
Returns
-------
var_extrema: `xarray.DataArray`
The values of the local extrema with other values as NaNs
extrema_mask: `numpy.array`
The boolean array of the local extrema.
See Also
--------
:func:`~metpy.plots.plot_local_extrema`
"""
from scipy.ndimage import maximum_filter, minimum_filter
if extrema not in ['max', 'min']:
raise ValueError('Invalid input for "extrema". Valid options are "max" or "min".')

if extrema == 'max':
extreme_val = maximum_filter(var.values, nsize, mode='nearest')
extreme_val = maximum_filter(var, nsize, mode='nearest')
elif extrema == 'min':
extreme_val = minimum_filter(var.values, nsize, mode='nearest')
return var.where(extreme_val == var.values)
extreme_val = minimum_filter(var, nsize, mode='nearest')
else:
raise ValueError(f'Invalid value for "extrema": {extrema}. '
'Valid options are "max" or "min".')
return var == extreme_val


@exporter.export
Expand Down
43 changes: 28 additions & 15 deletions src/metpy/plots/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,23 +285,31 @@ def normalize(x):
return res


def plot_local_extrema(ax, extreme_vals, symbol, plot_val=True, **kwargs):
"""Plot the local extreme (max/min) values of an array.
def plot_local_extrema(ax, extrema_mask, vals, x, y, symbol, plot_val=True, **kwargs):
"""Plot the local extreme (max/min) values of a 2D array.
The behavior of the plotting will have the symbol horizontal/vertical alignment
be center/bottom and any value plotted will be center/top. The text size of plotted
be center/bottom and any value plotted will be center/top. The default text size of plotted
values is 0.65 of the symbol size.
Parameters
----------
ax : `matplotlib.axes`
ax: `matplotlib.axes`
The axes which to plot the local extrema
extreme_vals : `xarray.DataArray`
The DataArray that contains the variable local extrema
extrema_mask : `numpy.array`
A boolean array that contains the variable local extrema
vals : `numpy.array`
The variable associated with the extrema_mask
x : `numpy.array`
The x-dimension variable associated with the extrema_vals
y : `numpy.array`
The y-dimension variable associated with the extrema_vals
symbol : str
The text or other string to plot at the local extrema location
plot_val : bool
plot_val: bool
Whether to plot the local extreme value (default is True)
textsize: int (optional)
Size of plotted extreme values, Default is 0.65 * size
Returns
-------
Expand Down Expand Up @@ -332,17 +340,22 @@ def plot_local_extrema(ax, extreme_vals, symbol, plot_val=True, **kwargs):
if plot_val:
kwargs.pop('verticalalignment')
size = kwargs.pop('size')
textsize = size * .65
textsize = kwargs.pop('textsize', size * 0.65)

stack_vals = extreme_vals.stack(x=[extreme_vals.metpy.x.name, extreme_vals.metpy.y.name])
for extrema in stack_vals[stack_vals.notnull()]:
x = extrema[extreme_vals.metpy.x.name].values
y = extrema[extreme_vals.metpy.y.name].values
extreme_vals = vals[extrema_mask]
if x.ndim == 1:
xx, yy = np.meshgrid(x, y)
else:
xx = x
yy = y
extreme_x = xx[extrema_mask]
extreme_y = yy[extrema_mask]
for extrema, ex_x, ex_y in zip(extreme_vals, extreme_x, extreme_y):
if plot_val:
ax.text(x, y, symbol, clip_on=True, clip_box=ax.bbox, size=size,
ax.text(ex_x, ex_y, symbol, clip_on=True, clip_box=ax.bbox, size=size,
verticalalignment='bottom', **kwargs)
ax.text(x, y, f'{extrema.values:.0f}', clip_on=True, clip_box=ax.bbox,
ax.text(ex_x, ex_y, f'{extrema:.0f}', clip_on=True, clip_box=ax.bbox,
size=textsize, verticalalignment='top', **kwargs)
else:
ax.text(x, y, symbol, clip_on=True, clip_box=ax.bbox, size=size,
ax.text(ex_x, ex_y, symbol, clip_on=True, clip_box=ax.bbox, size=size,
**kwargs)
30 changes: 15 additions & 15 deletions tests/calc/test_calc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,21 +498,21 @@ def local_extrema_data():

def test_find_local_extrema(local_extrema_data):
"""Test find_local_extrema function for maximum."""
local_max = find_local_extrema(local_extrema_data, 3, 'max')
local_min = find_local_extrema(local_extrema_data, 3, 'min')

max_truth = np.array([[np.nan, np.nan, np.nan, np.nan, np.nan],
[101637.19, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, 101212.8]])
min_truth = np.array([[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, 101159.93],
[np.nan, np.nan, np.nan, np.nan, np.nan]])
assert_array_almost_equal(local_max.data, max_truth)
assert_array_almost_equal(local_min.data, min_truth)
local_max = find_local_extrema(local_extrema_data.values, 3, 'max')
local_min = find_local_extrema(local_extrema_data.values, 3, 'min')

max_truth = np.array([[False, False, False, False, False],
[True, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, True]])
min_truth = np.array([[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, True],
[False, False, False, False, False]])
assert_array_almost_equal(local_max, max_truth)
assert_array_almost_equal(local_min, min_truth)

with pytest.raises(ValueError):
find_local_extrema(local_extrema_data, 3, 'large')
Expand Down
14 changes: 8 additions & 6 deletions tests/plots/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,21 @@ def test_plot_extrema():
"""Test plotting of local max/min values."""
data = xr.open_dataset(get_test_data('GFS_test.nc', as_file_obj=False))

mslp = data.Pressure_reduced_to_MSL_msl.squeeze()
relmax2d = find_local_extrema(mslp, 10, 'max').metpy.convert_units('hPa')
relmin2d = find_local_extrema(mslp, 15, 'min').metpy.convert_units('hPa')
mslp = data.Pressure_reduced_to_MSL_msl.squeeze().metpy.convert_units('hPa')
relmax2d = find_local_extrema(mslp.values, 10, 'max')
relmin2d = find_local_extrema(mslp.values, 15, 'min')

fig = plt.figure(figsize=(8., 8.))
ax = fig.add_subplot(1, 1, 1)

# Plot MSLP
clevmslp = np.arange(800., 1120., 4)
ax.contour(mslp.lon, mslp.lat, mslp.metpy.convert_units('hPa'),
ax.contour(mslp.lon, mslp.lat, mslp,
clevmslp, colors='k', linewidths=1.25, linestyles='solid')

plot_local_extrema(ax, relmax2d, 'H', plot_val=False, color='tab:red')
plot_local_extrema(ax, relmin2d, 'L', color='tab:blue')
plot_local_extrema(ax, relmax2d, mslp.values, mslp.lon, mslp.lat,
'H', plot_val=False, color='tab:red')
plot_local_extrema(ax, relmin2d, mslp.values, mslp.lon, mslp.lat,
'L', color='tab:blue')

return fig

0 comments on commit d4ca452

Please sign in to comment.