Skip to content

Commit

Permalink
Arm backend: Bugfix in TosaSpecification representation (#8137)
Browse files Browse the repository at this point in the history
Fix up TOSA 1.0 class string reprsentation handling and add testcases
for the fixed functionality.

Signed-off-by: Per Åstrand <[email protected]>
  • Loading branch information
per authored Feb 5, 2025
1 parent e63c923 commit 62e49ce
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 23 deletions.
56 changes: 37 additions & 19 deletions backends/arm/test/misc/test_tosa_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@
"TOSA-0.80+MI+8k",
"TOSA-0.80+BI+u55",
]
test_valid_1_00_strings = [
"TOSA-1.00.0+INT+FP+fft",
"TOSA-1.00.0+FP+bf16+fft",
"TOSA-1.00.0+INT+int4+cf",
"TOSA-1.00.0+FP+cf+bf16+8k",
"TOSA-1.00.0+FP+INT+bf16+fft+int4+cf",
"TOSA-1.00.0+FP+INT+fft+int4+cf+8k",
test_valid_1_0_strings = [
"TOSA-1.0.0+INT+FP+fft",
"TOSA-1.0.0+FP+bf16+fft",
"TOSA-1.0.0+INT+int4+cf",
"TOSA-1.0.0+FP+cf+bf16+8k",
"TOSA-1.0.0+FP+INT+bf16+fft+int4+cf",
"TOSA-1.0.0+FP+INT+fft+int4+cf+8k",
"TOSA-1.0+INT+FP+fft",
"TOSA-1.0+FP+bf16+fft",
"TOSA-1.0+INT+int4+cf",
"TOSA-1.0+FP+cf+bf16+8k",
"TOSA-1.0+FP+INT+bf16+fft+int4+cf",
"TOSA-1.0+FP+INT+fft+int4+cf+8k",
]

test_valid_1_00_extensions = {
test_valid_1_0_extensions = {
"INT": ["int16", "int4", "var", "cf"],
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
}
Expand All @@ -40,19 +46,19 @@
"TOSA-0.80+8k",
"TOSA-0.80+BI+MI",
"TOSA-0.80+BI+U55",
"TOSA-1.00.0+fft",
"TOSA-1.00.0+fp+bf16+fft",
"TOSA-1.00.0+INT+INT4+cf",
"TOSA-1.00.0+BI",
"TOSA-1.00.0+FP+FP+INT",
"TOSA-1.00.0+FP+CF+bf16",
"TOSA-1.00.0+BF16+fft+int4+cf+INT",
"TOSA-1.0.0+fft",
"TOSA-1.0.0+fp+bf16+fft",
"TOSA-1.0.0+INT+INT4+cf",
"TOSA-1.0.0+BI",
"TOSA-1.0.0+FP+FP+INT",
"TOSA-1.0.0+FP+CF+bf16",
"TOSA-1.0.0+BF16+fft+int4+cf+INT",
]

test_compile_specs = [
([CompileSpec("tosa_version", "TOSA-0.80+BI".encode())],),
([CompileSpec("tosa_version", "TOSA-0.80+BI+u55".encode())],),
([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],),
([CompileSpec("tosa_version", "TOSA-1.0.0+INT".encode())],),
]

test_compile_specs_no_version = [
Expand All @@ -70,8 +76,8 @@ def test_version_string_0_80(self, version_string: str):
assert isinstance(tosa_spec, Tosa_0_80)
assert tosa_spec.profile in ["BI", "MI"]

@parameterized.expand(test_valid_1_00_strings) # type: ignore[misc]
def test_version_string_1_00(self, version_string: str):
@parameterized.expand(test_valid_1_0_strings) # type: ignore[misc]
def test_version_string_1_0(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
assert isinstance(tosa_spec, Tosa_1_00)
assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count(
Expand All @@ -80,7 +86,7 @@ def test_version_string_1_00(self, version_string: str):

for profile in tosa_spec.profiles:
assert [
e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions
e in test_valid_1_0_extensions[profile] for e in tosa_spec.extensions
]

@parameterized.expand(test_invalid_strings) # type: ignore[misc]
Expand All @@ -103,3 +109,15 @@ def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec])
tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)

assert tosa_spec is None

@parameterized.expand(test_valid_0_80_strings)
def test_correct_string_representation_0_80(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
assert isinstance(tosa_spec, Tosa_0_80)
assert f"{tosa_spec}" == version_string

@parameterized.expand(test_valid_1_0_strings)
def test_correct_string_representation_1_0(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
assert isinstance(tosa_spec, Tosa_1_00)
assert f"{tosa_spec}" == version_string
13 changes: 9 additions & 4 deletions backends/arm/tosa_specification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,7 +14,9 @@
import re
from typing import List

from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-untyped]
CompileSpec,
)
from packaging.version import Version


Expand Down Expand Up @@ -131,7 +133,7 @@ def __init__(self, version: Version, extras: List[str]):
def __repr__(self):
extensions = ""
if self.level_8k:
extensions += "+8K"
extensions += "+8k"
if self.is_U55_subset:
extensions += "+u55"
return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
Expand Down Expand Up @@ -207,7 +209,10 @@ def _get_extensions_string(self) -> str:
return "".join(["+" + e for e in self.extensions])

def __repr__(self):
return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
extensions = self._get_extensions_string()
if self.level_8k:
extensions += "+8k"
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"

def __hash__(self) -> int:
return hash(str(self.version) + self._get_profiles_string())
Expand Down

0 comments on commit 62e49ce

Please sign in to comment.