Skip to content

Commit

Permalink
Error out if registering prim ops multiple times
Browse files Browse the repository at this point in the history
Differential Revision: D69090850

Pull Request resolved: #8172
  • Loading branch information
larryliu0820 authored Feb 4, 2025
1 parent 81f7c4f commit e63c923
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
44 changes: 41 additions & 3 deletions codegen/tools/gen_all_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ def resolve_model_file_path_to_buck_target(model_file_path: str) -> str:
return real_path


def _raise_if_check_prim_ops_fail(options):

# Error out if we have more than one targets registering prim ops.
if options.DEBUG_ONLY_check_prim_ops and len(options.DEBUG_ONLY_check_prim_ops) > 1:
assert (
options.DEBUG_ONLY_check_prim_ops[0] == "@"
), "DEBUG_ONLY_check_prim_ops is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."

prim_ops_targets_file = options.DEBUG_ONLY_check_prim_ops[1:]
with open(prim_ops_targets_file, "r") as file:
prim_ops_targets = file.read().split()
if len(prim_ops_targets) > 1:
# Yellow bold: \033[33;1m
# Red bold: \033[31;1m
# Green bold: \033[32;1m
error = (
"It seems this target is depending on more than 1 `prim_ops_registry` targets: "
+ f'\033[33;1m\n{", ".join(prim_ops_targets)}\033[0m. \nThis will likely cause errors such as: '
+ "\n \033[31;1mRe-registering aten::sym_size.int...\033[0m"
+ "\nTo find out the dependency chain, run the following command: "
+ f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {prim_ops_targets[0]})"\033[0m'
)
raise Exception(error)


def main(argv: List[Any]) -> None:
"""This binary generates 3 files:
Expand Down Expand Up @@ -95,8 +120,18 @@ def main(argv: List[Any]) -> None:
default=False,
required=False,
)
parser.add_argument(
"--DEBUG-ONLY-check-prim-ops",
"--DEBUG_ONLY_check_prim_ops",
help=(
"Useful argument to take BUCK targets that registers prim ops and error out if we have more than 1."
),
required=False,
)
options = parser.parse_args(argv)

_raise_if_check_prim_ops_fail(options)

# Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
# 1. a yaml file containing selected ops (could be empty), or
# 2. a non-empty list of yaml files in the `model_file_list_path` or
Expand Down Expand Up @@ -153,14 +188,17 @@ def main(argv: List[Any]) -> None:
debug_info_2 = ",".join(
model_dict["operators"][op_name]["debug_info"]
)
error = f"Operator {op_name} is used in 2 models: {debug_info_1} and {debug_info_2}"
# Yellow bold: \033[33;1m
# Red bold: \033[31;1m
# Green bold: \033[32;1m
error = f"\033[31;1mOperator {op_name} is used in 2 models: \033[33;1m{debug_info_1} and {debug_info_2}\033[0m"
if "//" not in debug_info_1 and "//" not in debug_info_2:
error += "\nWe can't determine what BUCK targets these model files belong to."
tail = "."
else:
error += "\nPlease run the following commands to find out where is the BUCK target being added as a dependency to your target:\n"
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_1})"'
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_2})"'
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_1})"\033[0m'
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_2})"\033[0m'
tail = "as well as results from BUCK commands listed above."

error += (
Expand Down
1 change: 1 addition & 0 deletions shim/xplat/executorch/codegen/codegen.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def executorch_ops_check(
"--model_file_list_path $(@query_outputs \"filter('.*_et_oplist', deps(set({deps})))\") " +
"--allow_include_all_overloads " +
"--check_ops_not_overlapping " +
"--DEBUG_ONLY_check_prim_ops $(@query_targets \"filter('prim_ops_registry(?:_static|_aten)?$', deps(set({deps})))\") " +
"--output_dir $OUT ").format(deps = " ".join(["\'{}\'".format(d) for d in deps])),
define_static_target = False,
platforms = kwargs.pop("platforms", get_default_executorch_platforms()),
Expand Down

0 comments on commit e63c923

Please sign in to comment.