Skip to content

Commit

Permalink
[ET-VK][ez][buck] Simplify test buck file
Browse files Browse the repository at this point in the history
## Context

The targets file for the op tests define a binary and test rule for each c++ file; instead of manually defining these rules each time, create a helper function to condense the code.

Differential Revision: [D67992066](https://our.internmc.facebook.com/intern/diff/D67992066/)

ghstack-source-id: 260809480
Pull Request resolved: #7577
  • Loading branch information
SS-JIA committed Jan 9, 2025
1 parent d988fb5 commit 66aa5ed
Showing 1 changed file with 47 additions and 140 deletions.
187 changes: 47 additions & 140 deletions backends/vulkan/test/op_tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,44 @@ load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_test_targets(test_name, extra_deps = [], src_file = None, is_fbcode = False):
deps_list = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
runtime.external_dep_location("libtorch"),
] + extra_deps

src_file_str = src_file if src_file else "{}.cpp".format(test_name)

runtime.cxx_binary(
name = "{}_bin".format(test_name),
srcs = [
src_file_str,
],
compiler_flags = [
"-Wno-unused-variable",
],
define_static_target = False,
deps = deps_list,
)

runtime.cxx_test(
name = test_name,
srcs = [
src_file_str,
],
contacts = ["[email protected]"],
fbandroid_additional_loaded_sonames = [
"torch-code-gen",
"vulkan_graph_runtime",
"vulkan_graph_runtime_shaderlib",
],
platforms = [ANDROID],
use_instrumentation_test = True,
deps = deps_list,
)


def define_common_targets(is_fbcode = False):
if is_fbcode:
return
Expand Down Expand Up @@ -82,19 +120,6 @@ def define_common_targets(is_fbcode = False):
default_outs = ["."],
)

runtime.cxx_binary(
name = "compute_graph_op_tests_bin",
srcs = [
":generated_op_correctness_tests_cpp[op_tests.cpp]",
],
define_static_target = False,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
runtime.external_dep_location("libtorch"),
],
)

runtime.cxx_binary(
name = "compute_graph_op_benchmarks_bin",
srcs = [
Expand All @@ -111,135 +136,17 @@ def define_common_targets(is_fbcode = False):
],
)

runtime.cxx_test(
name = "compute_graph_op_tests",
srcs = [
":generated_op_correctness_tests_cpp[op_tests.cpp]",
],
contacts = ["[email protected]"],
fbandroid_additional_loaded_sonames = [
"torch-code-gen",
"vulkan_graph_runtime",
"vulkan_graph_runtime_shaderlib",
],
platforms = [ANDROID],
use_instrumentation_test = True,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
runtime.external_dep_location("libtorch"),
],
define_test_targets(
"compute_graph_op_tests",
src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]"
)

runtime.cxx_binary(
name = "sdpa_test_bin",
srcs = [
"sdpa_test.cpp",
],
compiler_flags = [
"-Wno-unused-variable",
],
define_static_target = False,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
],
)

runtime.cxx_test(
name = "sdpa_test",
srcs = [
"sdpa_test.cpp",
],
contacts = ["[email protected]"],
fbandroid_additional_loaded_sonames = [
"torch-code-gen",
"vulkan_graph_runtime",
"vulkan_graph_runtime_shaderlib",
],
platforms = [ANDROID],
use_instrumentation_test = True,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/extension/tensor:tensor",
runtime.external_dep_location("libtorch"),
],
)

runtime.cxx_binary(
name = "linear_weight_int4_test_bin",
srcs = [
"linear_weight_int4_test.cpp",
],
compiler_flags = [
"-Wno-unused-variable",
],
define_static_target = False,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
runtime.external_dep_location("libtorch"),
],
)

runtime.cxx_test(
name = "linear_weight_int4_test",
srcs = [
"linear_weight_int4_test.cpp",
],
contacts = ["[email protected]"],
fbandroid_additional_loaded_sonames = [
"torch-code-gen",
"vulkan_graph_runtime",
"vulkan_graph_runtime_shaderlib",
],
platforms = [ANDROID],
use_instrumentation_test = True,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
define_test_targets(
"sdpa_test",
extra_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/extension/tensor:tensor",
runtime.external_dep_location("libtorch"),
],
)

runtime.cxx_binary(
name = "rotary_embedding_test_bin",
srcs = [
"rotary_embedding_test.cpp",
],
compiler_flags = [
"-Wno-unused-variable",
],
define_static_target = False,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
runtime.external_dep_location("libtorch"),
],
)

runtime.cxx_test(
name = "rotary_embedding_test",
srcs = [
"rotary_embedding_test.cpp",
],
contacts = ["[email protected]"],
fbandroid_additional_loaded_sonames = [
"torch-code-gen",
"vulkan_graph_runtime",
"vulkan_graph_runtime_shaderlib",
],
platforms = [ANDROID],
use_instrumentation_test = True,
deps = [
"//third-party/googletest:gtest_main",
"//executorch/backends/vulkan:vulkan_graph_runtime",
"//executorch/extension/tensor:tensor",
runtime.external_dep_location("libtorch"),
],
]
)
define_test_targets("linear_weight_int4_test")
define_test_targets("rotary_embedding_test")

0 comments on commit 66aa5ed

Please sign in to comment.