Skip to content

Commit

Permalink
Commutes elementwise, more typefixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-miklaucic committed Apr 24, 2024
1 parent 1e91b9b commit 1315ab5
Show file tree
Hide file tree
Showing 19 changed files with 431 additions and 82 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Eins is still in heavy development. Here's a sense of where we're headed.
### Near-Term (weeks)

- [ ] Updating indexing syntax to match `eindex`
- [ ] Unit array to indicate zero-dimensional tensors
- [x] Unit array to indicate zero-dimensional tensors
- [ ] `...` for batching over dynamic numbers of batch axes
- [ ] Specifying intermediate results to control the order of reduction
- [ ] Support `-` and `/`
Expand Down
78 changes: 62 additions & 16 deletions benchmarking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,93 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from eins import EinsOp, Reductions as Red\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np"
"import numpy as np\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(1024, 256, 3)\n",
"y = torch.randn(1024, 256, 3)\n",
"\n",
"z4 = EinsOp('b n1 d, b n2 d -> b n1 n2', combine='add', reduce=Red.l2_norm)(x, -y)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\nicho\\miniforge3\\lib\\site-packages\\jax\\_src\\numpy\\array_methods.py:64: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
" return lax_numpy.astype(arr, dtype)\n"
]
},
{
"data": {
"text/plain": [
"Array(0., dtype=float32)"
]
},
"execution_count": 2,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mz4\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'DTypeLike | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Array'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m\n",
"Return a bitwise copy of the array, viewed as a new dtype.\n",
"\n",
"This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`.\n",
"\n",
"If the source and target dtype have the same bitwidth, the result has the same\n",
"shape as the input array. If the bitwidth of the target dtype is different\n",
"from the source, the size of the last axis of the result is adjusted\n",
"accordingly.\n",
"\n",
">>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape\n",
"(1, 2, 6)\n",
">>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape\n",
"(1, 2, 2)\n",
"\n",
"Conversions involving booleans are not well-defined in all situations. With\n",
"regards to the shape of result as explained above, booleans are treated as\n",
"having a bitwidth of 8. However, when converting to a boolean array, the input\n",
"should only contain 0 or 1 bytes. Otherwise, results may be unpredictable or\n",
"may change depending on how the result is used.\n",
"\n",
"This conversion is guaranteed and safe:\n",
">>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_)\n",
"Array([ True, False, True], dtype=bool)\n",
"\n",
"However, there are no guarantees about the results of any expression involving\n",
"a view such as this: `jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)`.\n",
"In particular, the results may change between JAX releases and depending on\n",
"the platform. To safely convert such an array to a boolean array, compare it\n",
"with `0`:\n",
"\n",
">>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0\n",
"Array([ True, True, False], dtype=bool)\n",
"\u001b[0;31mFile:\u001b[0m ~/.local/share/hatch/env/virtual/eins/sThVc9L5/eins/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py\n",
"\u001b[0;31mType:\u001b[0m method"
]
}
],
"source": [
"x = jnp.array(np.random.randn(1024, 256, 3))\n",
"y = jnp.array(np.random.randn(1024, 256, 3))\n",
"x: jax.Array = jnp.array(np.random.randn(1024, 256, 3))\n",
"y: jax.Array = jnp.array(np.random.randn(1024, 256, 3))\n",
"\n",
"z4 = EinsOp('b n1 d, b n2 d -> b n1 n2', combine='add', reduce=Red.l2_norm)(x, -y)\n",
"\n",
"# Version without eins. Note how easy it would be to write x[:, None, ...] - y[:, :, None, ...],\n",
"# which would lead to the transposed version of the pairwise distances you want.\n",
"z5 = jnp.sqrt(jnp.sum(jnp.square(x[:, :, None, ...] - y[:, None, ...]), axis=-1))\n",
Expand Down Expand Up @@ -278,7 +324,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
3 changes: 3 additions & 0 deletions development_docs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Development Documentation

This folder is not for end users. It describes how Eins works internally.
57 changes: 57 additions & 0 deletions development_docs/ordering.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Ordering

Ordering operations is relatively complicated.


```
a b, b c, c d -> a d
```

We can combine this as `a b, b c -> a b c` and then `a b c, c d -> a b c d` or in the other order.
As long as the combination is commutative/associative, this won't matter.

Sometimes, we can save memory by reducing early:

```
a b, b c -> a c
a c, c d -> a d
```

But this isn't generally true.

Here, it works because

$$
\begin{aligned}
{AD}_{ad} &= \sum_{b \in B} \sum_{c \in C} {AB}_{ab} {BC}_{bc} {CD}_{cd} \\
&= \sum_{b \in B} {AB}_{ab} \sum_{c \in C} {BC}_{bc} {CD}_{cd} \\
&= \sum_{c \in C} {CD}_{cd} \sum_{b \in B} {AB}_{ab} {BC}_{bc}
\end{aligned}
$$

The reordering of the summations is fine, because we assume that reductions are commutative. But moving a term outside the sum only works because

```
sum(x * y, x * z) = x * sum(y, z)
```

The vast majority of reductions are folded binary operations. Here, this is expressing the distributive property, and saying that this needs to form a semiring.

There are a couple other examples of common reductions that work with this pattern:

- The tropical semiring: max(x + y, x + z) = x + max(y, z), similarly with minimum.
- Min/minimum, max/maximum. Note that sum/add and multiply/prod don't work, because the difference in order affects whether elements get repeated and then reduced. Max and minimum are idempotent so they don't care.

Many other reductions can be decomposed into elementwise ops and one of the basic central operations. (Basically all of them can: there are basically no other valid operations.) We can use that to generalize a bit.

For example, consider $\max(\sqrt{x + y}, \sqrt{x + z}) = \sqrt{x + \max(y, z)}$.

This is true because of a couple things:
- Max-plus is a semiring, as seen above.
- Square root is monotonic—it commutes with max.

We would like for this to also be true with the Euclidean distance version that squares the inputs. It's only true then if we order the squaring before the max. Such small differences are challenging from a user interface perspective.

Note that the version of this with logs works cleanly, because both log and exp are monotonic.

Note that there are going to be numerical differences in how these are handled. I'm fine with that.
74 changes: 48 additions & 26 deletions docs/in-depth.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,64 @@ many other things. To control the computation that's being performed beyond the
defines four kinds of functions that specify what's actually happening:

### Combinations
Combinations are elementwise functions that combine two arrays into a new array of the same shape. The default,
`'multiply'`, multiplies inputs, as in matrix multiplication.

It's much easier to ensure Eins does what you want when these are commutative and associative. Instead of trying to
specify subtraction, use addition and then negate the input you want to subtract. This gives Eins freedom to optimize
your computation.
Combinations are, mathematically, functions that take two scalars and output a scalar. In Eins,
combinations should be vectorized, taking in two arrays of the same shape and returning an array of
that shape. The most common examples are `np.add` and `np.multiply`.

**Common examples**: `'add'`, `'multiply'`, `'minimum'`, `'maximum'`


!!! danger

Eins assumes that a combination is commutative and associative, and it makes no guarantees about the
order your arrays are combined. If you supply custom functions, that responsibility is yours.

### Reductions
Reductions take a single array and an axis and eliminate that axis. The default, `'sum'`, sums over an axis, as in
matrix multiplication.
Reductions are essentially functions that take in a vector of any size and return a scalar, like
`np.sum`. (These are sometimes called aggregations.) In Eins, they're functions that take an array
and an axis and return an array with that axis removed.

If you pass in a combination, Eins will essentially apply `functools.reduce` and use that combination to reduce the
axis. In general, however, there are more efficient ways of doing the same thing: a folded `'add'` is just a slower
`'sum'`, and a folded `hypot` is just a slower `l2-norm`.
If you pass in a combination, Eins will essentially apply `functools.reduce` and use that
combination to reduce the axis. In general, however, there are more efficient ways of doing the same
thing: a folded `'add'` is just a slower `'sum'`, and a folded `hypot` is just a slower `l2-norm`.

**Common examples**: `'sum'`, `'prod'`, `'l2_norm'`, `'min'`, `'max'`.

Note the naming conventions, matching NumPy nomenclature. `np.max(arr, axis=0)` computes the max along axis 0,
eliminating it. `np.maximum(arr1, arr2)` is the elementwise maximum between two arrays.

!!! danger

If you reduce more than once in a program, Eins assumes you know what you're doing and that the
operation would be the same either way, like summing over two axes. If you supply a custom function,
make sure there is only one potential output.



### Elementwise Operations
An elementwise op takes in a single array and returns an array of the same size, applying an operation individually to
each element. Eins doesn't use these explicitly, but you can combine them with combinations or reductions to
ergonomically represent more complex functions.

An elementwise operation should be thought of as a function that takes a scalar and outputs a
scalar. Eins requires that the operation is *vectorized*, so it takes in an array and outputs an
array of the same shape.

**Common examples**: `'log'`, `'exp'`, `'tanh'`, `'square'`, `'sqrt'`

### Transformations
Named after the `.transform` method in Pandas, transformations take in a single array and `axis`, like reductions, but
they don't eliminate the axis. For example, `np.sort(arr, axis=0)` is different than `np.sort(arr, axis=1)`, but both
return the same shape.

Just like a folded combination becomes a reduction, a *scanned* or *accumulated* combination becomes a transformation.
Note that the way NumPy and other libraries notate these differs from the idea of a scan. `cumprod`, in Eins, is really
just an alias for `cummultiply`, because Eins uses the combination rather than the reduction. If you have an array with elements `[a, b, c, d]` and an operator like `*`, then Eins computes
Named after the `.transform` method in Pandas, transformations should be thought of mathematically
as functions that take in a vector of any size and produce a vector of the same size. Think of
sorting or standardization: you need multiple inputs for standardization to make sense, but at the
end you haven't changed the shape of the array.

In Eins, transformations take in a single array and `axis`, like reductions, but they don't
eliminate the axis. For example, `np.sort(arr, axis=0)` is different than `np.sort(arr, axis=1)`,
but both return an array of the same shape as `arr`.

Just like a folded combination becomes a reduction, a *scanned* or *accumulated* combination becomes
a transformation. Note that the way NumPy and other libraries notate these differs from the idea of
a scan. `cumprod`, in Eins, is really just an alias for `cummultiply`, because Eins uses the
combination rather than the reduction. If you have an array with elements `[a, b, c, d]` and an
operator like `*`, then Eins computes

```python
[a, a * b, (a * b) * c, ((a * b) * c) * d]
Expand All @@ -73,13 +94,14 @@ Similarly, if you wanted to compute root-mean-square error along an axis, you co

### Explicit Function Objects

Eins supports a relatively sophisticated "stringly-typed" input format, as you've seen above. This means you rarely need
any imports beyond `EinsOp`, and plays nicely with any kind of config files or serialization, but it does also make it
harder to know what functions Eins defines or use your own.
Eins supports a relatively sophisticated "stringly-typed" input format, as you've seen above. This
means you rarely need any imports beyond `EinsOp`, and you can easily serialize the description of
the operation, but it does also make it harder to know what functions Eins defines or use your own.

If you prefer, you can instead pass in explicit objects: `Combination`, `Reduction`, `ElementwiseOp`, and
`Transformation`. These are each base classes that you can implement yourself, but it's easiest to use the associated
object exported from the base namespace: `Combinations`, `Reductions`, etc. These namespaces provide an autocomplete-friendly way of using these operations.
If you prefer, you can instead pass in explicit objects: `Combination`, `Reduction`,
`ElementwiseOp`, and `Transformation`. These are each base classes that you can implement yourself,
but it's easiest to use the associated object exported from the base namespace: `Combinations`,
`Reductions`, etc. These namespaces provide an autocomplete-friendly way of using these operations.

Explicit objects are the only way to specify compositions with function syntax. If you pass in a callable to `combine`
or `reduce`, Eins will assume it has the correct signature, but if you pass in `(my_func1, my_func2)` Eins has no way of
Expand Down
4 changes: 2 additions & 2 deletions docs/stylesheets/extra.css
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
font-family: "Instrument Sans";
font-style: italic;
font-weight: 400;
src: url("InstrumentSansItalic[wdth\,wght].ttf") format('woff2');
src: url("InstrumentSans-Italic[wdth\,wght].ttf") format('woff2');
font-variation-settings: "wght" 430, "wdth" 90;
}

@font-face {
font-family: "Instrument Sans";
font-style: italic;
font-weight: 700;
src: url("InstrumentSansItalic[wdth\,wght].ttf") format('woff2');
src: url("InstrumentSans-Italic[wdth\,wght].ttf") format('woff2');
font-variation-settings: "wght" 700, "wdth" 90;
}

Expand Down
9 changes: 5 additions & 4 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ That would give us the array of vectors between points, of shape `batch n1 n2 d`
We do this by computing the Euclidean norm: the square root of the sum of the squares of the values
along the axis. This is called the $L_2$ norm, hence the name.

### Batched Custom Loss
### Batched Mean Huber Loss

The literals that Eins accepts are documented properly in the type system, so you should get a handy
autocomplete for a name like `l2_norm`. The time will come when one of those options isn't
Expand All @@ -249,7 +249,7 @@ One solution is to simply pass in your own function. Combinations should have tw
arguments and output an array of the same size, and custom reductions should take in a single
positional argument and either `axis` or `dim` as keyword arguments.

```py title="Average Huber Loss"
```py title="Batched Mean Huber Loss"
from torch.nn.functional import huber_loss

EinsOp('batch out_features, batch out_features -> batch',
Expand All @@ -265,9 +265,10 @@ shape of the output:

- **Elementwise functions** are basically just functions from real numbers to real numbers that you
can batch arbitrarily. Examples are `np.sin`, `np.abs`, and `np.exp`. They have the signature
`f(Array) -> Array`.
`f(Array) -> Array`, and mathematically they're functions from a scalar to a scalar.
- **Transformations** use an axis, but don't eliminate it. Examples are `np.sort`, `np.flip`,
`np.roll`, and normalization. They have the signature `f(Array, axis: int) -> Array`.
`np.roll`, and normalization. They have the signature `f(Array, axis: int) -> Array`, and
mathematically they're functions from a vector to a vector.

Eins implements a library of these functions to go along with combinations and reductions. Combining
them lets you make new functions that are easy to reason about and framework-agnostic. Passing in a
Expand Down
4 changes: 2 additions & 2 deletions examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from eins.namespaces import ElementwiseOps

# Set this to 'jax', 'numpy', or 'torch'
BACKEND = 'jax'
BACKEND = 'numpy'

if BACKEND == 'jax':
import jax.numpy as jnp
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_close(a: Array, b: Array):
y = randn(8, 6, 32)
op = EinsOp('b n1 d, b n2 d -> b n1 n2', combine='add', reduce='l2_norm')
z1 = op(x, -y)
z2 = EinsOp('b n1 d, b n2 d -> b n1 n2', combine='add', reduce=('sum', 'square'))(x, -y)
z2 = EinsOp('b n1 d, b n2 d -> b n1 n2', combine='add', reduce=('sqrt', 'sum', 'square'))(x, -y)
z3 = EinsOp('b n1 d, b n2 d -> b n1 n2', combine='add', reduce='hypot')(x, -y)
z4 = EinsOp('b n1 d, b n2 d -> b n1 n2', combine='add', reduce=R.l2_norm)(x, -y)

Expand Down
Loading

0 comments on commit 1315ab5

Please sign in to comment.