Skip to content

Commit

Permalink
Update on "Use c10 version of half/bfloat16 in executorch"
Browse files Browse the repository at this point in the history
Accomplished by importing relevant files from c10 into
executorch/runtime/core/portable_type/c10, and then using `using` in
the top-level ExecuTorch headers. This approach should keep the
ExecuTorch build hermetic for embedded use cases. In the future, we
should add a CI job to ensure the c10 files stay identical to the
PyTorch ones.

Differential Revision: [D66106969](https://our.internmc.facebook.com/intern/diff/D66106969/)

[ghstack-poisoned]
  • Loading branch information
Github Executorch committed Feb 5, 2025
2 parents 25ada5b + 3e62e23 commit ad42e94
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 26 deletions.
56 changes: 37 additions & 19 deletions backends/arm/test/misc/test_tosa_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@
"TOSA-0.80+MI+8k",
"TOSA-0.80+BI+u55",
]
test_valid_1_00_strings = [
"TOSA-1.00.0+INT+FP+fft",
"TOSA-1.00.0+FP+bf16+fft",
"TOSA-1.00.0+INT+int4+cf",
"TOSA-1.00.0+FP+cf+bf16+8k",
"TOSA-1.00.0+FP+INT+bf16+fft+int4+cf",
"TOSA-1.00.0+FP+INT+fft+int4+cf+8k",
test_valid_1_0_strings = [
"TOSA-1.0.0+INT+FP+fft",
"TOSA-1.0.0+FP+bf16+fft",
"TOSA-1.0.0+INT+int4+cf",
"TOSA-1.0.0+FP+cf+bf16+8k",
"TOSA-1.0.0+FP+INT+bf16+fft+int4+cf",
"TOSA-1.0.0+FP+INT+fft+int4+cf+8k",
"TOSA-1.0+INT+FP+fft",
"TOSA-1.0+FP+bf16+fft",
"TOSA-1.0+INT+int4+cf",
"TOSA-1.0+FP+cf+bf16+8k",
"TOSA-1.0+FP+INT+bf16+fft+int4+cf",
"TOSA-1.0+FP+INT+fft+int4+cf+8k",
]

test_valid_1_00_extensions = {
test_valid_1_0_extensions = {
"INT": ["int16", "int4", "var", "cf"],
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
}
Expand All @@ -40,19 +46,19 @@
"TOSA-0.80+8k",
"TOSA-0.80+BI+MI",
"TOSA-0.80+BI+U55",
"TOSA-1.00.0+fft",
"TOSA-1.00.0+fp+bf16+fft",
"TOSA-1.00.0+INT+INT4+cf",
"TOSA-1.00.0+BI",
"TOSA-1.00.0+FP+FP+INT",
"TOSA-1.00.0+FP+CF+bf16",
"TOSA-1.00.0+BF16+fft+int4+cf+INT",
"TOSA-1.0.0+fft",
"TOSA-1.0.0+fp+bf16+fft",
"TOSA-1.0.0+INT+INT4+cf",
"TOSA-1.0.0+BI",
"TOSA-1.0.0+FP+FP+INT",
"TOSA-1.0.0+FP+CF+bf16",
"TOSA-1.0.0+BF16+fft+int4+cf+INT",
]

test_compile_specs = [
([CompileSpec("tosa_version", "TOSA-0.80+BI".encode())],),
([CompileSpec("tosa_version", "TOSA-0.80+BI+u55".encode())],),
([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],),
([CompileSpec("tosa_version", "TOSA-1.0.0+INT".encode())],),
]

test_compile_specs_no_version = [
Expand All @@ -70,8 +76,8 @@ def test_version_string_0_80(self, version_string: str):
assert isinstance(tosa_spec, Tosa_0_80)
assert tosa_spec.profile in ["BI", "MI"]

@parameterized.expand(test_valid_1_00_strings) # type: ignore[misc]
def test_version_string_1_00(self, version_string: str):
@parameterized.expand(test_valid_1_0_strings) # type: ignore[misc]
def test_version_string_1_0(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
assert isinstance(tosa_spec, Tosa_1_00)
assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count(
Expand All @@ -80,7 +86,7 @@ def test_version_string_1_00(self, version_string: str):

for profile in tosa_spec.profiles:
assert [
e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions
e in test_valid_1_0_extensions[profile] for e in tosa_spec.extensions
]

@parameterized.expand(test_invalid_strings) # type: ignore[misc]
Expand All @@ -103,3 +109,15 @@ def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec])
tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)

assert tosa_spec is None

@parameterized.expand(test_valid_0_80_strings)
def test_correct_string_representation_0_80(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
assert isinstance(tosa_spec, Tosa_0_80)
assert f"{tosa_spec}" == version_string

@parameterized.expand(test_valid_1_0_strings)
def test_correct_string_representation_1_0(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
assert isinstance(tosa_spec, Tosa_1_00)
assert f"{tosa_spec}" == version_string
13 changes: 9 additions & 4 deletions backends/arm/tosa_specification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,7 +14,9 @@
import re
from typing import List

from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-untyped]
CompileSpec,
)
from packaging.version import Version


Expand Down Expand Up @@ -131,7 +133,7 @@ def __init__(self, version: Version, extras: List[str]):
def __repr__(self):
extensions = ""
if self.level_8k:
extensions += "+8K"
extensions += "+8k"
if self.is_U55_subset:
extensions += "+u55"
return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
Expand Down Expand Up @@ -207,7 +209,10 @@ def _get_extensions_string(self) -> str:
return "".join(["+" + e for e in self.extensions])

def __repr__(self):
return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
extensions = self._get_extensions_string()
if self.level_8k:
extensions += "+8k"
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"

def __hash__(self) -> int:
return hash(str(self.version) + self._get_profiles_string())
Expand Down
44 changes: 41 additions & 3 deletions codegen/tools/gen_all_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ def resolve_model_file_path_to_buck_target(model_file_path: str) -> str:
return real_path


def _raise_if_check_prim_ops_fail(options):

# Error out if we have more than one targets registering prim ops.
if options.DEBUG_ONLY_check_prim_ops and len(options.DEBUG_ONLY_check_prim_ops) > 1:
assert (
options.DEBUG_ONLY_check_prim_ops[0] == "@"
), "DEBUG_ONLY_check_prim_ops is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."

prim_ops_targets_file = options.DEBUG_ONLY_check_prim_ops[1:]
with open(prim_ops_targets_file, "r") as file:
prim_ops_targets = file.read().split()
if len(prim_ops_targets) > 1:
# Yellow bold: \033[33;1m
# Red bold: \033[31;1m
# Green bold: \033[32;1m
error = (
"It seems this target is depending on more than 1 `prim_ops_registry` targets: "
+ f'\033[33;1m\n{", ".join(prim_ops_targets)}\033[0m. \nThis will likely cause errors such as: '
+ "\n \033[31;1mRe-registering aten::sym_size.int...\033[0m"
+ "\nTo find out the dependency chain, run the following command: "
+ f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {prim_ops_targets[0]})"\033[0m'
)
raise Exception(error)


