Skip to content

Commit

Permalink
Implement the topk operation and fix broadcasting (#4)
Browse files Browse the repository at this point in the history
* Attempt at implementing the topk custom mhlo call

* Implement the topk operation and fix broadcasting
  • Loading branch information
kasper0406 authored Aug 27, 2024
1 parent 4e6cb35 commit db50958
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 25 deletions.
84 changes: 66 additions & 18 deletions stablehlo_coreml/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import Function, Program, types
from coremltools.converters.mil._deployment_compatibility import AvailableTarget
from coremltools.converters.mil.mil.ops.defs._utils import (
promote_input_dtypes,
)
from .utils import index_by_slices, update_tensor_by_slice, iterate_indexes_in_shapes

from jaxlib.mlir import ir
Expand All @@ -11,16 +14,17 @@
AddOp, SubtractOp, MulOp, DivOp, NegOp, SignOp, AbsOp, ExpOp, Log1pOp, SqrtOp,
ConstantOp, DotGeneralOp, ReshapeOp, BroadcastInDimOp, WhileOp, CompareOp,
ConvertOp, SelectOp, DynamicSliceOp, ReturnOp, ConvolutionOp, MaxOp, RsqrtOp,
TanhOp, ConcatenateOp, TransposeOp, DynamicUpdateSliceOp, SliceOp
TanhOp, ConcatenateOp, TransposeOp, DynamicUpdateSliceOp, SliceOp, CustomCallOp,
IotaOp
)
from jaxlib.mlir.dialects.mhlo import (TopKOp)
from jax._src.lib.mlir.dialects import hlo

import numpy as np

from typing import List, Optional
import inspect
from functools import partial, reduce
import operator


def convert(module, minimum_deployment_target: AvailableTarget):
Expand Down Expand Up @@ -298,6 +302,7 @@ def op_sqrt(self, context: TranscriptionContext, op: SqrtOp):
@register_stablehlo_op
def op_constant(self, context: TranscriptionContext, op: ConstantOp):
constant = np.array(op.value)
constant = np.reshape(constant, op.result.type.shape)
context.add_variable(op.result.get_name(), constant)

@register_stablehlo_op
Expand All @@ -317,9 +322,7 @@ def op_dot_general(self, context: TranscriptionContext, op: DotGeneralOp):
rhs = context[op.rhs.get_name()]

def multiply(lst: List):
if len(lst) == 0:
return 1
return reduce(lambda a, b: int(a) * int(b), lst)
return reduce(lambda a, b: int(a) * int(b), lst, 1)

def last_column_dot(lhs, rhs):
# TODO: Figure out if we need to special case broadcasting dims
Expand Down Expand Up @@ -416,22 +419,20 @@ def op_reshape(self, context: TranscriptionContext, op: ReshapeOp):

@register_stablehlo_op
def op_broadcast_in_dim(self, context: TranscriptionContext, op: BroadcastInDimOp):
# TODO(knielsen): Consider if this is actually correct!
# CoreML seems to auto-broadcast along the lines of numpy. Therefore this
# explicit broadcasting op is not necessary.
x = context[op.operand.get_name()]

# We handle one special case where the broadcast functions as a reshape
op_elements = reduce(operator.mul, op.result.type.shape, 1)
x_elements = reduce(operator.mul, x.shape, 1)
if op_elements == x_elements:
# We know that the only possibility is for data to be added, so this is likely a reshape
x = mb.reshape(x=x, shape=op.result.type.shape)
reshaped_operand_shape = [1] * len(op.result.type.shape)
for i, op_shape in enumerate(op.operand.type.shape):
result_idx = op.broadcast_dimensions[i]
reshaped_operand_shape[result_idx] = op_shape

# Another special case. If we are broadcasting a constant in all directions, just change the shape
if len(op.broadcast_dimensions) == 0 and len(x.shape) == 0:
dtype = types.nptype_from_builtin(self.__resolve_type(x))
x = mb.mul(x=x, y=np.ones(op.result.type.shape, dtype=dtype))
x = mb.reshape(x=x, shape=reshaped_operand_shape)
for result_dim, current_shape in enumerate(reshaped_operand_shape):
if current_shape != op.result.type.shape[result_dim]:
assert current_shape == 1
# Replicate data along dimension `dim` until the result dimension is filled up
values = [x] * op.result.type.shape[result_dim]
x = mb.concat(values=values, axis=result_dim)

context.add_variable(op.result.get_name(), x)

Expand Down Expand Up @@ -677,9 +678,56 @@ def op_tanh(self, context: TranscriptionContext, op: TanhOp):
@register_stablehlo_op
def op_concatenate(self, context: TranscriptionContext, op: ConcatenateOp):
values = [context[input.get_name()] for input in op.inputs]
values = promote_input_dtypes(values)
mil_res = mb.concat(values=values, axis=op.dimension.value)
context.add_variable(op.result.get_name(), mil_res)

