diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index a17eb0628..ac05983cc 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -455,137 +455,6 @@ def apply_configuration( return new_mlir -def apply_params_mmt( - problem_size: ProblemSize, template: list[str], configuration: Configuration -) -> tuple[str, str]: - M, N, K = problem_size.MNK - modified = indent( - get_transform_function_mmt( - problem_size, f"match_mmt_{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_mmt_tile_sizes(configuration) - ) - embeddable = indent( - get_transform_function_mmt(problem_size, f"match_op", configuration), " " - ) - return modified, embeddable - - -def apply_params_conv( - problem_size: ProblemSize, template: list[str], configuration: Configuration -) -> tuple[str, str]: - conv_dims = ConvDimInfo.from_problem_size(problem_size) - modified = indent( - get_transform_function_conv( - problem_size, - f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_conv_tile_sizes(configuration) - ) - embeddable = indent( - get_transform_function_conv(problem_size, f"match_op", configuration), - " ", - ) - return modified, embeddable - - -def apply_params_contract( - problem_size: ProblemSize, - tile_dims: str, - template: list[str], - configuration: Configuration, -) -> tuple[str, str]: - # TODO: Generate transform function. - return ( - apply_configuration( - template, configuration, get_contract_tile_sizes(configuration, tile_dims) - ), - "", - ) - - -def apply_params_batch_matmul( - problem_size: ProblemSize, - tile_dims: str, - template: list[str], - configuration: Configuration, -) -> tuple[str, str]: - tune_logger.info(f"{configuration}") - M, N, K = problem_size.MNK - modified = indent( - get_transform_function_batch_matmul( - problem_size, - tile_dims, - f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_contract_tile_sizes(configuration, tile_dims) - ) - - embeddable = indent( - get_transform_function_batch_matmul( - problem_size, tile_dims, f"match_op", configuration - ), - " ", - ) - return modified, embeddable - - -def apply_params_batch_mmt( - problem_size: ProblemSize, template: list[str], configuration: Configuration -) -> tuple[str, str]: - M, N, K = problem_size.MNK - B = problem_size.matmul_size.B - modified = indent( - get_transform_function_batch_mmt( - problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - get_transform_function_batch_mmt(problem_size, f"match_op", configuration), - " ", - ) - return modified, embeddable - - -def apply_params_broadcast_rhs_mmt( - problem_size: ProblemSize, template: list[str], configuration: Configuration -) -> tuple[str, str]: - M, N, K = problem_size.MNK - modified = indent( - get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_op", configuration - ), - " ", - ) - return modified, embeddable - - def parse_tensor_type(tensor_type: str) -> ShapedType: shape_match = re.search(MlirRegex.tensor_type, tensor_type) assert shape_match @@ -598,254 +467,6 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: return ShapedType(dims, str_to_elem_ty[elem]) -def get_shapes_mmt(template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: - continue - # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) - mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(mmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 2 - lhs_M, lhs_K = lhs_shaped_type.shape - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - rhs_N, rhs_K = rhs_shaped_type.shape - - assert lhs_shaped_type.element_type == rhs_shaped_type.element_type - assert lhs_K == rhs_K - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 2 - res_M, res_N = res_shaped_type.shape - - assert lhs_M == res_M - assert rhs_N == res_N - - matmul_size = MatmulSize( - lhs_shaped_type.shape[0], rhs_shaped_type.shape[0], lhs_shaped_type.shape[1] - ) - return ProblemSize( - matmul_size, - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.mmt, - ) - - assert False, "Shape not found" - - -def get_shapes_conv(template: list[str]) -> ProblemSize: - for line in template: - if "linalg.conv_2d_nhwc_hwcf" not in line: - continue - - # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) - conv_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(conv_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 4 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 4 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 4 - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) - return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, - ), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.conv, - ) - - assert False, "Shape not found" - - -def get_shapes_contract( - template: list[str], lhs_dims: str, rhs_dims: str -) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if "lowering_config =" not in line: - continue - if '"reduction"' not in line: - continue - - # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) - cont_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() >= 2 - - M = math.prod( - val if dim == "m" else 1 - for dim, val in zip(lhs_dims, lhs_shaped_type.shape) - ) - N = math.prod( - val if dim == "n" else 1 - for dim, val in zip(rhs_dims, rhs_shaped_type.shape) - ) - K0 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(lhs_dims, lhs_shaped_type.shape) - ) - K1 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(rhs_dims, rhs_shaped_type.shape) - ) - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.contraction, - ) - - assert False, "Shape not found" - - -def get_shapes_batch_matmul( - template: list[str], lhs_dims: str, rhs_dims: str -) -> ProblemSize: - for line in template: - if "linalg.batch_matmul" not in line: - continue - # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) - # outs(%12 : tensor<64x72x1280xf32>) - cont_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == lhs_shaped_type.rank() - - LHS = lhs_shaped_type.shape - RHS = rhs_shaped_type.shape - RES = res_shaped_type.shape - - B = math.prod(val if dim == "b" else 1 for dim, val in zip(lhs_dims, LHS)) - B0 = math.prod(val if dim == "b" else 1 for dim, val in zip(lhs_dims, RHS)) - B1 = math.prod(val if dim == "b" else 1 for dim, val in zip(lhs_dims, RES)) - M = math.prod(val if dim == "m" else 1 for dim, val in zip(lhs_dims, LHS)) - N = math.prod(val if dim == "n" else 1 for dim, val in zip(rhs_dims, RHS)) - K0 = math.prod(val if dim == "k" else 1 for dim, val in zip(lhs_dims, LHS)) - K1 = math.prod(val if dim == "k" else 1 for dim, val in zip(rhs_dims, RHS)) - assert B == B0 and B == B1 - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0, B), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.batch_matmul, - ) - - assert False, "Shape not found" - - -def get_shapes_batch_mmt(template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - continue - # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) - bmmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 3 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - B1, N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B1 - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.batch_mmt, - ) - - assert False, "Shape not found" - - def is_broadcast_rhs_mmt_op(line: str) -> bool: if "linalg.generic" not in line: return False @@ -862,51 +483,6 @@ def is_broadcast_rhs_mmt_op(line: str) -> bool: return True -def is_broadcast_rhs_mmt(template: list[str]) -> bool: - return any(is_broadcast_rhs_mmt_op(line) for line in template) - - -def get_shapes_broadcast_rhs_mmt(template: list[str]) -> ProblemSize: - for line in template: - if not is_broadcast_rhs_mmt_op(line): - continue - - # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) - bmmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.broadcast_rhs_mmt, - ) - - assert False, "Shape not found" - - def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: def is_compatible(intrinsic: MfmaIntrinsic) -> bool: if problem_size.res_type.element_type != intrinsic.output_type: @@ -1122,22 +698,41 @@ def parse_mlir(mlir_text: str) -> ir.Module: @dataclass -class CandidateGenFn: - get_shapes_fn: Optional[Callable[[list[str]], ProblemSize]] = None - apply_params_fn: Optional[ - Callable[[ProblemSize, list[str], Configuration], tuple[str, str]] - ] = None +class TFMLIR: + """Transformation of MLIR context""" + + template: str + modified: str + embeddable: str class DispatchTuner(ABC): @abstractmethod - def supports(self, mlir: str) -> bool: + def supports(self, op_name: str) -> bool: + """Check if the tuner can handle the type of operation represented by the input string.""" pass @abstractmethod - def get_candidate_gen_fn(self) -> CandidateGenFn: + def get_shapes(self, template: list[str]) -> ProblemSize: + """Extract problem size of thge operation.""" pass + @abstractmethod + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + """Apply parameter transformations to the operation.""" + pass + + +@dataclass +class OpWalkResult: + was_interrupted: bool = False + dispatch_tuner: Optional[DispatchTuner] = None + class DispatchTunerRegistry: def __init__(self): @@ -1147,111 +742,531 @@ def register(self, dispatch_tuners: list[DispatchTuner]) -> None: for dispatch_tuner in dispatch_tuners: self.registry.add(dispatch_tuner) - def get_candidate_gen_fn(self, mlir: str) -> CandidateGenFn: + def find_handler(self, op_name: str) -> DispatchTuner: for dispatch_tuner in self.registry: - if dispatch_tuner.supports(mlir): - return dispatch_tuner.get_candidate_gen_fn() - + if dispatch_tuner.supports(op_name): + return dispatch_tuner assert False, "Not supported" class MmtTuner(DispatchTuner): - def supports(self, mlir: str) -> bool: - return "matmul_transpose_b" in mlir + def supports(self, op_name: str) -> bool: + return "matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: + continue + # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 2 + lhs_M, lhs_K = lhs_shaped_type.shape + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + rhs_N, rhs_K = rhs_shaped_type.shape + + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 2 + res_M, res_N = res_shaped_type.shape + + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_shaped_type.shape[0], + rhs_shaped_type.shape[0], + lhs_shaped_type.shape[1], + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt, + ) - def get_candidate_gen_fn(self) -> CandidateGenFn: - return CandidateGenFn(get_shapes_mmt, apply_params_mmt) + assert False, "Shape not found" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + M, N, K = problem_size.MNK + modified = indent( + get_transform_function_mmt( + problem_size, f"match_mmt_{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_mmt_tile_sizes(configuration) + ) + embeddable = indent( + get_transform_function_mmt(problem_size, f"match_op", configuration), " " + ) + return TFMLIR(template, modified, embeddable) class ConvTuner(DispatchTuner): - def supports(self, mlir: str) -> bool: - return "conv_2d_nhwc_hwcf" in mlir + def supports(self, op_name: str) -> bool: + return "conv_2d_nhwc_hwcf" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.conv_2d_nhwc_hwcf" not in line: + continue + + # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) + conv_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(conv_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.conv, + ) - def get_candidate_gen_fn(self) -> CandidateGenFn: - return CandidateGenFn(get_shapes_conv, apply_params_conv) + assert False, "Shape not found" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + conv_dims = ConvDimInfo.from_problem_size(problem_size) + modified = indent( + get_transform_function_conv( + problem_size, + f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_conv_tile_sizes(configuration) + ) + embeddable = indent( + get_transform_function_conv(problem_size, f"match_op", configuration), + " ", + ) + return TFMLIR(template, modified, embeddable) class ContractionTuner(DispatchTuner): - def __init__( - self, lhs_dims: str, rhs_dims: str, tile_dims: str, mlir_template: str - ): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): self.lhs_dims = lhs_dims self.rhs_dims = rhs_dims self.tile_dims = tile_dims - self.mlir_template = mlir_template - - def supports(self, mlir: str) -> bool: - return "matmul_like" in mlir - - def get_candidate_gen_fn(self) -> CandidateGenFn: - if is_broadcast_rhs_mmt(self.mlir_template): - get_shapes_fn = get_shapes_broadcast_rhs_mmt - apply_params_fn = apply_params_broadcast_rhs_mmt - else: - get_shapes_fn = lambda template: get_shapes_contract( - template, self.lhs_dims, self.rhs_dims + + def supports(self, op_name: str) -> bool: + return "matmul_like" in op_name + + def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: + return any(is_broadcast_rhs_mmt_op(line) for line in template) + + def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: + for line in template: + if not is_broadcast_rhs_mmt_op(line): + continue + + # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.broadcast_rhs_mmt, + ) + + assert False, "Shape not found" + + def get_shapes(self, template: list[str]) -> ProblemSize: + if self.is_broadcast_rhs_mmt(template): + return self.get_shapes_broadcast_rhs_mmt(template) + + for line in template: + if "linalg.generic" not in line: + continue + if "lowering_config =" not in line: + continue + if '"reduction"' not in line: + continue + + # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" ) - apply_params_fn = lambda ps, template, config: apply_params_contract( - ps, self.tile_dims, template, config + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() >= 2 + + M = math.prod( + val if dim == "m" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + N = math.prod( + val if dim == "n" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + K0 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + K1 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.contraction, + ) + + assert False, "Shape not found" + + def apply_params_broadcast_rhs_mmt( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> tuple[str, str]: + M, N, K = problem_size.MNK + modified = indent( + get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_op", configuration + ), + " ", + ) + return TFMLIR(template, modified, embeddable) + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + if self.is_broadcast_rhs_mmt(template): + return self.apply_params_broadcast_rhs_mmt( + problem_size, template, configuration ) - return CandidateGenFn(get_shapes_fn, apply_params_fn) + + # TODO: Generate transform function. + return TFMLIR( + template, + apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), + ), + "", + ) class BatchMmtTuner(DispatchTuner): - def supports(self, mlir: str) -> bool: - return "batch_matmul_transpose_b" in mlir + def supports(self, op_name: str) -> bool: + return "batch_matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + continue + # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 3 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + B1, N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B1 + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.batch_mmt, + ) - def get_candidate_gen_fn(self) -> CandidateGenFn: - return CandidateGenFn(get_shapes_batch_mmt, apply_params_batch_mmt) + assert False, "Shape not found" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + M, N, K = problem_size.MNK + B = problem_size.matmul_size.B + modified = indent( + get_transform_function_batch_mmt( + problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + get_transform_function_batch_mmt(problem_size, f"match_op", configuration), + " ", + ) + return TFMLIR(template, modified, embeddable) class BatchMatmulTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): self.lhs_dims = lhs_dims self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "batch_matmul" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.batch_matmul" not in line: + continue + # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) + # outs(%12 : tensor<64x72x1280xf32>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) - def supports(self, mlir: str) -> bool: - return "batch_matmul" in mlir + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == lhs_shaped_type.rank() - def get_candidate_gen_fn(self) -> CandidateGenFn: - get_shapes_fn = lambda template: get_shapes_batch_matmul( - template, self.lhs_dims, self.rhs_dims + LHS = lhs_shaped_type.shape + RHS = rhs_shaped_type.shape + RES = res_shaped_type.shape + + B = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + B0 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) + ) + B1 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) + ) + M = math.prod( + val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + N = math.prod( + val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + K0 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + K1 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + assert B == B0 and B == B1 + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0, B), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.batch_matmul, + ) + + assert False, "Shape not found" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + M, N, K = problem_size.MNK + modified = indent( + get_transform_function_batch_matmul( + problem_size, + self.tile_dims, + f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), ) - apply_params_fn = lambda ps, template, config: apply_params_batch_matmul( - ps, self.tile_dims, template, config + + embeddable = indent( + get_transform_function_batch_matmul( + problem_size, self.tile_dims, f"match_op", configuration + ), + " ", ) - return CandidateGenFn(get_shapes_fn, apply_params_fn) + return TFMLIR(template, modified, embeddable) def walk_callback_get_fn( op: ir.Operation, - candidate_gen_fn: CandidateGenFn, + walk_result: OpWalkResult, dispatch_tuner_registry: DispatchTunerRegistry, ) -> ir.WalkResult: if op.name == "util.func": func_name = str(op.opview.sym_name) - searched_fn = dispatch_tuner_registry.get_candidate_gen_fn(func_name) - candidate_gen_fn.get_shapes_fn = searched_fn.get_shapes_fn - candidate_gen_fn.apply_params_fn = searched_fn.apply_params_fn - if candidate_gen_fn.apply_params_fn and candidate_gen_fn.get_shapes_fn: - return ir.WalkResult.INTERRUPT + walk_result.was_interrupted = True + walk_result.dispatch_tuner = dispatch_tuner_registry.find_handler(func_name) + return ir.WalkResult.INTERRUPT return ir.WalkResult.ADVANCE def walk_mlir_op( mlir_module: ir.Module, - candidate_gen_fn: CandidateGenFn, dispatch_tuner_registry: DispatchTunerRegistry, -): +) -> OpWalkResult: + walk_result = OpWalkResult() for op in mlir_module.body.operations: op.walk( - lambda op: walk_callback_get_fn( - op, candidate_gen_fn, dispatch_tuner_registry - ), + lambda op: walk_callback_get_fn(op, walk_result, dispatch_tuner_registry), ir.WalkOrder.POST_ORDER, ) - if candidate_gen_fn.apply_params_fn and candidate_gen_fn.get_shapes_fn: + if walk_result.was_interrupted: break + return walk_result def tune( @@ -1281,40 +1296,34 @@ def tune( with open(path.join(output, f"0.mlir"), "w") as f: f.write(mlir_text) - candidate_gen_fn = CandidateGenFn() dispatch_tuner_registry = DispatchTunerRegistry() dispatch_tuner_registry.register( [ MmtTuner(), ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims, mlir_template), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), ] ) - walk_mlir_op(mlir_module, candidate_gen_fn, dispatch_tuner_registry) - - get_shapes_fn = candidate_gen_fn.get_shapes_fn - apply_params_fn = candidate_gen_fn.apply_params_fn + walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry) - problem_size = get_shapes_fn(mlir_template) + dispatch_tuner = walk_result.dispatch_tuner + problem_size = dispatch_tuner.get_shapes(mlir_template) tune_logger.debug(str(problem_size)) - configs = [] for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): if i >= limit: break tune_logger.info(f"Solution #{i+1}: {config}") configs.append(config) - new_mlir, embeddable_tuning = apply_params_fn( - problem_size, mlir_template, config - ) + tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(new_mlir) + f.write(tf_mlir.modified) with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(embeddable_tuning) + f.write(tf_mlir.embeddable) with open(path.join(output, "configs.pkl"), "wb") as file: pickle.dump(configs, file)