Skip to content

Commit

Permalink
Add more tests for flax nnx layers (#10)
Browse files Browse the repository at this point in the history
* Add additional tests

* Fix tests and add a few additional tests
  • Loading branch information
kasper0406 authored Oct 20, 2024
1 parent 7b5fc6b commit 5c50a20
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 10 deletions.
11 changes: 10 additions & 1 deletion stablehlo_coreml/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
Log1pOp, SqrtOp, ConstantOp, DotGeneralOp, ReshapeOp, BroadcastInDimOp, WhileOp,
CompareOp, ConvertOp, SelectOp, DynamicSliceOp, ReturnOp, ConvolutionOp, MinOp,
MaxOp, RsqrtOp, TanhOp, SineOp, CosineOp, TanOp, Atan2Op, ConcatenateOp, TransposeOp,
DynamicUpdateSliceOp, SliceOp, CustomCallOp, IotaOp, ReduceOp, OrOp, AndOp, ReverseOp
DynamicUpdateSliceOp, SliceOp, CustomCallOp, IotaOp, ReduceOp, OrOp, AndOp, ReverseOp,
IsFiniteOp,
)
from jaxlib.mlir.dialects.mhlo import (TopKOp)
from jax._src.lib.mlir.dialects import hlo
Expand Down Expand Up @@ -762,6 +763,14 @@ def op_reverse(self, context: TranscriptionContext, op: ReverseOp):
mil_res = mb.reverse(x=x, axes=np.array(op.dimensions, dtype=np.int32))
context.add_variable(op.result.get_name(), mil_res)

@register_stablehlo_op
def op_isfinite(self, context: TranscriptionContext, op: IsFiniteOp):
x = context[op.x.get_name()]
# All finite numbers will have abs(x) < inf
infinity = np.array(np.inf, dtype=types.nptype_from_builtin(self.__resolve_type(x)))
mil_res = mb.less(x=mb.abs(x=x), y=infinity)
context.add_variable(op.result.get_name(), mil_res)

@register_stablehlo_op
def op_reduce(self, context: TranscriptionContext, op: ReduceOp):
# HLO reductions can be arbitrarily complex and defines a custom function
Expand Down
101 changes: 100 additions & 1 deletion tests/test_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from flax import nnx
import jax.numpy as jnp

from tests.test_jax import run_and_compare
from tests.test_jax import run_and_compare, run_and_compare_specific_input

from tests.flax_blocks import ResidualConv, Encoder, UNet, UNetWithXlstm
from tests.flax_xlstm import sLSTMCell, sLSTMBlock, mLSTMCell, mLSTMBlock, xLSTMModule, xLSTM

from functools import partial


def test_flax_nnx_linear():
class TestLinear(nnx.Module):
Expand Down Expand Up @@ -384,3 +386,100 @@ def test_unet_with_xlstm():
model.eval()

run_and_compare(nnx.jit(model), (carry, x, ))


def test_activations():
example_input = (jnp.zeros((20,)),)
run_and_compare(nnx.celu, example_input)
run_and_compare(nnx.elu, example_input)
run_and_compare(nnx.gelu, example_input)
run_and_compare(nnx.glu, example_input)
run_and_compare(nnx.hard_sigmoid, example_input)
run_and_compare(nnx.hard_silu, example_input)
run_and_compare(nnx.hard_swish, example_input)
run_and_compare(nnx.hard_tanh, example_input)
run_and_compare(nnx.leaky_relu, example_input)
run_and_compare(nnx.log_sigmoid, example_input)
run_and_compare(nnx.log_softmax, example_input)
run_and_compare(nnx.logsumexp, example_input)
run_and_compare(nnx.relu, example_input)
run_and_compare(nnx.selu, example_input)
run_and_compare(nnx.sigmoid, example_input)
run_and_compare(nnx.silu, example_input)
run_and_compare(nnx.soft_sign, example_input)
run_and_compare(nnx.softmax, example_input)
run_and_compare(nnx.softplus, example_input)
run_and_compare(nnx.standardize, example_input)
run_and_compare(nnx.swish, example_input)
run_and_compare(nnx.tanh, example_input)

run_and_compare_specific_input(partial(nnx.one_hot, num_classes=3), (jnp.array([0, 1, 2]), ))
run_and_compare_specific_input(partial(nnx.one_hot, num_classes=5), (jnp.array([4, 0, 1, 0]), ))


def test_attantion():
class TestAttention(nnx.Module):
def __init__(self, rngs=nnx.Rngs):
self.layer = nnx.MultiHeadAttention(
num_heads=4,
in_features=5,
qkv_features=16,
decode=False,
rngs=rngs,
)

def __call__(self, q, k, v):
return self.layer(q, k, v)

shape = (4, 3, 2, 5)
input_spec = (jnp.zeros(shape), jnp.zeros(shape), jnp.zeros(shape))
run_and_compare(nnx.jit(TestAttention(nnx.Rngs(0))), input_spec)

@nnx.jit
def create_masks(length):
attention_mask = nnx.make_attention_mask(length, length)
causal_mask = nnx.make_causal_mask(length)
return nnx.combine_masks(attention_mask, causal_mask)