@register_stablehlo_op
def op_iota(self, context: TranscriptionContext, op: IotaOp):
iota_dim = int(op.iota_dimension)
tensor_shape = op.result.type.shape
vec_shape = [tensor_shape[dim] if dim == iota_dim else 1 for dim in range(len(tensor_shape))]
res = np.reshape(np.arange(tensor_shape[iota_dim]), vec_shape) * np.ones(tensor_shape)
context.add_variable(op.result.get_name(), res)

@register_stablehlo_op
def op_custom_call(self, context: TranscriptionContext, op: CustomCallOp):
if op.call_target_name.value.startswith("mhlo."):
mapped_op = None
op_impl = None
match op.call_target_name.value:
case "mhlo.topk":
mapped_op = TopKOp
op_impl = self._op_mhlo_topk

if not mapped_op:
raise ValueError(f"mhlo op '{op.call_target_name.value}' is not implemented")
if not op_impl:
raise ValueError(f"mhlo op '{op.call_target_name.value}' does not have an implementation")

mhlo_attributes = {attr.name: attr.attr for attr in list(op.attributes["mhlo.attributes"])}
delegate_op = partial(mapped_op, **mhlo_attributes, loc=op.location)(*op.operands)

# We manually have to handle the results, as the current API does not allow naming
# the `delegate_op` results according to the custom call results
mil_results = op_impl(context, delegate_op)
for (custom_call_result, mil_result) in zip(op.results, mil_results):
context.add_variable(custom_call_result.get_name(), mil_result)

return

raise ValueError(f"Custom call is not supported: {op.call_target_name}")

def _op_mhlo_topk(self, context: TranscriptionContext, op: TopKOp):
"""
This is a MHLO op, and follows a slightly different pattern, since it is unvoked by a
custom call. It will return the results, as we currently can not rename the results
in the TopKOp
"""
x = context[op.operand.get_name()]
mil_res = mb.topk(x=x, k=op.k.value, ascending=not op.largest.value)
return mil_res

def __invoke_hlo_function(self, context: TranscriptionContext, func_name: str, hlo_params, hlo_func_body, cml_args):
# Enter variable context for the function call
context.push_function(func_name)
Expand Down
31 changes: 24 additions & 7 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
from coremltools.converters.mil.testing_utils import compare_backend
from coremltools.converters.mil.mil import Program, Block

from functools import partial

def test_addition():
def plus(x, y):
return jnp.add(x, y)

run_and_compare(plus, (jnp.float32(1), jnp.float32(1)))
run_and_compare(plus, (jnp.zeros((2, 2, 2)), jnp.zeros((2, 2, 2))))
def test_addition():
run_and_compare(jnp.add, (jnp.float32(1), jnp.float32(1)))
run_and_compare(jnp.add, (jnp.zeros((2, 2, 2)), jnp.zeros((2, 2, 2))))


def test_tensor_multiplication():
Expand Down Expand Up @@ -85,12 +84,18 @@ def full_tensor_product_4_1(lhs, rhs):
run_and_compare(full_tensor_product_4_1, (jnp.zeros(((2, 2, 2, 3))), jnp.zeros((2,))))


def test_topk():
input_shape = (3, 5, 10)
run_and_compare(partial(jax.lax.top_k, k=3), (jnp.zeros(input_shape),))


def jax_export(jax_func, input_spec):
def compute_input_shapes(input_specs):
shapes = []
for input_spec in input_specs:
if isinstance(input_spec, (list, tuple)):
shapes.append(compute_input_shapes(input_spec))
# We only unwrap the shapes for one level
shapes.append(input_spec)
else:
shapes.append(jax.ShapeDtypeStruct(input_spec.shape, input_spec.dtype))
return shapes
Expand Down Expand Up @@ -178,7 +183,19 @@ def run_and_compare(jax_func, input_spec, max_complexity: int = 10_000):
"max allowed complexity is {max_complexity}"
)

cml_model = ct.convert(mil_program, source="milinternal", minimum_deployment_target=ct.target.iOS18)
pipeline = ct.PassPipeline.DEFAULT
# We temporarily avoid fp16 conversions in tests because of https://github.com/apple/coremltools/issues/2324
passes_to_remove = [
'common::add_fp16_cast'
]
pipeline.remove_passes(passes_to_remove)

cml_model = ct.convert(
mil_program,
source="milinternal",
minimum_deployment_target=ct.target.iOS18,
pass_pipeline=pipeline,
)

# Generate random inputs that matches cml_model input spec
cml_input_key_values = {}
Expand Down

0 comments on commit db50958

Please sign in to comment.