Skip to content

Commit

Permalink
New transformation API
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-miklaucic committed Apr 12, 2024
1 parent a4e8f84 commit b31eab0
Show file tree
Hide file tree
Showing 10 changed files with 275 additions and 167 deletions.
14 changes: 12 additions & 2 deletions examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from eins import Reductions as R # noqa: N817
from eins import Transformations as T
from eins.common_types import Array
from eins.namespaces import ElementwiseOps

# Set this to 'jax', 'numpy', or 'torch'
BACKEND = 'torch'
Expand Down Expand Up @@ -56,6 +57,15 @@ def test_close(a: Array, b: Array):
assert diffs.max() < EPSILON, f'{a.shape} != {b.shape}, {R.mean(a)}, {R.mean(b)}, {diffs.max()}' # noqa: S101


# Softmax
x = randn(5, 4)

op = EinsOp('a b', transform={'b': ('softmax', ElementwiseOps.from_func(lambda x: x + 2))})
y = op(x)
print(op)
y2 = T.Softmax(temperature=1)(x, axis=1)
test_close(y, y2)

# Splitting
x, y = randn(3, 4), randn(5, 4)
z2 = xp.concatenate((x, y), axis=0)
Expand All @@ -65,8 +75,8 @@ def test_close(a: Array, b: Array):

# Concatenation

z1 = EinsOp('a c, b c -> a+b c')(x, y)
test_close(z1, z2)
# z1 = EinsOp('a c, b c -> a+b c')(x, y)
# test_close(z1, z2)

# Simple matrix multiplication
x = randn(32, 64)
Expand Down
153 changes: 31 additions & 122 deletions network.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/eins/combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __call__(self, arr1: Array, arr2: Array) -> Array:
arrs = (arr1, arr2)
out = arrs
combines = 0
for op in self.ops:
for op in self.ops[::-1]:
if isinstance(op, Combination):
combines += 1
if combines > 1:
Expand Down
8 changes: 7 additions & 1 deletion src/eins/concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Split,
Tensor,
Tile,
Transform,
Transpose,
)

