Skip to content

Commit

Permalink
Fix stripe issue in landcover level34 (#175)
Browse files Browse the repository at this point in the history
* set high ue but valid data as non-veg

* move water_season from level1 to condition of wo frequency in level34

* fix typo

* cover the edge cases of landcover level1

* make nodata condition explicit in ue

* better organize the logic block of fc masking

* normalize water frequency

* fix typo

---------

Co-authored-by: Emma Ai <[email protected]>
  • Loading branch information
emmaai and Emma Ai authored Dec 4, 2024
1 parent 7e56476 commit d04b1c0
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 190 deletions.
143 changes: 101 additions & 42 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,34 @@ def native_transform(self, xx):
5. Drop the WOfS band
"""

# clear and dry pixels not mask against bit 4: terrain high slope,
# valid and dry pixels not mask against bit 4: terrain high slope,
# bit 3: terrain shadow, and
# bit 2: low solar angle
valid = (xx["water"] & ~((1 << 4) | (1 << 3) | (1 << 2))) == 0
valid = (xx["water"].data & ~((1 << 4) | (1 << 3) | (1 << 2))) == 0

# clear and wet pixels not mask against bit 2: low solar angle
wet = (xx["water"] & ~(1 << 2)) == 128
# clear wet pixels not mask against bit 2: low solar angle
wet = (xx["water"].data & ~(1 << 2)) == 128

# clear dry pixels
clear = xx["water"].data == 0

# get "valid" wo pixels, both dry and wet used in veg_frequency
wet_valid = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": valid},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# get "clear" wo pixels, both dry and wet used in water_frequency
wet_clear = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": clear},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# dilate both 'valid' and 'water'
for key, val in self.BAD_BITS_MASK.items():
Expand All @@ -67,37 +88,64 @@ def native_transform(self, xx):
raw_mask = mask_cleanup(
raw_mask, mask_filters=self.cloud_filters.get(key)
)
valid &= ~raw_mask
wet &= ~raw_mask

valid = expr_eval(
"where(b>0, 0, a)",
{"a": valid, "b": raw_mask.data},
name="get_valid_pixels",
dtype="uint8",
)
wet_clear = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_clear, "b": raw_mask.data},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
)
wet_valid = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_valid, "b": raw_mask.data},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)
xx = xx.drop_vars(["water"])

# get valid wo pixels, both dry and wet
data = expr_eval(
"where(a|b, a, _nan)",
{"a": wet.data, "b": valid.data},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# Pick out the fc pixels that have an unmixing error of less than the threshold
valid &= xx["ue"] < self.ue_threshold
valid = expr_eval(
"where(b<_v, a, 0)",
{"a": valid, "b": xx["ue"].data},
name="get_low_ue",
dtype="bool",
**{"_v": self.ue_threshold},
)
xx = xx.drop_vars(["ue"])
valid = xr.DataArray(valid, dims=xx["pv"].dims, coords=xx["pv"].coords)

xx = keep_good_only(xx, valid, nodata=NODATA)
xx = to_float(xx, dtype="float32")

xx["wet"] = xr.DataArray(data, dims=wet.dims, coords=wet.coords)
xx["wet_valid"] = xr.DataArray(
wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords
)
xx["wet_clear"] = xr.DataArray(
wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords
)

return xx

def fuser(self, xx):

wet = xx["wet"]
wet_valid = xx["wet_valid"]
wet_clear = xx["wet_clear"]

xx = _xr_fuse(xx.drop_vars(["wet"]), partial(_fuse_mean_np, nodata=np.nan), "")
xx = _xr_fuse(
xx.drop_vars(["wet_valid", "wet_clear"]),
partial(_fuse_mean_np, nodata=np.nan),
"",
)

xx["wet"] = _nodata_fuser(wet, nodata=np.nan)
xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan)
xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan)

return xx

Expand All @@ -123,7 +171,7 @@ def _veg_or_not(self, xx: xr.Dataset):
# mark water freq >= 0.5 as 0
data = expr_eval(
"where(a>0, 0, b)",
{"a": xx["wet"].data, "b": data},
{"a": xx["wet_valid"].data, "b": data},
name="get_veg",
dtype="uint8",
)
Expand All @@ -134,25 +182,25 @@ def _water_or_not(self, xx: xr.Dataset):
# mark water freq > 0.5 as 1
data = expr_eval(
"where(a>0.5, 1, 0)",
{"a": xx["wet"].data},
{"a": xx["wet_clear"].data},
name="get_water",
dtype="uint8",
)

# mark nans
data = expr_eval(
"where(a!=a, nodata, b)",
{"a": xx["wet"].data, "b": data},
{"a": xx["wet_clear"].data, "b": data},
name="get_water",
dtype="uint8",
**{"nodata": int(NODATA)},
)
return data

def _max_consecutive_months(self, data, nodata):
nan_mask = da.ones(data.shape[1:], chunks=data.chunks[1:], dtype="bool")
def _max_consecutive_months(self, data, nodata, normalize=False):
tmp = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
max_count = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
total = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")

for t in data:
# +1 if not nodata
Expand Down Expand Up @@ -180,23 +228,34 @@ def _max_consecutive_months(self, data, nodata):
dtype="uint8",
)

# mark nodata
nan_mask = expr_eval(
"where(a==nodata, b, False)",
{"a": t, "b": nan_mask},
name="mark_nodata",
dtype="bool",
# total valid
total = expr_eval(
"where(a==nodata, b, b+1)",
{"a": t, "b": total},
name="get_total_valid",
dtype="uint8",
**{"nodata": nodata},
)

# mark nodata
max_count = expr_eval(
"where(a, nodata, b)",
{"a": nan_mask, "b": max_count},
name="mark_nodata",
dtype="uint8",
**{"nodata": int(nodata)},
)
if normalize:
max_count = expr_eval(
"where(a<=0, nodata, b/a*12)",
{"a": total, "b": max_count},
name="normalize_max_count",
dtype="float32",
**{"nodata": int(nodata)},
)
max_count = da.ceil(max_count).astype("uint8")
else:
max_count = expr_eval(
"where(a<=0, nodata, b)",
{"a": total, "b": max_count},
name="mark_nodata",
dtype="uint8",
**{"nodata": int(nodata)},
)

return max_count

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
Expand All @@ -207,15 +266,15 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
max_count_veg = self._max_consecutive_months(data, NODATA)

data = self._water_or_not(xx)
max_count_water = self._max_consecutive_months(data, NODATA)
max_count_water = self._max_consecutive_months(data, NODATA, normalize=True)

attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
k: xr.DataArray(v, dims=xx["wet"].dims[1:], attrs=attrs)
k: xr.DataArray(v, dims=xx["pv"].dims[1:], attrs=attrs)
for k, v in zip(self.measurements, [max_count_veg, max_count_water])
}
coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:])
coords = dict((dim, xx.coords[dim]) for dim in xx["pv"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)


Expand Down
69 changes: 17 additions & 52 deletions odc/stats/plugins/lc_veg_class_a1.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
saltpan_threshold: Optional[int] = None,
water_threshold: Optional[float] = None,
veg_threshold: Optional[int] = None,
water_seasonality_threshold: Optional[float] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -70,9 +69,6 @@ def __init__(
)
self.water_threshold = water_threshold if water_threshold is not None else 0.2
self.veg_threshold = veg_threshold if veg_threshold is not None else 2
self.water_seasonality_threshold = (
water_seasonality_threshold if water_seasonality_threshold else 0.25
)
self.output_classes = output_classes

def fuser(self, xx):
Expand Down Expand Up @@ -165,23 +161,26 @@ def l3_class(self, xx: xr.Dataset):
},
)

# all unmarked values (0) is terretrial veg

# all unmarked values (0) and 255 > veg >= 2 is terretrial veg
l3_mask = expr_eval(
"where(a<=0, m, a)",
{"a": l3_mask},
"where((a<=0)&(b>=2)&(b<nodata), m, a)",
{"a": l3_mask, "b": xx["veg_frequency"].data},
name="mark_veg",
dtype="uint8",
**{"m": self.output_classes["terrestrial_veg"]},
**{
"m": self.output_classes["terrestrial_veg"],
"nodata": (
NODATA
if xx["veg_frequency"].attrs["nodata"]
!= xx["veg_frequency"].attrs["nodata"]
else xx["veg_frequency"].attrs["nodata"]
),
},
)

# mark nodata if any source is nodata
# issues:
# - nodata information from non-indexed datasets missing

# Mask nans with NODATA
# Mask nans and pixels where non of classes applicable
l3_mask = expr_eval(
"where((a!=a), nodata, e)",
"where((a!=a)|(e<=0), nodata, e)",
{
"a": si5,
"e": l3_mask,
Expand All @@ -191,49 +190,15 @@ def l3_class(self, xx: xr.Dataset):
**{"nodata": NODATA},
)

# Now add the water frequency
# Divide water frequency into following classes:
# 0 --> 0
# (0,0.25] --> 1
# (0.25,1] --> 2

water_seasonality = expr_eval(
"where((a > 0) & (a <= wt), 1, a)",
{"a": xx["frequency"].data},
name="mark_wo_fq",
dtype="float32",
**{"wt": self.water_seasonality_threshold},
)

water_seasonality = expr_eval(
"where((a > wt) & (a <= 1), 2, b)",
{"a": xx["frequency"].data, "b": water_seasonality},
name="mark_wo_fq",
dtype="float32",
**{"wt": self.water_seasonality_threshold},
)

water_seasonality = expr_eval(
"where((a != a), nodata, a)",
{
"a": water_seasonality,
},
name="mark_nodata",
dtype="uint8",
**{"nodata": NODATA},
)

return l3_mask, water_seasonality
return l3_mask

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
l3_mask, water_seasonality = self.l3_class(xx)
l3_mask = self.l3_class(xx)
attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
k: xr.DataArray(v, dims=xx["veg_frequency"].dims[1:], attrs=attrs)
for k, v in zip(
self.measurements, [l3_mask.squeeze(0), water_seasonality.squeeze(0)]
)
for k, v in zip(self.measurements, [l3_mask.squeeze(0)])
}
coords = dict((dim, xx.coords[dim]) for dim in xx["veg_frequency"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
25 changes: 13 additions & 12 deletions tests/test_landcover_plugin_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,12 @@ def test_native_transform(fc_wo_dataset, bits):
np.array([1, 1, 3, 5, 6, 2, 6, 2, 2, 5, 6, 0, 0, 2, 3]),
np.array([0, 3, 2, 1, 3, 5, 6, 1, 4, 5, 6, 0, 2, 4, 2]),
)
result = np.where(out_xx["wet"].data == out_xx["wet"].data)
result = np.where(out_xx["wet_valid"].data == out_xx["wet_valid"].data)
for a, b in zip(expected_valid, result):
assert (a == b).all()

expected_valid = (np.array([1, 2, 3]), np.array([6, 2, 0]), np.array([6, 1, 2]))
result = np.where(out_xx["wet"].data == 1)
result = np.where(out_xx["wet_valid"].data == 1)

for a, b in zip(expected_valid, result):
assert (a == b).all()
Expand Down Expand Up @@ -391,11 +391,11 @@ def test_water_or_not(fc_wo_dataset):
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._water_or_not(xx).compute()
valid_index = (
np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2]),
np.array([1, 1, 3, 5, 6, 2, 6, 0, 0, 2, 2, 3, 5, 6]),
np.array([0, 3, 2, 1, 3, 5, 6, 0, 2, 1, 4, 2, 5, 6]),
np.array([0, 0, 1, 1, 2, 2, 2]),
np.array([3, 6, 2, 6, 0, 2, 2]),
np.array([2, 3, 5, 6, 2, 1, 4]),
)
expected_value = np.array([0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0])
expected_value = np.array([0, 0, 0, 1, 1, 1, 0])
i = 0
for idx in zip(*valid_index):
assert yy[idx] == expected_value[i]
Expand Down Expand Up @@ -423,14 +423,15 @@ def test_reduce(fc_wo_dataset):

expected_value = np.array(
[
[0, 255, 1, 255, 255, 255, 255],
[0, 255, 255, 0, 255, 255, 255],
[255, 1, 255, 255, 0, 0, 255],
[255, 255, 12, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 12, 255, 255, 0, 0, 255],
[255, 255, 0, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 0, 255, 255, 255, 0, 255],
[255, 255, 255, 0, 255, 255, 1],
]
[255, 255, 255, 255, 255, 255, 255],
[255, 255, 255, 0, 255, 255, 12],
],
dtype="uint8",
)

assert (xx.water_frequency.data == expected_value).all()
Expand Down
Loading

0 comments on commit d04b1c0

Please sign in to comment.