Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stripe issue in landcover level34 #175

Merged
merged 8 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 101 additions & 42 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,34 @@ def native_transform(self, xx):
5. Drop the WOfS band
"""

# clear and dry pixels not mask against bit 4: terrain high slope,
# valid and dry pixels not mask against bit 4: terrain high slope,
# bit 3: terrain shadow, and
# bit 2: low solar angle
valid = (xx["water"] & ~((1 << 4) | (1 << 3) | (1 << 2))) == 0
valid = (xx["water"].data & ~((1 << 4) | (1 << 3) | (1 << 2))) == 0

# clear and wet pixels not mask against bit 2: low solar angle
wet = (xx["water"] & ~(1 << 2)) == 128
# clear wet pixels not mask against bit 2: low solar angle
wet = (xx["water"].data & ~(1 << 2)) == 128

# clear dry pixels
clear = xx["water"].data == 0

# get "valid" wo pixels, both dry and wet used in veg_frequency
wet_valid = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": valid},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# get "clear" wo pixels, both dry and wet used in water_frequency
wet_clear = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": clear},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# dilate both 'valid' and 'water'
for key, val in self.BAD_BITS_MASK.items():
Expand All @@ -67,37 +88,64 @@ def native_transform(self, xx):
raw_mask = mask_cleanup(
raw_mask, mask_filters=self.cloud_filters.get(key)
)
valid &= ~raw_mask
wet &= ~raw_mask

valid = expr_eval(
"where(b>0, 0, a)",
{"a": valid, "b": raw_mask.data},
name="get_valid_pixels",
dtype="uint8",
)
wet_clear = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_clear, "b": raw_mask.data},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
)
wet_valid = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_valid, "b": raw_mask.data},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)
xx = xx.drop_vars(["water"])

# get valid wo pixels, both dry and wet
data = expr_eval(
"where(a|b, a, _nan)",
{"a": wet.data, "b": valid.data},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# Pick out the fc pixels that have an unmixing error of less than the threshold
valid &= xx["ue"] < self.ue_threshold
valid = expr_eval(
"where(b<_v, a, 0)",
{"a": valid, "b": xx["ue"].data},
name="get_low_ue",
dtype="bool",
**{"_v": self.ue_threshold},
)
xx = xx.drop_vars(["ue"])
valid = xr.DataArray(valid, dims=xx["pv"].dims, coords=xx["pv"].coords)

xx = keep_good_only(xx, valid, nodata=NODATA)
xx = to_float(xx, dtype="float32")

xx["wet"] = xr.DataArray(data, dims=wet.dims, coords=wet.coords)
xx["wet_valid"] = xr.DataArray(
wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords
)
xx["wet_clear"] = xr.DataArray(
wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords
)

return xx

def fuser(self, xx):

wet = xx["wet"]
wet_valid = xx["wet_valid"]
wet_clear = xx["wet_clear"]

xx = _xr_fuse(xx.drop_vars(["wet"]), partial(_fuse_mean_np, nodata=np.nan), "")
xx = _xr_fuse(
xx.drop_vars(["wet_valid", "wet_clear"]),
partial(_fuse_mean_np, nodata=np.nan),
"",
)

xx["wet"] = _nodata_fuser(wet, nodata=np.nan)
xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan)
xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan)

return xx

Expand All @@ -123,7 +171,7 @@ def _veg_or_not(self, xx: xr.Dataset):
# mark water freq >= 0.5 as 0
data = expr_eval(
"where(a>0, 0, b)",
{"a": xx["wet"].data, "b": data},
{"a": xx["wet_valid"].data, "b": data},
name="get_veg",
dtype="uint8",
)
Expand All @@ -134,25 +182,25 @@ def _water_or_not(self, xx: xr.Dataset):
# mark water freq > 0.5 as 1
data = expr_eval(
"where(a>0.5, 1, 0)",
{"a": xx["wet"].data},
{"a": xx["wet_clear"].data},
name="get_water",
dtype="uint8",
)

# mark nans
data = expr_eval(
"where(a!=a, nodata, b)",
{"a": xx["wet"].data, "b": data},
{"a": xx["wet_clear"].data, "b": data},
name="get_water",
dtype="uint8",
**{"nodata": int(NODATA)},
)
return data

def _max_consecutive_months(self, data, nodata):
nan_mask = da.ones(data.shape[1:], chunks=data.chunks[1:], dtype="bool")
def _max_consecutive_months(self, data, nodata, normalize=False):
tmp = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
max_count = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
total = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")

for t in data:
# +1 if not nodata
Expand Down Expand Up @@ -180,23 +228,34 @@ def _max_consecutive_months(self, data, nodata):
dtype="uint8",
)

# mark nodata
nan_mask = expr_eval(
"where(a==nodata, b, False)",
{"a": t, "b": nan_mask},
name="mark_nodata",
dtype="bool",
# total valid
total = expr_eval(
"where(a==nodata, b, b+1)",
{"a": t, "b": total},
name="get_total_valid",
dtype="uint8",
**{"nodata": nodata},
)

# mark nodata
max_count = expr_eval(
"where(a, nodata, b)",
{"a": nan_mask, "b": max_count},
name="mark_nodata",
dtype="uint8",
**{"nodata": int(nodata)},
)
if normalize:
max_count = expr_eval(
"where(a<=0, nodata, b/a*12)",
{"a": total, "b": max_count},
name="normalize_max_count",
dtype="float32",
**{"nodata": int(nodata)},
)
max_count = da.ceil(max_count).astype("uint8")
else:
max_count = expr_eval(
"where(a<=0, nodata, b)",
{"a": total, "b": max_count},
name="mark_nodata",
dtype="uint8",
**{"nodata": int(nodata)},
)

return max_count

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
Expand All @@ -207,15 +266,15 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
max_count_veg = self._max_consecutive_months(data, NODATA)

data = self._water_or_not(xx)
max_count_water = self._max_consecutive_months(data, NODATA)
max_count_water = self._max_consecutive_months(data, NODATA, normalize=True)

attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
k: xr.DataArray(v, dims=xx["wet"].dims[1:], attrs=attrs)
k: xr.DataArray(v, dims=xx["pv"].dims[1:], attrs=attrs)
for k, v in zip(self.measurements, [max_count_veg, max_count_water])
}
coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:])
coords = dict((dim, xx.coords[dim]) for dim in xx["pv"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)


Expand Down
69 changes: 17 additions & 52 deletions odc/stats/plugins/lc_veg_class_a1.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
saltpan_threshold: Optional[int] = None,
water_threshold: Optional[float] = None,
veg_threshold: Optional[int] = None,
water_seasonality_threshold: Optional[float] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -70,9 +69,6 @@ def __init__(
)
self.water_threshold = water_threshold if water_threshold is not None else 0.2
self.veg_threshold = veg_threshold if veg_threshold is not None else 2
self.water_seasonality_threshold = (
water_seasonality_threshold if water_seasonality_threshold else 0.25
)
self.output_classes = output_classes

def fuser(self, xx):
Expand Down Expand Up @@ -165,23 +161,26 @@ def l3_class(self, xx: xr.Dataset):
},
)

# all unmarked values (0) is terretrial veg

# all unmarked values (0) and 255 > veg >= 2 is terretrial veg
l3_mask = expr_eval(
"where(a<=0, m, a)",
{"a": l3_mask},
"where((a<=0)&(b>=2)&(b<nodata), m, a)",
{"a": l3_mask, "b": xx["veg_frequency"].data},
name="mark_veg",
dtype="uint8",
**{"m": self.output_classes["terrestrial_veg"]},
**{
"m": self.output_classes["terrestrial_veg"],
"nodata": (
NODATA
if xx["veg_frequency"].attrs["nodata"]
!= xx["veg_frequency"].attrs["nodata"]
else xx["veg_frequency"].attrs["nodata"]
),
},
)

# mark nodata if any source is nodata
# issues:
# - nodata information from non-indexed datasets missing

# Mask nans with NODATA
# Mask nans and pixels where non of classes applicable
l3_mask = expr_eval(
"where((a!=a), nodata, e)",
"where((a!=a)|(e<=0), nodata, e)",
{
"a": si5,
"e": l3_mask,
Expand All @@ -191,49 +190,15 @@ def l3_class(self, xx: xr.Dataset):
**{"nodata": NODATA},
)

# Now add the water frequency
# Divide water frequency into following classes:
# 0 --> 0
# (0,0.25] --> 1
# (0.25,1] --> 2

water_seasonality = expr_eval(
"where((a > 0) & (a <= wt), 1, a)",
{"a": xx["frequency"].data},
name="mark_wo_fq",
dtype="float32",
**{"wt": self.water_seasonality_threshold},
)

water_seasonality = expr_eval(
"where((a > wt) & (a <= 1), 2, b)",
{"a": xx["frequency"].data, "b": water_seasonality},
name="mark_wo_fq",
dtype="float32",
**{"wt": self.water_seasonality_threshold},
)

water_seasonality = expr_eval(
"where((a != a), nodata, a)",
{
"a": water_seasonality,
},
name="mark_nodata",
dtype="uint8",
**{"nodata": NODATA},
)

return l3_mask, water_seasonality
return l3_mask

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
l3_mask, water_seasonality = self.l3_class(xx)
l3_mask = self.l3_class(xx)
attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
k: xr.DataArray(v, dims=xx["veg_frequency"].dims[1:], attrs=attrs)
for k, v in zip(
self.measurements, [l3_mask.squeeze(0), water_seasonality.squeeze(0)]
)
for k, v in zip(self.measurements, [l3_mask.squeeze(0)])
}
coords = dict((dim, xx.coords[dim]) for dim in xx["veg_frequency"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
25 changes: 13 additions & 12 deletions tests/test_landcover_plugin_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,12 @@ def test_native_transform(fc_wo_dataset, bits):
np.array([1, 1, 3, 5, 6, 2, 6, 2, 2, 5, 6, 0, 0, 2, 3]),
np.array([0, 3, 2, 1, 3, 5, 6, 1, 4, 5, 6, 0, 2, 4, 2]),
)
result = np.where(out_xx["wet"].data == out_xx["wet"].data)
result = np.where(out_xx["wet_valid"].data == out_xx["wet_valid"].data)
for a, b in zip(expected_valid, result):
assert (a == b).all()

expected_valid = (np.array([1, 2, 3]), np.array([6, 2, 0]), np.array([6, 1, 2]))
result = np.where(out_xx["wet"].data == 1)
result = np.where(out_xx["wet_valid"].data == 1)

for a, b in zip(expected_valid, result):
assert (a == b).all()
Expand Down Expand Up @@ -391,11 +391,11 @@ def test_water_or_not(fc_wo_dataset):
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._water_or_not(xx).compute()
valid_index = (
np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2]),
np.array([1, 1, 3, 5, 6, 2, 6, 0, 0, 2, 2, 3, 5, 6]),
np.array([0, 3, 2, 1, 3, 5, 6, 0, 2, 1, 4, 2, 5, 6]),
np.array([0, 0, 1, 1, 2, 2, 2]),
np.array([3, 6, 2, 6, 0, 2, 2]),
np.array([2, 3, 5, 6, 2, 1, 4]),
)
expected_value = np.array([0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0])
expected_value = np.array([0, 0, 0, 1, 1, 1, 0])
i = 0
for idx in zip(*valid_index):
assert yy[idx] == expected_value[i]
Expand Down Expand Up @@ -423,14 +423,15 @@ def test_reduce(fc_wo_dataset):

expected_value = np.array(
[
[0, 255, 1, 255, 255, 255, 255],
[0, 255, 255, 0, 255, 255, 255],
[255, 1, 255, 255, 0, 0, 255],
[255, 255, 12, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 12, 255, 255, 0, 0, 255],
[255, 255, 0, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 0, 255, 255, 255, 0, 255],
[255, 255, 255, 0, 255, 255, 1],
]
[255, 255, 255, 255, 255, 255, 255],
[255, 255, 255, 0, 255, 255, 12],
],
dtype="uint8",
)

assert (xx.water_frequency.data == expected_value).all()
Expand Down
Loading
Loading