Expand All @@ -36,7 +37,10 @@ def do(
xp = array_namespace(*x)

try:
if isinstance(op, Reshape):
if op.is_identity_for(ins):
# no-op
return x
elif isinstance(op, Reshape):
new_shape = tuple(map(self.constr.value_of, op.new_shape))
return [xp.reshape(x[0], shape=new_shape)]
elif isinstance(op, Transpose):
Expand Down Expand Up @@ -112,6 +116,8 @@ def do(
elif isinstance(op, Reduce):
# print(ins, _outs, op, id(_outs[0]))
return [op.method(x[0], axis=ins[0].axes.index(op.axis))]
elif isinstance(op, Transform):
return [op.method(x[0], axis=ins[0].axes.index(op.axis))]
else:
msg = 'Op not supported: ' + str(op)
raise TypeError(msg)
Expand Down
125 changes: 97 additions & 28 deletions src/eins/einsop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,41 @@
from typing import AnyStr, Callable, Mapping, MutableMapping, Optional, Sequence, Union

from eins.combination import (
ARRAY_COMBINE_OPS,
Combination,
CombineLiteral,
CompositeCombination,
CustomCombination,
parse_combination,
)
from eins.combination import (
ops as _combination_ops,
)
from eins.common_types import Array
from eins.concrete import ArrayBackend
from eins.elementwise import ElementwiseLiteral, ElementwiseOp, parse_elementwise
from eins.elementwise import ops as _elementwise_ops
from eins.parsing import Constant, Symbol
from eins.program import Program
from eins.program import Program, TransformProgram
from eins.reduction import (
ARRAY_REDUCE_OPS,
CompositeReduction,
CustomReduction,
Reduction,
ReductionLiteral,
parse_reduction,
)
from eins.reduction import (
ops as _reduction_ops,
)
from eins.strategy import BaseStrategy
from eins.symbolic import Tensor
from eins.transformation import Transformation, TransformationLiteral, parse_transformation
from eins.transformation import (
CompositeTransformation,
CustomTransformation,
Transformation,
TransformationLiteral,
parse_transformation,
)
from eins.transformation import ops as _transformation_ops

ElementwiseKind = Union[ElementwiseLiteral, Callable, ElementwiseOp]
# use AnyStr to ensure autocomplete works
Expand All @@ -41,8 +53,8 @@
ReductionKind, Sequence[Union[ElementwiseKind, TransformationKind, ReductionKind]]
]
ReduceArg = Union[GeneralReductionKind, Mapping[str, GeneralReductionKind]]

CombineArg = Union[CombinationKind, Sequence[Union[ElementwiseKind, CombinationKind]]]
TransformArg = Union[TransformationKind, Sequence[Union[ElementwiseKind, TransformationKind]]]


def _parse_reduce_arg(reduce: GeneralReductionKind) -> Reduction:
Expand All @@ -56,7 +68,7 @@ def _parse_reduce_arg(reduce: GeneralReductionKind) -> Reduction:
return reduce_parse

msg = f'Cannot parse reduction {reduce}. Valid literals are: ' + ', '.join(
ARRAY_REDUCE_OPS + ARRAY_COMBINE_OPS
list(_reduction_ops) + list(_combination_ops)
)
raise ValueError(msg)
else:
Expand Down Expand Up @@ -107,6 +119,51 @@ def _parse_reduce_arg(reduce: GeneralReductionKind) -> Reduction:
return CompositeReduction(tuple(ops))


def _parse_transform_arg(transform: TransformArg) -> Transformation:
if isinstance(transform, Transformation):
return transform
elif isinstance(transform, Callable):
return CustomTransformation(transform)
elif isinstance(transform, str):
combo_parse = parse_transformation(transform)
if combo_parse is not None:
return combo_parse

msg = f'Cannot parse transformation {transform}. Valid literals: {", ".join(_transformation_ops)}'
raise ValueError(msg)
else:
ops = []
for op in transform:
if isinstance(op, (ElementwiseOp, Transformation)):
ops.append(op)
continue

if isinstance(op, Callable):
# callables are ambiguous here
msg = f"""
User-supplied function in transform={transform} is ambiguous: either write a custom lambda combining
these operations or explicitly create objects using e.g., eins.ElementwiseOps.from_func().
"""
raise ValueError(msg)

op_parse = parse_transformation(op)
if op_parse is not None:
ops.append(op_parse)
continue

op_parse = parse_elementwise(op)
if op_parse is not None:
ops.append(op_parse)
continue

msg = f'Cannot parse operation {op} in {ops}. Valid literals: ' + (
'\n'.join([', '.join(list(_transformation_ops) + list(_elementwise_ops))])
)
raise ValueError(msg)

return CompositeTransformation(tuple(ops))


def _parse_combine_arg(combine: CombineArg) -> Combination:
if isinstance(combine, Combination):
return combine
Expand All @@ -117,9 +174,7 @@ def _parse_combine_arg(combine: CombineArg) -> Combination:
if combo_parse is not None:
return combo_parse

msg = (
f'Cannot parse reduction {combine}. Valid literals are: {", ".join(ARRAY_COMBINE_OPS)}'
)
msg = f'Cannot parse combination {combine}. Valid literals: {", ".join(_combination_ops)}'
raise ValueError(msg)
else:
ops = []
Expand All @@ -146,13 +201,8 @@ def _parse_combine_arg(combine: CombineArg) -> Combination:
ops.append(op_parse)
continue

msg = f'Cannot parse operation {op} in {ops}. Valid literals are: ' + (
'\n'.join(
[
', '.join(typing.get_args(ops))
for ops in (CombineLiteral, ElementwiseLiteral)
]
)
msg = f'Cannot parse operation {op} in {ops}. Valid literals: ' + (
'\n'.join([', '.join(list(_combination_ops) + list(_elementwise_ops))])
)
raise ValueError(msg)

Expand All @@ -169,6 +219,7 @@ def __init__(
*,
reduce: ReduceArg = 'sum',
combine: CombineArg = 'multiply',
transform: Optional[Mapping[str, TransformArg]] = None,
symbol_values: Optional[Mapping[str, int]] = None,
):
"""
Expand All @@ -183,8 +234,8 @@ def __init__(
multiplication, and `'batch (size size) channels -> batch size size channels'` unpacks a
batch of square images. Any amount of whitespace can surround `->` and `,`.
reduce: function f(Array, axis: int) → Array, Reduction, str, or mapping from axes to
previous
reduce: function f(Array, axis: int) → Array, Reduction, str, sequence of elementwise ops,
transformations, and reductions, or mapping from axes to previous
Describes how axes that appear in the input but not the output are eliminated: use
[eins.Reductions] to get an autocomplete-friendly list of options. The default is
`'sum'`, like in `einsum`. Common alternatives are `'mean'`, `'std'`, `'max'`, and
Expand All @@ -206,8 +257,8 @@ def __init__(
'c': 'min'}` has two meanings, depending on which happens first. Instead, you can pass
`'a b c -> a b -> a'`, which forces a specific order.
combine: function f(Array, Array) → Array, Combination, str, or mapping from axes to
previous
combine: function f(Array, Array) → Array, Combination, str, sequence of elementwise ops and
a combination, or mapping from axes to previous
Describes how the elements of different input tensors are combined: use
[eins.Combinations] to get an autocomplete-friendly list of options. The default is
`'multiply'`, which is what `einsum` does. This can be a list of elementwise operations
Expand All @@ -218,30 +269,48 @@ def __init__(
same shape as the two inputs. `eins` makes no guarantees about the order combinations
are performed, so this function should be commutative and associative.
transform: mapping from axes to one of: function f(Array, axis: int) → Array,
Transformation, str, or sequence of elementwise ops and transforms
symbol_values: mapping from symbols to integers or None
An alternative to using = to specify axis values.
"""
if '->' not in op:
msg = f'Einsop "{op}" has no "->", which is required'
if transform is None:
transform = {}

if '->' not in op and len(transform) == 0:
msg = f'Einsop "{op}" has no "->", which is required unless transform is given.'
raise ValueError(msg)
elif '->' in op and len(transform) > 0:
msg = f'Einsop "{op}" has "->", which is not allowed with transform.'
raise ValueError(msg)

self.is_transform = len(transform) > 0

self.op_str = op

if isinstance(reduce, Mapping):
self.reduce = {k: _parse_reduce_arg(v) for k, v in reduce.items()}
if self.is_transform:
self.transform = {k: _parse_transform_arg(v) for k, v in transform.items()}
self.program = TransformProgram.parse(self.op_str, transform=self.transform)
else:
self.reduce = _parse_reduce_arg(reduce)
if isinstance(reduce, Mapping):
self.reduce = {k: _parse_reduce_arg(v) for k, v in reduce.items()}
else:
self.reduce = _parse_reduce_arg(reduce)

self.combine = _parse_combine_arg(combine)
self.combine = _parse_combine_arg(combine)

self.program = Program.parse(self.op_str, combine=self.combine, reduce=self.reduce)
self.program = Program.parse(self.op_str, combine=self.combine, reduce=self.reduce)

self.symbol_values = symbol_values or {}
for k, v in self.symbol_values.items():
self.program.constr.add_constraint(Symbol(k), Constant(v))

def __repr__(self) -> str:
return f'EinsOp({self.op_str}, reduce={self.reduce}, combine={self.combine})'
if self.is_transform:
return f'EinsOp({self.op_str}, transform={self.transform})'
else:
return f'EinsOp({self.op_str}, reduce={self.reduce}, combine={self.combine})'

def __str__(self) -> str:
names = {}
Expand Down
2 changes: 1 addition & 1 deletion src/eins/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __call__(self, arr: Array) -> Array:
raise ValueError(msg) from None


@dataclass
@dataclass(frozen=True, unsafe_hash=True)
class CustomElementwiseOp(ElementwiseOp):
"""Elementwise operation defined by user.
Expand Down
Loading

0 comments on commit b31eab0

Please sign in to comment.