From 48cfa71effa930cf661bf52eb525a106f7772750 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Sun, 5 Nov 2023 12:42:44 +1100 Subject: [PATCH] Avoid aliasing in MPUChunk.merge Merging two chunks should not modify inputs, or return chunk that aliases members of inputs. --- odc/geo/cog/_mpu.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/odc/geo/cog/_mpu.py b/odc/geo/cog/_mpu.py index 6e5b1f6e..a40be9cd 100644 --- a/odc/geo/cog/_mpu.py +++ b/odc/geo/cog/_mpu.py @@ -3,6 +3,7 @@ """ from __future__ import annotations +from copy import copy from functools import partial from uuid import uuid4 from typing import ( @@ -87,7 +88,7 @@ def __init__( observed: Optional[List[Tuple[int, Any]]] = None, is_final: bool = False, lhs_keep: int = 0, - tk: str | None = None + tk: str | None = None, ) -> None: if tk is None: tk = uuid4().hex @@ -104,6 +105,19 @@ def __init__( # if supplying data must also supply observed assert data is None or (observed is not None and len(observed) > 0) + def clone(self) -> "MPUChunk": + return MPUChunk( + self.nextPartId, + self.write_credits, + copy(self.data), + copy(self.left_data), + copy(self.parts), + copy(self.observed), + self.is_final, + self.lhs_keep, + self.tk, + ) + def __dask_tokenize__(self): return ( "MPUChunk", @@ -155,8 +169,8 @@ def merge( lhs.nextPartId, lhs.write_credits + rhs.write_credits, lhs.data + rhs.data, - lhs.left_data, - lhs.parts, + copy(lhs.left_data), + copy(lhs.parts), lhs.observed + rhs.observed, rhs.is_final, lhs.lhs_keep, @@ -165,12 +179,13 @@ def merge( # Flush `lhs.data + rhs.left_data` if we can # or else move it into .left_data + lhs = lhs.clone() lhs.flush_rhs(write, rhs.left_data) return MPUChunk( rhs.nextPartId, rhs.write_credits, - rhs.data, + copy(rhs.data), lhs.left_data, lhs.parts + rhs.parts, lhs.observed + rhs.observed, @@ -334,7 +349,9 @@ def from_dask_bag( import dask.bag from dask.base import tokenize - tk = tokenize(partId, chunks, writes_per_chunk, lhs_keep, write, spill_sz, split_every) + tk = tokenize( + partId, chunks, writes_per_chunk, lhs_keep, write, spill_sz, split_every + ) mpus = dask.bag.from_sequence( MPUChunk.gen_bunch(