run_and_compare(create_masks, (jnp.zeros((5, 20)), ))


# This test currently makes Python crash due to https://github.com/llvm/llvm-project/pull/113064
# def test_embed():
# model = nnx.Embed(num_embeddings=10, features=5, rngs=nnx.Rngs(0))
# example_input = (jnp.array([[1, 5, 3], [9, 3, 0]], dtype=jnp.int32), )
# run_and_compare_specific_input(nnx.jit(model), example_input)


def test_nnx_einsum():
layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0))
example_input = (jnp.zeros((16, 11, 2)), )
run_and_compare(nnx.jit(layer), example_input)


def test_batch_norm_infer():
layer = nnx.BatchNorm(num_features=10, momentum=0.9, epsilon=1e-5, rngs=nnx.Rngs(0))
layer.eval()
example_input = (jnp.zeros((20, 10)), )
run_and_compare(nnx.jit(layer), example_input)


def test_layer_norm_infer():
layer = nnx.LayerNorm(num_features=10, rngs=nnx.Rngs(0))
layer.eval()
example_input = (jnp.zeros((20, 10)), )
run_and_compare(nnx.jit(layer), example_input)


def test_rms_norm_infer():
layer = nnx.RMSNorm(num_features=10, rngs=nnx.Rngs(0))
layer.eval()
example_input = (jnp.zeros((20, 10)), )
run_and_compare(nnx.jit(layer), example_input)


def test_group_norm_infer():
layer = nnx.GroupNorm(num_features=10, num_groups=2, rngs=nnx.Rngs(0))
layer.eval()
example_input = (jnp.zeros((20, 10)), )
run_and_compare(nnx.jit(layer), example_input)
37 changes: 29 additions & 8 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def test_trigonmetry():
run_and_compare(jnp.atan2, (jnp.zeros((50, 20)), jnp.zeros((50, 20)),))


def test_is_finite():
input = (jnp.array([20.0, -12.23, jnp.inf, -jnp.inf, jnp.nan], dtype=jnp.float16), )
run_and_compare_specific_input(jnp.isfinite, input)
run_and_compare_specific_input(jnp.isinf, input)
run_and_compare_specific_input(jnp.isnan, input)


def jax_export(jax_func, input_spec):
def compute_input_shapes(input_specs):
shapes = []
Expand Down Expand Up @@ -233,16 +240,15 @@ def count_block(block: Block):
return total_complexity


def run_and_compare(jax_func, input_spec, max_complexity: int = 10_000):
def run_and_compare_specific_input(jax_func, inputs, max_complexity: int = 10_000):
"""
Converts the given `jax_func` to a CoreML model.
Both models will be run on random input data with shapes specified by `input_spec`.
If the CoreML model and `jax_func` does not agree on the output, an error will be raised.
The resulting CoreML model will be returned.
"""

jax_func = jax.jit(jax_func)
exported = jax_export(jax_func, input_spec)
exported = jax_export(jax_func, inputs)
context = jax_mlir.make_ir_context()
hlo_module = ir.Module.parse(exported.mlir_module(), context=context)
# print(f"HLO module: {hlo_module}")
Expand Down Expand Up @@ -272,15 +278,12 @@ def run_and_compare(jax_func, input_spec, max_complexity: int = 10_000):
# Generate random inputs that matches cml_model input spec
cml_input_key_values = {}
jax_input_values = []
key = jax.random.PRNGKey(0)
for input_name, input_shape in zip(cml_model.input_description, exported.in_avals):
key, value_key = jax.random.split(key, num=2)
input_value = generate_random_from_shape(input_shape, value_key)
for input_name, input_value in zip(cml_model.input_description, flatten(inputs)):
cml_input_key_values[input_name] = input_value
jax_input_values.append(input_value)

# Transfor the input to match the Jax model, and call it
jax_input_values = __nest_flat_jax_input_to_input_spec(input_spec, jax_input_values)
jax_input_values = __nest_flat_jax_input_to_input_spec(inputs, jax_input_values)
expected_output = jax_func(*jax_input_values)

# TODO(knielsen): Is there a nicer way of doing this?
Expand All @@ -297,6 +300,24 @@ def run_and_compare(jax_func, input_spec, max_complexity: int = 10_000):
return cml_model


def run_and_compare(jax_func, input_specification, max_complexity: int = 10_000):
"""
Converts the given `jax_func` to a CoreML model.
The model will be tested with randomly generated data with the shapes of `input_specification`.
If the CoreML model and `jax_func` does not agree on the output, an error will be raised.
The resulting CoreML model will be returned.
"""
flat_inputs = []
key = jax.random.PRNGKey(0)
for input_spec in flatten(input_specification):
key, value_key = jax.random.split(key, num=2)
input_value = generate_random_from_shape(input_spec, value_key)
flat_inputs.append(input_value)

inputs = __nest_flat_jax_input_to_input_spec(input_specification, flat_inputs)
return run_and_compare_specific_input(jax_func, inputs, max_complexity=max_complexity)


def get_model_instruction_types(cml_model) -> List[str]:
def collect_ops(ops: List) -> List[str]:
collected_ops = []
Expand Down

0 comments on commit 5c50a20

Please sign in to comment.