From d988fb5988598b0363126526636c8d3d2b792b8f Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 9 Jan 2025 14:02:33 -0800 Subject: [PATCH 1/2] [ET-VK] Parse required extensions of shaders and check capabilities during dispatch ## Context Now that we are using GLSL/SPIR-V extensions more heavily in our shaders, there is a risk that a particular shader uses an extension that is not supported by the physical device. It is tedious to manually check that all the extensions required by a shader is supported by the device; it would be much more convenient for developers if there was an automated way to perform this check. This diff provides a solution for this. Materially, this has manifested into an issue with our internal CI tests that run on Android emulator (which uses swiftshader under the hood). If the emulator tries to compile a shader that requires the `shaderInt16` feature, then the emulator will crash. ## Solution 1. Update `ShaderInfo` to have fields indicating whether certain extensions that require device support is required. 2. Update the `gen_vulkan_spv.py` shader compilation script to parse the GLSL code and log whether aforemention extensions are needed in the generated `ShaderInfo`. 3. Introduce a new exception class, `ShaderNotSupportedError`. 4. Before dispatching, check that all extensions required by the shader is supported by the device. If not, throw the new exception class. 4. In the generated operator correctness tests, skip the test if `ShaderNotSupportedError` is thrown. Differential Revision: [D67992067](https://our.internmc.facebook.com/intern/diff/D67992067/) ghstack-source-id: 260809479 Pull Request resolved: https://github.com/pytorch/executorch/pull/7576 --- backends/vulkan/runtime/api/Context.cpp | 21 +++++++++++++ backends/vulkan/runtime/api/Context.h | 2 ++ backends/vulkan/runtime/gen_vulkan_spv.py | 25 +++++++++++++++ .../vulkan/runtime/graph/ops/DispatchNode.cpp | 2 ++ .../graph/ops/glsl/conv2d_dw_output_tile.glsl | 2 -- backends/vulkan/runtime/vk_api/Adapter.cpp | 8 +++-- backends/vulkan/runtime/vk_api/Exception.cpp | 31 +++++++++++++++++++ backends/vulkan/runtime/vk_api/Exception.h | 21 +++++++++++++ backends/vulkan/runtime/vk_api/Shader.cpp | 10 ++++-- backends/vulkan/runtime/vk_api/Shader.h | 8 ++++- .../op_tests/utils/gen_correctness_base.py | 5 +++ 11 files changed, 128 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 9517941f36..f425859935 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -87,6 +87,27 @@ void Context::report_shader_dispatch_end() { } } +void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) { + if (shader.requires_shader_int16) { + if (!adapter_p_->supports_int16_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_INT16); + } + } + if (shader.requires_16bit_storage) { + if (!adapter_p_->supports_16bit_storage_buffers()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::INT16_STORAGE); + } + } + if (shader.requires_8bit_storage) { + if (!adapter_p_->supports_8bit_storage_buffers()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE); + } + } +} + vkapi::DescriptorSet Context::get_descriptor_set( const vkapi::ShaderInfo& shader_descriptor, const utils::uvec3& local_workgroup_size, diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 300fd3995d..0c199c24cc 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -185,6 +185,8 @@ class Context final { } } + void check_device_capabilities(const vkapi::ShaderInfo& shader); + vkapi::DescriptorSet get_descriptor_set( const vkapi::ShaderInfo&, const utils::uvec3&, diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 7d004547a8..7d3d2d5295 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -720,6 +720,10 @@ def maybe_replace_u16vecn(self, input_text: str) -> str: if "codegen-nosub" in input_text: return input_text + # Remove extension requirement so that generated ShaderInfo does not mark it + input_text = input_text.replace( + "#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require", "" + ) input_text = input_text.replace("u16vec", "ivec") input_text = input_text.replace("uint16_t", "int") return input_text @@ -791,6 +795,9 @@ class ShaderInfo: weight_storage_type: str = "" bias_storage_type: str = "" register_for: Optional[Tuple[str, List[str]]] = None + requires_shader_int16_ext: bool = False + requires_16bit_storage_ext: bool = False + requires_8bit_storage_ext: bool = False def getName(filePath: str) -> str: @@ -858,6 +865,11 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]: return (matches_list[0], matches_list[1:]) +def isExtensionRequireLine(lineStr: str) -> bool: + extension_require_id = r"^#extension ([A-Za-z0-9_]+)\s*:\s*require" + return re.search(extension_require_id, lineStr) is not None + + typeIdMapping = { r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER", @@ -889,6 +901,13 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info.bias_storage_type = getBiasStorageType(line) if isRegisterForLine(line): shader_info.register_for = findRegisterFor(line) + if isExtensionRequireLine(line): + if "GL_EXT_shader_explicit_arithmetic_types_int16" in line: + shader_info.requires_shader_int16_ext = True + if "GL_EXT_shader_16bit_storage" in line: + shader_info.requires_16bit_storage_ext = True + if "GL_EXT_shader_8bit_storage" in line: + shader_info.requires_8bit_storage_ext = True return shader_info @@ -952,12 +971,18 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts)) + def to_cpp_str(val: bool): + return "true" if val else "false" + shader_info_args = [ f'"{name}"', f"{name}_bin", str(sizeBytes), shader_info_layouts, tile_size, + to_cpp_str(shader_info.requires_shader_int16_ext), + to_cpp_str(shader_info.requires_16bit_storage_ext), + to_cpp_str(shader_info.requires_8bit_storage_ext), ] shader_info_str = textwrap.indent( diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index a163a0d7ae..63b8798f2c 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -58,6 +58,8 @@ void DispatchNode::encode(ComputeGraph* graph) { api::Context* const context = graph->context(); vkapi::PipelineBarrier pipeline_barrier{}; + context->check_device_capabilities(shader_); + std::unique_lock cmd_lock = context->dispatch_lock(); std::array push_constants_data; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 20fb9374be..4a8d741869 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -34,8 +34,6 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require - /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index 5805d476a3..ec30650ba0 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -256,6 +256,9 @@ std::string Adapter::stringize() const { ss << " deviceType: " << device_type << std::endl; ss << " deviceName: " << properties.deviceName << std::endl; +#define PRINT_BOOL(value, name) \ + ss << " " << std::left << std::setw(36) << #name << value << std::endl; + #define PRINT_PROP(struct, name) \ ss << " " << std::left << std::setw(36) << #name << struct.name \ << std::endl; @@ -298,12 +301,13 @@ std::string Adapter::stringize() const { ss << " }" << std::endl; #endif /* VK_KHR_8bit_storage */ -#ifdef VK_KHR_shader_float16_int8 ss << " Shader 16bit and 8bit Features {" << std::endl; + PRINT_BOOL(physical_device_.supports_int16_shader_types, shaderInt16) +#ifdef VK_KHR_shader_float16_int8 PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16); PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8); - ss << " }" << std::endl; #endif /* VK_KHR_shader_float16_int8 */ + ss << " }" << std::endl; const VkPhysicalDeviceMemoryProperties& mem_props = physical_device_.memory_properties; diff --git a/backends/vulkan/runtime/vk_api/Exception.cpp b/backends/vulkan/runtime/vk_api/Exception.cpp index e330c1c079..d26fbd8cb2 100644 --- a/backends/vulkan/runtime/vk_api/Exception.cpp +++ b/backends/vulkan/runtime/vk_api/Exception.cpp @@ -77,5 +77,36 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg) what_ = oss.str(); } +// +// ShaderNotSupportedError +// + +std::ostream& operator<<(std::ostream& out, const VulkanExtension result) { + switch (result) { + case VulkanExtension::SHADER_INT16: + out << "shaderInt16"; + break; + case VulkanExtension::INT16_STORAGE: + out << "VK_KHR_16bit_storage"; + break; + case VulkanExtension::INT8_STORAGE: + out << "VK_KHR_8bit_storage"; + break; + } + return out; +} + +ShaderNotSupportedError::ShaderNotSupportedError( + std::string shader_name, + VulkanExtension extension) + : shader_name_(std::move(shader_name)), extension_{extension} { + std::ostringstream oss; + oss << "Shader " << shader_name_ << " "; + oss << "not compatible with device. "; + oss << "Missing support for extension or physical device feature: "; + oss << extension_; + what_ = oss.str(); +} + } // namespace vkapi } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Exception.h b/backends/vulkan/runtime/vk_api/Exception.h index ec2f2956a8..a65afb1bcc 100644 --- a/backends/vulkan/runtime/vk_api/Exception.h +++ b/backends/vulkan/runtime/vk_api/Exception.h @@ -78,5 +78,26 @@ class Error : public std::exception { } }; +enum class VulkanExtension : uint8_t { + SHADER_INT16, + INT16_STORAGE, + INT8_STORAGE, +}; + +class ShaderNotSupportedError : public std::exception { + public: + ShaderNotSupportedError(std::string shader_name, VulkanExtension extension); + + private: + std::string shader_name_; + VulkanExtension extension_; + std::string what_; + + public: + const char* what() const noexcept override { + return what_.c_str(); + } +}; + } // namespace vkapi } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Shader.cpp b/backends/vulkan/runtime/vk_api/Shader.cpp index 29774e2f40..e560f37868 100644 --- a/backends/vulkan/runtime/vk_api/Shader.cpp +++ b/backends/vulkan/runtime/vk_api/Shader.cpp @@ -28,14 +28,20 @@ ShaderInfo::ShaderInfo( const uint32_t* const spirv_bin, const uint32_t size, std::vector layout, - const utils::uvec3 tile_size) + const utils::uvec3 tile_size, + const bool requires_shader_int16_ext, + const bool requires_16bit_storage_ext, + const bool requires_8bit_storage_ext) : src_code{ spirv_bin, size, }, kernel_name{std::move(name)}, kernel_layout{std::move(layout)}, - out_tile_size(tile_size) { + out_tile_size(tile_size), + requires_shader_int16(requires_shader_int16_ext), + requires_16bit_storage(requires_16bit_storage_ext), + requires_8bit_storage(requires_8bit_storage_ext) { } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { diff --git a/backends/vulkan/runtime/vk_api/Shader.h b/backends/vulkan/runtime/vk_api/Shader.h index 1e3b2a799f..d9fec65feb 100644 --- a/backends/vulkan/runtime/vk_api/Shader.h +++ b/backends/vulkan/runtime/vk_api/Shader.h @@ -62,6 +62,9 @@ struct ShaderInfo final { // Shader Metadata utils::uvec3 out_tile_size{1u, 1u, 1u}; + bool requires_shader_int16 = false; + bool requires_16bit_storage = false; + bool requires_8bit_storage = false; explicit ShaderInfo(); @@ -70,7 +73,10 @@ struct ShaderInfo final { const uint32_t*, const uint32_t, std::vector, - const utils::uvec3 tile_size); + const utils::uvec3 tile_size, + const bool requires_shader_int16_ext, + const bool requires_16bit_storage_ext, + const bool requires_8bit_storage_ext); operator bool() const { return src_code.bin != nullptr; diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 3d9aa6aa80..d7e3896945 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -45,8 +45,13 @@ class GeneratedOpsTest_{op_name} : public ::testing::Test {{ test_suite_template = """ TEST_P(GeneratedOpsTest_{op_name}, {case_name}) {{ {create_ref_data} +try {{ {create_and_check_out} }} +catch (const vkcompute::vkapi::ShaderNotSupportedError& e) {{ + GTEST_SKIP() << e.what(); +}} +}} """ From 66aa5ed6e23303720027cfe07341886ab2a357e6 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 9 Jan 2025 14:02:36 -0800 Subject: [PATCH 2/2] [ET-VK][ez][buck] Simplify test buck file ## Context The targets file for the op tests define a binary and test rule for each c++ file; instead of manually defining these rules each time, create a helper function to condense the code. Differential Revision: [D67992066](https://our.internmc.facebook.com/intern/diff/D67992066/) ghstack-source-id: 260809480 Pull Request resolved: https://github.com/pytorch/executorch/pull/7577 --- backends/vulkan/test/op_tests/targets.bzl | 187 ++++++---------------- 1 file changed, 47 insertions(+), 140 deletions(-) diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index ab55d5beea..d26f1a805c 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -3,6 +3,44 @@ load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps") load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def define_test_targets(test_name, extra_deps = [], src_file = None, is_fbcode = False): + deps_list = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + runtime.external_dep_location("libtorch"), + ] + extra_deps + + src_file_str = src_file if src_file else "{}.cpp".format(test_name) + + runtime.cxx_binary( + name = "{}_bin".format(test_name), + srcs = [ + src_file_str, + ], + compiler_flags = [ + "-Wno-unused-variable", + ], + define_static_target = False, + deps = deps_list, + ) + + runtime.cxx_test( + name = test_name, + srcs = [ + src_file_str, + ], + contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], + fbandroid_additional_loaded_sonames = [ + "torch-code-gen", + "vulkan_graph_runtime", + "vulkan_graph_runtime_shaderlib", + ], + platforms = [ANDROID], + use_instrumentation_test = True, + deps = deps_list, + ) + + def define_common_targets(is_fbcode = False): if is_fbcode: return @@ -82,19 +120,6 @@ def define_common_targets(is_fbcode = False): default_outs = ["."], ) - runtime.cxx_binary( - name = "compute_graph_op_tests_bin", - srcs = [ - ":generated_op_correctness_tests_cpp[op_tests.cpp]", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], - ) - runtime.cxx_binary( name = "compute_graph_op_benchmarks_bin", srcs = [ @@ -111,135 +136,17 @@ def define_common_targets(is_fbcode = False): ], ) - runtime.cxx_test( - name = "compute_graph_op_tests", - srcs = [ - ":generated_op_correctness_tests_cpp[op_tests.cpp]", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], + define_test_targets( + "compute_graph_op_tests", + src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]" ) - runtime.cxx_binary( - name = "sdpa_test_bin", - srcs = [ - "sdpa_test.cpp", - ], - compiler_flags = [ - "-Wno-unused-variable", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", - ], - ) - - runtime.cxx_test( - name = "sdpa_test", - srcs = [ - "sdpa_test.cpp", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", - "//executorch/extension/tensor:tensor", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_binary( - name = "linear_weight_int4_test_bin", - srcs = [ - "linear_weight_int4_test.cpp", - ], - compiler_flags = [ - "-Wno-unused-variable", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_test( - name = "linear_weight_int4_test", - srcs = [ - "linear_weight_int4_test.cpp", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", + define_test_targets( + "sdpa_test", + extra_deps = [ "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", "//executorch/extension/tensor:tensor", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_binary( - name = "rotary_embedding_test_bin", - srcs = [ - "rotary_embedding_test.cpp", - ], - compiler_flags = [ - "-Wno-unused-variable", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_test( - name = "rotary_embedding_test", - srcs = [ - "rotary_embedding_test.cpp", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/extension/tensor:tensor", - runtime.external_dep_location("libtorch"), - ], + ] ) + define_test_targets("linear_weight_int4_test") + define_test_targets("rotary_embedding_test")