Skip to content

Commit

Permalink
Merge pull request #46 from pyiron/multiple_dispatch
Browse files Browse the repository at this point in the history
Allow multiple dispatch
  • Loading branch information
samwaseda authored Sep 13, 2024
2 parents 5b7bdf4 + cc95f2e commit dbed2f0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 44 deletions.
78 changes: 37 additions & 41 deletions elaston/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,50 +160,46 @@ def get_output_units_from_type_hints(func):
return get_units_from_type_hints(func).get("return", None)


def units(outputs=None, inputs=None):
"""
Decorator to handle units in functions.
def units(func=None, *, outputs=None, inputs=None):
# Perform initial checks
_check_inputs_and_outputs(inputs, outputs)
# If func is None, this means the decorator is called with parentheses
if func is None:
# Return the actual decorator that expects the function
def decorator(func):
return _units_decorator(func, inputs, outputs)

Parameters
----------
outputs : str, list, tuple, callable, optional
Output units. If a string, it should be a valid unit. If a list or
tuple, it should contain valid units. If a callable, it should return a
valid unit.
inputs : dict, optional
return decorator
else:
# The decorator is called without parentheses, so func is the actual function
return _units_decorator(func, inputs, outputs)

Returns
-------
callable
"""
_check_inputs_and_outputs(inputs, outputs)

def decorator(func):
nonlocal inputs, outputs
if inputs is None:
inputs = get_input_units_from_type_hints(func)
if outputs is None:
outputs = get_output_units_from_type_hints(func)

@wraps(func)
def wrapper(*args, **kwargs):
ureg = _get_ureg(args, kwargs)
if ureg is None:
return func(*args, **kwargs)
bound_args = _get_input_args(func, *args, **kwargs)
if outputs is not None:
output_units = _get_output_units(outputs, bound_args, ureg)
result = func(**_pint_to_value(bound_args, inputs))
if outputs is not None and output_units is not None:
if isinstance(output_units, Unit):
return result * output_units
else:
return tuple([res * out for res, out in zip(result, output_units)])
return result

return wrapper

return decorator
def _units_decorator(func, inputs, outputs):

# If inputs or outputs are None, set them based on the function signature
if inputs is None:
inputs = get_input_units_from_type_hints(func)
if outputs is None:
outputs = get_output_units_from_type_hints(func)

@wraps(func)
def wrapper(*args, **kwargs):
ureg = _get_ureg(args, kwargs)
if ureg is None:
return func(*args, **kwargs)
bound_args = _get_input_args(func, *args, **kwargs)
if outputs is not None:
output_units = _get_output_units(outputs, bound_args, ureg)
result = func(**_pint_to_value(bound_args, inputs))
if outputs is not None and output_units is not None:
if isinstance(output_units, Unit):
return result * output_units
else:
return tuple([res * out for res, out in zip(result, output_units)])
return result

return wrapper


def optional_units(*args):
Expand Down
21 changes: 18 additions & 3 deletions tests/unit/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
from pint import UnitRegistry


@units
def get_speed_multiple_dispatch(
distance: Float["meter"], time: Float["second"]
) -> Float["meter/second"]:
return distance / time


@units()
def get_speed_ints(
distance: Int["meter"], time: Int["second"]
) -> Int["meter/second"]:
def get_speed_ints(distance: Int["meter"], time: Int["second"]) -> Int["meter/second"]:
return distance / time


Expand Down Expand Up @@ -103,6 +108,16 @@ def test_type_hinting(self):
get_speed_ints(1 * ureg.meter, 1 * ureg.millisecond).magnitude, int(1e3)
)

def test_multiple_dispatch(self):
ureg = UnitRegistry()
self.assertAlmostEqual(
get_speed_multiple_dispatch(1 * ureg.meter, 1 * ureg.second).magnitude, 1
)
self.assertAlmostEqual(
get_speed_multiple_dispatch(1 * ureg.meter, 1 * ureg.millisecond).magnitude,
1e3,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit dbed2f0

Please sign in to comment.