def main(argv: List[Any]) -> None:
"""This binary generates 3 files:
Expand Down Expand Up @@ -95,8 +120,18 @@ def main(argv: List[Any]) -> None:
default=False,
required=False,
)
parser.add_argument(
"--DEBUG-ONLY-check-prim-ops",
"--DEBUG_ONLY_check_prim_ops",
help=(
"Useful argument to take BUCK targets that registers prim ops and error out if we have more than 1."
),
required=False,
)
options = parser.parse_args(argv)

_raise_if_check_prim_ops_fail(options)

# Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
# 1. a yaml file containing selected ops (could be empty), or
# 2. a non-empty list of yaml files in the `model_file_list_path` or
Expand Down Expand Up @@ -153,14 +188,17 @@ def main(argv: List[Any]) -> None:
debug_info_2 = ",".join(
model_dict["operators"][op_name]["debug_info"]
)
error = f"Operator {op_name} is used in 2 models: {debug_info_1} and {debug_info_2}"
# Yellow bold: \033[33;1m
# Red bold: \033[31;1m
# Green bold: \033[32;1m
error = f"\033[31;1mOperator {op_name} is used in 2 models: \033[33;1m{debug_info_1} and {debug_info_2}\033[0m"
if "//" not in debug_info_1 and "//" not in debug_info_2:
error += "\nWe can't determine what BUCK targets these model files belong to."
tail = "."
else:
error += "\nPlease run the following commands to find out where is the BUCK target being added as a dependency to your target:\n"
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_1})"'
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_2})"'
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_1})"\033[0m'
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_2})"\033[0m'
tail = "as well as results from BUCK commands listed above."

error += (
Expand Down
1 change: 1 addition & 0 deletions shim/xplat/executorch/codegen/codegen.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def executorch_ops_check(
"--model_file_list_path $(@query_outputs \"filter('.*_et_oplist', deps(set({deps})))\") " +
"--allow_include_all_overloads " +
"--check_ops_not_overlapping " +
"--DEBUG_ONLY_check_prim_ops $(@query_targets \"filter('prim_ops_registry(?:_static|_aten)?$', deps(set({deps})))\") " +
"--output_dir $OUT ").format(deps = " ".join(["\'{}\'".format(d) for d in deps])),
define_static_target = False,
platforms = kwargs.pop("platforms", get_default_executorch_platforms()),
Expand Down

0 comments on commit ad42e94

Please sign in to comment.