Skip to content

Commit

Permalink
Fill veg nodata due to wet and ue with spatial mode (#178)
Browse files Browse the repository at this point in the history
* fill veg nodata with mode

* Revert "fix nodata ring along the water (#177)"

This reverts commit 78d8c62.

---------

Co-authored-by: Emma Ai <[email protected]>
  • Loading branch information
emmaai and Emma Ai authored Dec 17, 2024
1 parent 78d8c62 commit 0cf4d97
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 102 deletions.
86 changes: 86 additions & 0 deletions odc/stats/plugins/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
import operator
import numpy as np
import dask
from osgeo import gdal, ogr, osr
from functools import partial


def rasterize_vector_mask(
Expand Down Expand Up @@ -140,3 +142,87 @@ def generate_numexpr_expressions(rules_df, final_class_column, previous):
expressions = sorted(expressions, key=len)

return expressions


def numpy_mode_exclude_nodata(values, target_value, exclude_values):
"""
Compute the mode of an array using NumPy, excluding nodata.
:param values: A flattened 1D array representing the neighborhood.
:param target_value: The value to be replaced
:param exclude_values: A list or set of values to exclude from the mode calculation.
:return: The mode of the array (smallest value in case of ties), excluding nodata.
"""

valid_mask = ~(
np.isin(values, list(set(exclude_values) | {target_value})) | np.isnan(values)
)
valid_values = values[valid_mask]
if len(valid_values) == 0:
return target_value
unique_vals, counts = np.unique(valid_values, return_counts=True)
max_count = counts.max()
# select the smallest value among ties
mode_value = unique_vals[counts == max_count].min()
return mode_value


def process_nodata_pixels(block, target_value, exclude_values, max_radius):
"""
Replace nodata pixels in a block with the mode of their 3x3 neighborhood.
:param block : numpy.ndarray The 2D array chunk.
:param target_value: The value to be replaced
:param exclude_values: A list or set of values to exclude from the mode calculation.
:param max_radius: maximum size of neighbourhood
:return: numpy.ndarray The modified block where nodata pixels are replaced.
"""
result = block.copy()
nodata_indices = np.argwhere(block == target_value)

for i, j in nodata_indices:
# start from the smallest/nearest neighbourhood
# stop once finding the valid value otherwise expand till the max_radius
for radius in range(1, max_radius + 1):
i_min, i_max = max(0, i - radius), min(block.shape[0], i + radius + 1)
j_min, j_max = max(0, j - radius), min(block.shape[1], j + radius + 1)

neighborhood = block[i_min:i_max, j_min:j_max].flatten()
tmp = numpy_mode_exclude_nodata(neighborhood, target_value, exclude_values)
if np.isnan(tmp) | (tmp == target_value):
continue
result[i, j] = tmp
break

return result


def replace_nodata_with_mode(
arr, target_value, exclude_values=None, neighbourhood_size=3
):
"""
Replace nodata-valued pixels in a Dask array with the mode of their neighborhood,
processing only the nodata pixels.
:param arr: A 2D Dask array.
:param target_value: The value to be replaced
:param exclude_values: A list or set of values to exclude from the mode calculation.
:param neighbourhood_size: the size of neighbourhood, e.g., 3:= 3*3 block, 5:=5*5 block
:return: A Dask array where nodata-valued pixels have been replaced.
"""
if exclude_values is None:
exclude_values = set()

radius = neighbourhood_size // 2
process_func = partial(
process_nodata_pixels,
target_value=target_value,
exclude_values=exclude_values,
max_radius=radius,
)
# Use map_overlap to handle edges and target only the nodata pixels
result = arr.map_overlap(
process_func,
depth=(radius, radius),
boundary="nearest",
dtype=arr.dtype,
trim=True,
)
return result
134 changes: 100 additions & 34 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,58 +62,61 @@ def native_transform(self, xx):
# clear wet pixels not mask against bit 2: low solar angle
wet = (xx["water"].data & ~(1 << 2)) == 128

# get "valid" wo pixels, both dry and wet used in veg_frequency
wet_valid = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": valid},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# clear dry pixels
clear = xx["water"].data == 0

# get "clear" wo pixels, both dry and wet used in water_frequency
wet_clear = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": clear},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# dilate both 'valid' and 'water'
for key, val in self.BAD_BITS_MASK.items():
if self.cloud_filters.get(key) is not None:
raw_mask = (xx["water"] & val) > 0
raw_mask = mask_cleanup(
raw_mask, mask_filters=self.cloud_filters.get(key)
)
wet_valid = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_valid, "b": raw_mask.data},
valid = expr_eval(
"where(b>0, 0, a)",
{"a": valid, "b": raw_mask.data},
name="get_valid_pixels",
dtype="float32",
dtype="bool",
)

wet_clear = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_clear, "b": raw_mask.data},
clear = expr_eval(
"where(b>0, 0, a)",
{"a": clear, "b": raw_mask.data},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
dtype="bool",
)
wet = expr_eval(
"where(b>0, 0, a)",
{"a": wet, "b": raw_mask.data},
name="get_wet_pixels",
dtype="bool",
)

