Skip to content

Commit

Permalink
add test for cultivated model
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Aug 6, 2024
1 parent dbf76eb commit 72b9fb4
Show file tree
Hide file tree
Showing 3 changed files with 478 additions and 9 deletions.
1 change: 1 addition & 0 deletions odc/stats/plugins/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, model_path):
def setup(self, worker):
worker.plugin_instance = self
worker.predictors = {}
print(f"registered worker {worker}")

def get_predictor(self):
worker = get_worker()
Expand Down
30 changes: 21 additions & 9 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def predict(self, input_array):
input_array,
bands_indices,
chunks=(
self.chunks["x"],
self.chunks["y"],
input_array.chunks[0],
input_array.chunks[1],
15 + len(bands_indices) - bands_indices["bcdev"] - 1,
),
dtype="float32",
Expand All @@ -254,6 +254,7 @@ def predict(self, input_array):
dtype="float32",
name="cultivated_predict",
)

return cc

def aggregate_results_from_group(self, predict_output):
Expand All @@ -263,10 +264,21 @@ def aggregate_results_from_group(self, predict_output):
# for each pixel
m_size = len(predict_output)
if m_size > 1:
predict_output = da.stack(predict_output).sum(axis=0)
predict_output = da.stack(predict_output)
else:
predict_output = predict_output[0]

predict_output = expr_eval(
"where(a<nodata, 1-a, a)",
{"a": predict_output},
name="invert_output",
dtype="float32",
**{"nodata": NODATA},
)

if m_size > 1:
predict_output = predict_output.sum(axis=0)

predict_output = expr_eval(
"where((a/nodata)>=_l, nodata, a%nodata)",
{"a": predict_output},
Expand All @@ -276,19 +288,19 @@ def aggregate_results_from_group(self, predict_output):
)

predict_output = expr_eval(
"where((a>=_l)&(a<nodata), _u, a)",
"where((a>0)&(a<nodata), _u, a)",
{"a": predict_output},
name="output_classes_natural",
name="output_classes_cultivated",
dtype="float32",
**{"_u": self.output_classes["natural"], "nodata": NODATA, "_l": m_size},
**{"_u": self.output_classes["cultivated"], "nodata": NODATA},
)

predict_output = expr_eval(
"where(a<_l, _nu, a)",
"where(a<=0, _nu, a)",
{"a": predict_output},
name="output_classes_cultivated",
name="output_classes_natural",
dtype="uint8",
**{"_nu": self.output_classes["cultivated"], "_l": m_size},
**{"_nu": self.output_classes["natural"]},
)

return predict_output.rechunk(-1, -1)
Expand Down
Loading

0 comments on commit 72b9fb4

Please sign in to comment.