Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK][ez][buck] Simplify test buck file #7593

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,27 @@ void Context::report_shader_dispatch_end() {
}
}

void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
if (shader.requires_shader_int16) {
if (!adapter_p_->supports_int16_shader_types()) {
throw vkapi::ShaderNotSupportedError(
shader.kernel_name, vkapi::VulkanExtension::SHADER_INT16);
}
}
if (shader.requires_16bit_storage) {
if (!adapter_p_->supports_16bit_storage_buffers()) {
throw vkapi::ShaderNotSupportedError(
shader.kernel_name, vkapi::VulkanExtension::INT16_STORAGE);
}
}
if (shader.requires_8bit_storage) {
if (!adapter_p_->supports_8bit_storage_buffers()) {
throw vkapi::ShaderNotSupportedError(
shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE);
}
}
}

vkapi::DescriptorSet Context::get_descriptor_set(
const vkapi::ShaderInfo& shader_descriptor,
const utils::uvec3& local_workgroup_size,
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class Context final {
}
}

void check_device_capabilities(const vkapi::ShaderInfo& shader);

vkapi::DescriptorSet get_descriptor_set(
const vkapi::ShaderInfo&,
const utils::uvec3&,
Expand Down
25 changes: 25 additions & 0 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,10 @@ def maybe_replace_u16vecn(self, input_text: str) -> str:
if "codegen-nosub" in input_text:
return input_text

# Remove extension requirement so that generated ShaderInfo does not mark it
input_text = input_text.replace(
"#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require", ""
)
input_text = input_text.replace("u16vec", "ivec")
input_text = input_text.replace("uint16_t", "int")
return input_text
Expand Down Expand Up @@ -791,6 +795,9 @@ class ShaderInfo:
weight_storage_type: str = ""
bias_storage_type: str = ""
register_for: Optional[Tuple[str, List[str]]] = None
requires_shader_int16_ext: bool = False
requires_16bit_storage_ext: bool = False
requires_8bit_storage_ext: bool = False


def getName(filePath: str) -> str:
Expand Down Expand Up @@ -858,6 +865,11 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
return (matches_list[0], matches_list[1:])


def isExtensionRequireLine(lineStr: str) -> bool:
extension_require_id = r"^#extension ([A-Za-z0-9_]+)\s*:\s*require"
return re.search(extension_require_id, lineStr) is not None


typeIdMapping = {
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
Expand Down Expand Up @@ -889,6 +901,13 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
shader_info.bias_storage_type = getBiasStorageType(line)
if isRegisterForLine(line):
shader_info.register_for = findRegisterFor(line)
if isExtensionRequireLine(line):
if "GL_EXT_shader_explicit_arithmetic_types_int16" in line:
shader_info.requires_shader_int16_ext = True
if "GL_EXT_shader_16bit_storage" in line:
shader_info.requires_16bit_storage_ext = True
if "GL_EXT_shader_8bit_storage" in line:
shader_info.requires_8bit_storage_ext = True

return shader_info

Expand Down Expand Up @@ -952,12 +971,18 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) ->

shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))

def to_cpp_str(val: bool):
return "true" if val else "false"

shader_info_args = [
f'"{name}"',
f"{name}_bin",
str(sizeBytes),
shader_info_layouts,
tile_size,
to_cpp_str(shader_info.requires_shader_int16_ext),
to_cpp_str(shader_info.requires_16bit_storage_ext),
to_cpp_str(shader_info.requires_8bit_storage_ext),
]