xx = xx.drop_vars(["water"])

# Pick out the fc pixels that
# have an unmixing error of less than the threshold for dry
# and ignore ue for wet
# get "clear" wo pixels, both dry and wet used in water_frequency
wet_clear = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": clear},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# get "valid" wo pixels, both dry and wet
# to remark nodata reason in veg_frequency
wet_valid = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": valid},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# Pick out the fc pixels that have an unmixing error of less than the threshold
valid = expr_eval(
"where((b>=_v)&(a<=0)|(a!=a), 0, 1)",
"where(b<_v, a, 0)",
{"a": valid, "b": xx["ue"].data},
name="get_low_ue_wet",
name="get_low_ue",
dtype="bool",
**{"_v": self.ue_threshold},
)
Expand All @@ -126,12 +129,16 @@ def native_transform(self, xx):
xx["wet_clear"] = xr.DataArray(
wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords
)
xx["wet_valid"] = xr.DataArray(
wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords
)

return xx

def fuser(self, xx):

wet_clear = xx["wet_clear"]
wet_valid = xx["wet_valid"]

xx = _xr_fuse(
xx.drop_vars(["wet_clear"]),
Expand All @@ -140,6 +147,7 @@ def fuser(self, xx):
)

xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan)
xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan)

return xx

Expand Down Expand Up @@ -183,6 +191,59 @@ def _water_or_not(self, xx: xr.Dataset):
)
return data

def _wet_or_not(self, xx: xr.Dataset):
# mark water freq >= 0.5 as 1
data = expr_eval(
"where(a>0, 1, 0)",
{"a": xx["wet_valid"].data},
name="get_wet",
dtype="uint8",
)

# mark nans
data = expr_eval(
"where(a!=a, nodata, b)",
{"a": xx["wet_valid"].data, "b": data},
name="get_wet",
dtype="uint8",
**{"nodata": int(NODATA)},
)
return data

def _wet_valid_percent(self, data, nodata):
wet = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
total = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")

for t in data:
# +1 if not nodata
wet = expr_eval(
"where(a==nodata, b, a+b)",
{"a": t, "b": wet},
name="get_wet",
dtype="uint8",
**{"nodata": nodata},
)

# total valid
total = expr_eval(
"where(a==nodata, b, b+1)",
{"a": t, "b": total},
name="get_total_valid",
dtype="uint8",
**{"nodata": nodata},
)

wet = expr_eval(
"where(a<=0, nodata, b/a*100)",
{"a": total, "b": wet},
name="normalize_max_count",
dtype="float32",
**{"nodata": int(nodata)},
)

wet = da.ceil(wet).astype("uint8")
return wet

def _max_consecutive_months(self, data, nodata, normalize=False):
tmp = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
max_count = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
Expand Down Expand Up @@ -276,11 +337,16 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
data = self._water_or_not(xx)
max_count_water = self._max_consecutive_months(data, NODATA, normalize=True)

data = self._wet_or_not(xx)
wet_percent = self._wet_valid_percent(data, NODATA)

attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
k: xr.DataArray(v, dims=xx["pv"].dims[1:], attrs=attrs)
for k, v in zip(self.measurements, [max_count_veg, max_count_water])
for k, v in zip(
self.measurements, [max_count_veg, max_count_water, wet_percent]
)
}
coords = dict((dim, xx.coords[dim]) for dim in xx["pv"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
Loading

0 comments on commit 0cf4d97

Please sign in to comment.