Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dask_awkward wrapper to Correction and CompoundCorrection #219

Merged
merged 6 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ test =
scipy
awkward >=2.2.2;python_version>"3.7"
awkward <2;python_version<="3.7"
dask-awkward;python_version>"3.7"
dask-awkward >=2024.1.1;python_version>"3.7"
dev =
pytest >=4.6
pre-commit
Expand Down
85 changes: 79 additions & 6 deletions src/correctionlib/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import correctionlib._core
import correctionlib.version

_version_two = version.parse("2")
_min_version_ak = version.parse("2.0.0")
_min_version_dak = version.parse("2024.1.1")


def open_auto(filename: str) -> str:
Expand Down Expand Up @@ -58,9 +59,9 @@ def _call_as_numpy(
) -> Any:
import awkward

if version.parse(awkward.__version__) < _version_two:
if version.parse(awkward.__version__) < _min_version_ak:
raise RuntimeError(
f"""imported awkward is version {awkward.__version__} < 2.0.0
f"""imported awkward is version {awkward.__version__} < {str(_min_version_ak)}
If you cannot upgrade, try doing: ak.flatten(arrays) -> result = correction(arrays) -> ak.unflatten(result, counts)
"""
)
Expand Down Expand Up @@ -130,6 +131,49 @@ def _wrap_awkward(
return awkward.transform(tocall, *array_args)


def _call_dask_correction(
correction: Any,
*args: Union["numpy.ndarray[Any, Any]", str, int, float],
):
return _wrap_awkward(correction._base.evalv, *args)


def _wrap_dask_awkward(
correction: Any,
*args: Union["numpy.ndarray[Any, Any]", str, int, float],
) -> Any:
import dask.delayed
import dask_awkward

if version.parse(dask_awkward.__version__) < _min_version_dak:
raise RuntimeError(
f"""imported dask_awkward is version {dask_awkward.__version__} < {str(_min_version_dak)}
This version of dask_awkward includes several useful bugfixes and functionality extensions.
Please upgrade dask_awkward.
"""
)

if not hasattr(correction, "_delayed_correction"):
setattr( # noqa: B010
correction,
"_delayed_correction",
dask.delayed(correction),
)

correction_meta = _wrap_awkward(
correction._base.evalv,
*(arg._meta if isinstance(arg, dask_awkward.Array) else arg for arg in args),
nsmith- marked this conversation as resolved.
Show resolved Hide resolved
)

return dask_awkward.map_partitions(
_call_dask_correction,
correction._delayed_correction,
*args,
meta=correction_meta,
label=correction._name,
)


class Correction:
"""High-level correction evaluator object

Expand Down Expand Up @@ -174,12 +218,22 @@ def evaluate(
self, *args: Union["numpy.ndarray[Any, Any]", str, int, float]
) -> Union[float, "numpy.ndarray[Any, numpy.dtype[numpy.float64]]"]:
# TODO: create a ufunc with numpy.vectorize in constructor?
if any(str(type(arg)).startswith("<class 'dask.array.") for arg in args):
raise TypeError(
"Correctionlib does not yet handle dask.array collections. "
"If you require this functionality (i.e. you cannot or do "
"not want to use dask_awkward/awkward arrays) please open an "
"issue at https://github.com/cms-nanoAOD/correctionlib/issues."
)
try:
vargs = [
numpy.asarray(arg)
for arg in args
if not isinstance(arg, (str, int, float))
]
except NotImplementedError:
if any(str(type(arg)).startswith("<class 'dask_awkward.") for arg in args):
lgray marked this conversation as resolved.
Show resolved Hide resolved
return _wrap_dask_awkward(self, *args) # type: ignore
except (ValueError, TypeError):
if any(str(type(arg)).startswith("<class 'awkward.") for arg in args):
return _wrap_awkward(self._base.evalv, *args) # type: ignore
Expand Down Expand Up @@ -242,9 +296,28 @@ def evaluate(
self, *args: Union["numpy.ndarray[Any, Any]", str, int, float]
) -> Union[float, "numpy.ndarray[Any, numpy.dtype[numpy.float64]]"]:
# TODO: create a ufunc with numpy.vectorize in constructor?
vargs = [
numpy.asarray(arg) for arg in args if not isinstance(arg, (str, int, float))
]
if any(str(type(arg)).startswith("<class 'dask.array.") for arg in args):
raise TypeError(
"Correctionlib does not yet handle dask.array collections. "
"if you require this functionality (i.e. you cannot or do "
"not want to use dask_awkward/awkward arrays) please open an "
"issue at https://github.com/cms-nanoAOD/correctionlib/issues."
)
try:
vargs = [
numpy.asarray(arg)
for arg in args
if not isinstance(arg, (str, int, float))
]
except NotImplementedError:
if any(str(type(arg)).startswith("<class 'dask_awkward.") for arg in args):
return _wrap_dask_awkward(self, *args) # type: ignore
except (ValueError, TypeError):
if any(str(type(arg)).startswith("<class 'awkward.") for arg in args):
return _wrap_awkward(self._base.evalv, *args) # type: ignore
except Exception as err:
raise err

if vargs:
bargs = numpy.broadcast_arrays(*vargs)
oshape = bargs[0].shape
Expand Down
3 changes: 1 addition & 2 deletions tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def test_highlevel_dask(cset):
x = awkward.unflatten(numpy.ones(6), [3, 2, 1])
dx = dask_awkward.from_awkward(x, 3)

evaluate = dask_awkward.map_partitions(
sf.evaluate,
evaluate = sf.evaluate(
dx,
1.0,
)
Expand Down
Loading