shader_info_str = textwrap.indent(
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ void DispatchNode::encode(ComputeGraph* graph) {
api::Context* const context = graph->context();
vkapi::PipelineBarrier pipeline_barrier{};

context->check_device_capabilities(shader_);

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

/*
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
Expand Down
8 changes: 6 additions & 2 deletions backends/vulkan/runtime/vk_api/Adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ std::string Adapter::stringize() const {
ss << " deviceType: " << device_type << std::endl;
ss << " deviceName: " << properties.deviceName << std::endl;

#define PRINT_BOOL(value, name) \
ss << " " << std::left << std::setw(36) << #name << value << std::endl;

#define PRINT_PROP(struct, name) \
ss << " " << std::left << std::setw(36) << #name << struct.name \
<< std::endl;
Expand Down Expand Up @@ -298,12 +301,13 @@ std::string Adapter::stringize() const {
ss << " }" << std::endl;
#endif /* VK_KHR_8bit_storage */

#ifdef VK_KHR_shader_float16_int8
ss << " Shader 16bit and 8bit Features {" << std::endl;
PRINT_BOOL(physical_device_.supports_int16_shader_types, shaderInt16)
#ifdef VK_KHR_shader_float16_int8
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16);
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8);
ss << " }" << std::endl;
#endif /* VK_KHR_shader_float16_int8 */
ss << " }" << std::endl;

const VkPhysicalDeviceMemoryProperties& mem_props =
physical_device_.memory_properties;
Expand Down
31 changes: 31 additions & 0 deletions backends/vulkan/runtime/vk_api/Exception.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,36 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg)
what_ = oss.str();
}

//
// ShaderNotSupportedError
//

std::ostream& operator<<(std::ostream& out, const VulkanExtension result) {
switch (result) {
case VulkanExtension::SHADER_INT16:
out << "shaderInt16";
break;
case VulkanExtension::INT16_STORAGE:
out << "VK_KHR_16bit_storage";
break;
case VulkanExtension::INT8_STORAGE:
out << "VK_KHR_8bit_storage";
break;
}
return out;
}

ShaderNotSupportedError::ShaderNotSupportedError(
std::string shader_name,
VulkanExtension extension)
: shader_name_(std::move(shader_name)), extension_{extension} {
std::ostringstream oss;
oss << "Shader " << shader_name_ << " ";
oss << "not compatible with device. ";
oss << "Missing support for extension or physical device feature: ";
oss << extension_;
what_ = oss.str();
}

} // namespace vkapi
} // namespace vkcompute
21 changes: 21 additions & 0 deletions backends/vulkan/runtime/vk_api/Exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,26 @@ class Error : public std::exception {
}
};

enum class VulkanExtension : uint8_t {
SHADER_INT16,
INT16_STORAGE,
INT8_STORAGE,
};

class ShaderNotSupportedError : public std::exception {
public:
ShaderNotSupportedError(std::string shader_name, VulkanExtension extension);

private:
std::string shader_name_;
VulkanExtension extension_;
std::string what_;

public:
const char* what() const noexcept override {
return what_.c_str();
}
};

} // namespace vkapi
} // namespace vkcompute
10 changes: 8 additions & 2 deletions backends/vulkan/runtime/vk_api/Shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@ ShaderInfo::ShaderInfo(
const uint32_t* const spirv_bin,
const uint32_t size,
std::vector<VkDescriptorType> layout,
const utils::uvec3 tile_size)
const utils::uvec3 tile_size,
const bool requires_shader_int16_ext,
const bool requires_16bit_storage_ext,
const bool requires_8bit_storage_ext)
: src_code{
spirv_bin,
size,
},
kernel_name{std::move(name)},
kernel_layout{std::move(layout)},
out_tile_size(tile_size) {
out_tile_size(tile_size),
requires_shader_int16(requires_shader_int16_ext),
requires_16bit_storage(requires_16bit_storage_ext),
requires_8bit_storage(requires_8bit_storage_ext) {
}

bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
Expand Down
8 changes: 7 additions & 1 deletion backends/vulkan/runtime/vk_api/Shader.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ struct ShaderInfo final {

// Shader Metadata
utils::uvec3 out_tile_size{1u, 1u, 1u};
bool requires_shader_int16 = false;
bool requires_16bit_storage = false;
bool requires_8bit_storage = false;

explicit ShaderInfo();

Expand All @@ -70,7 +73,10 @@ struct ShaderInfo final {
const uint32_t*,
const uint32_t,
std::vector<VkDescriptorType>,
const utils::uvec3 tile_size);
const utils::uvec3 tile_size,
const bool requires_shader_int16_ext,
const bool requires_16bit_storage_ext,
const bool requires_8bit_storage_ext);

operator bool() const {
return src_code.bin != nullptr;
Expand Down
Loading