From df3c561bc644675634c4e4b34cc7bb269131aef8 Mon Sep 17 00:00:00 2001 From: kgoebber Date: Mon, 24 Jul 2023 08:55:09 -0500 Subject: [PATCH] make recommended changes --- src/metpy/calc/tools.py | 23 ++++++++++-------- src/metpy/plots/_util.py | 45 ++++++++++++++++++++++------------- tests/calc/test_calc_tools.py | 30 +++++++++++------------ tests/plots/test_util.py | 14 ++++++----- 4 files changed, 65 insertions(+), 47 deletions(-) diff --git a/src/metpy/calc/tools.py b/src/metpy/calc/tools.py index 0c43103edad..de59c304e56 100644 --- a/src/metpy/calc/tools.py +++ b/src/metpy/calc/tools.py @@ -782,38 +782,41 @@ def take(indexer): @exporter.export -def find_local_extrema(var, nsize, extrema): +def find_local_extrema(var, nsize=15, extrema='max'): r"""Find the local extreme (max/min) values of a 2D array. Parameters ---------- - var : `xarray.DataArray` + var : `numpy.array` The variable to locate the local extrema using the nearest method from the maximum_filter or minimum_filter from the scipy.ndimage module. nsize : int The minimum number of grid points between each local extrema. + Default value is 15. extrema: str The value 'max' for local maxima or 'min' for local minima. + Default value is 'max'. Returns ------- - var_extrema: `xarray.DataArray` - The values of the local extrema with other values as NaNs + extrema_mask: `numpy.array` + The boolean array of the local extrema. See Also -------- - :func:`~metpy.plots.plot_local_extrema` + :func:`~metpy.plot.plot_local_extrema` """ from scipy.ndimage import maximum_filter, minimum_filter - if extrema not in ['max', 'min']: - raise ValueError('Invalid input for "extrema". Valid options are "max" or "min".') if extrema == 'max': - extreme_val = maximum_filter(var.values, nsize, mode='nearest') + extreme_val = maximum_filter(var, nsize, mode='nearest') elif extrema == 'min': - extreme_val = minimum_filter(var.values, nsize, mode='nearest') - return var.where(extreme_val == var.values) + extreme_val = minimum_filter(var, nsize, mode='nearest') + else: + raise ValueError(f'Invalid value for "extrema": {extrema}. ' + 'Valid options are "max" or "min".') + return var == extreme_val @exporter.export diff --git a/src/metpy/plots/_util.py b/src/metpy/plots/_util.py index 829c5f775fc..68c832ea195 100644 --- a/src/metpy/plots/_util.py +++ b/src/metpy/plots/_util.py @@ -285,23 +285,31 @@ def normalize(x): return res -def plot_local_extrema(ax, extreme_vals, symbol, plot_val=True, **kwargs): - """Plot the local extreme (max/min) values of an array. +def plot_local_extrema(ax, extrema_mask, vals, x, y, symbol, plot_val=True, **kwargs): + """Plot the local extreme (max/min) values of a 2D array. The behavior of the plotting will have the symbol horizontal/vertical alignment - be center/bottom and any value plotted will be center/top. The text size of plotted + be center/bottom and any value plotted will be center/top. The default text size of plotted values is 0.65 of the symbol size. Parameters ---------- - ax : `matplotlib.axes` + ax: `matplotlib.axes` The axes which to plot the local extrema - extreme_vals : `xarray.DataArray` - The DataArray that contains the variable local extrema + extrema_mask : `numpy.array` + A boolean array that contains the variable local extrema + vals : `numpy.array` + The variable associated with the extrema_mask + x : `numpy.array` + The x-dimension variable associated with the extrema_vals + y : `numpy.array` + The y-dimension variable associated with the extrema_vals symbol : str The text or other string to plot at the local extrema location - plot_val : bool + plot_val: bool Whether to plot the local extreme value (default is True) + textsize: int (optional) + Size of plotted extreme values, Default is 0.65 * size Returns ------- @@ -332,17 +340,22 @@ def plot_local_extrema(ax, extreme_vals, symbol, plot_val=True, **kwargs): if plot_val: kwargs.pop('verticalalignment') size = kwargs.pop('size') - textsize = size * .65 + textsize = kwargs.pop('textsize', size * 0.65) - stack_vals = extreme_vals.stack(x=[extreme_vals.metpy.x.name, extreme_vals.metpy.y.name]) - for extrema in stack_vals[stack_vals.notnull()]: - x = extrema[extreme_vals.metpy.x.name].values - y = extrema[extreme_vals.metpy.y.name].values + extreme_vals = vals[extrema_mask] + if x.ndim == 1: + xx, yy = np.meshgrid(x, y) + else: + xx = x + yy = y + extreme_x = xx[extrema_mask] + extreme_y = yy[extrema_mask] + for extrema, ex_x, ex_y in zip(extreme_vals, extreme_x, extreme_y): if plot_val: - ax.text(x, y, symbol, clip_on=True, clip_box=ax.bbox, size=size, + ax.text(ex_x, ex_y, symbol, clip_on=True, clip_box=ax.bbox, size=size, verticalalignment='bottom', **kwargs) - ax.text(x, y, f'{extrema.values:.0f}', clip_on=True, clip_box=ax.bbox, - size=textsize, verticalalignment='top', **kwargs) + ax.text(ex_x, ex_y, f'{extrema:.0f}', clip_on=True, clip_box=ax.bbox, size=textsize, + verticalalignment='top', **kwargs) else: - ax.text(x, y, symbol, clip_on=True, clip_box=ax.bbox, size=size, + ax.text(ex_x, ex_y, symbol, clip_on=True, clip_box=ax.bbox, size=size, **kwargs) diff --git a/tests/calc/test_calc_tools.py b/tests/calc/test_calc_tools.py index 7b3f2dbebce..7e493fe91d2 100644 --- a/tests/calc/test_calc_tools.py +++ b/tests/calc/test_calc_tools.py @@ -498,21 +498,21 @@ def local_extrema_data(): def test_find_local_extrema(local_extrema_data): """Test find_local_extrema function for maximum.""" - local_max = find_local_extrema(local_extrema_data, 3, 'max') - local_min = find_local_extrema(local_extrema_data, 3, 'min') - - max_truth = np.array([[np.nan, np.nan, np.nan, np.nan, np.nan], - [101637.19, np.nan, np.nan, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, 101212.8]]) - min_truth = np.array([[np.nan, np.nan, np.nan, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, np.nan], - [np.nan, np.nan, np.nan, np.nan, 101159.93], - [np.nan, np.nan, np.nan, np.nan, np.nan]]) - assert_array_almost_equal(local_max.data, max_truth) - assert_array_almost_equal(local_min.data, min_truth) + local_max = find_local_extrema(local_extrema_data.values, 3, 'max') + local_min = find_local_extrema(local_extrema_data.values, 3, 'min') + + max_truth = np.array([[False, False, False, False, False], + [True, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, True]]) + min_truth = np.array([[False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, True], + [False, False, False, False, False]]) + assert_array_almost_equal(local_max, max_truth) + assert_array_almost_equal(local_min, min_truth) with pytest.raises(ValueError): find_local_extrema(local_extrema_data, 3, 'large') diff --git a/tests/plots/test_util.py b/tests/plots/test_util.py index 7ab3a5b3e21..6aee06ba4c5 100644 --- a/tests/plots/test_util.py +++ b/tests/plots/test_util.py @@ -163,19 +163,21 @@ def test_plot_extrema(): """Test plotting of local max/min values.""" data = xr.open_dataset(get_test_data('GFS_test.nc', as_file_obj=False)) - mslp = data.Pressure_reduced_to_MSL_msl.squeeze() - relmax2d = find_local_extrema(mslp, 10, 'max').metpy.convert_units('hPa') - relmin2d = find_local_extrema(mslp, 15, 'min').metpy.convert_units('hPa') + mslp = data.Pressure_reduced_to_MSL_msl.squeeze().metpy.convert_units('hPa') + relmax2d = find_local_extrema(mslp.values, 10, 'max') + relmin2d = find_local_extrema(mslp.values, 15, 'min') fig = plt.figure(figsize=(8., 8.)) ax = fig.add_subplot(1, 1, 1) # Plot MSLP clevmslp = np.arange(800., 1120., 4) - ax.contour(mslp.lon, mslp.lat, mslp.metpy.convert_units('hPa'), + ax.contour(mslp.lon, mslp.lat, mslp, clevmslp, colors='k', linewidths=1.25, linestyles='solid') - plot_local_extrema(ax, relmax2d, 'H', plot_val=False, color='tab:red') - plot_local_extrema(ax, relmin2d, 'L', color='tab:blue') + plot_local_extrema(ax, relmax2d, mslp.values, mslp.lon, mslp.lat, + 'H', plot_val=False, color='tab:red') + plot_local_extrema(ax, relmin2d, mslp.values, mslp.lon, mslp.lat, + 'L', color='tab:blue') return fig