Skip to content

Commit

Permalink
gpu: sycl: binary: add support for broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 authored and dzarukin committed May 1, 2024
1 parent 5521d4c commit e177c3c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
27 changes: 24 additions & 3 deletions src/gpu/sycl/binary_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace sycl {

struct binary_kernel_vec_t {
static constexpr int vec_len = 8;
static constexpr int max_supported_ndims = 5;

binary_kernel_vec_t(const sycl_binary_conf_t &conf,
sycl_in_memory_arg_t &src0, sycl_in_memory_arg_t &src1,
Expand Down Expand Up @@ -63,8 +64,17 @@ struct binary_kernel_vec_t {
? load_float_value(scales_dt_, src1_scale_ptr(), 0)
: 1.f);

if (sg_base_idx + (sg.get_local_range()[0] * conf_.block_size)
< conf_.wk_size) {
dims_t dims, off;
bool any_broadcast = false;
for (int i = 0; i < max_supported_ndims; i++) {
dims[i] = (i < src0_md().ndims()) ? src0_md().dims()[i] : 1;
if (i < src0_md().ndims()) {
any_broadcast |= conf_.broadcast_dims[i];
}
}
if (!any_broadcast
&& sg_base_idx + (sg.get_local_range()[0] * conf_.block_size)
< conf_.wk_size) {
for (int i = 0; i < conf_.block_size / vec_len; i++) {
auto src0_vec = load_float_vec<vec_len>(
src0_md().data_type(), src0_ptr(), vec_base_idx + i);
Expand All @@ -90,10 +100,21 @@ struct binary_kernel_vec_t {
for (int i = 0; i < conf_.block_size; i++) {
int idx = base_idx + i;
if (idx < conf_.wk_size) {
utils::l_dims_by_l_offset(
off, idx, dims, max_supported_ndims);

for (int i = 0; i < max_supported_ndims; i++) {
if (conf_.broadcast_dims[i] && i < src0_md().ndims()) {
off[i] = 0;
}
}

int idx1 = src1_md().off_v(off);

auto src0 = load_float_value(
src0_md().data_type(), src0_ptr(), idx);
auto src1 = load_float_value(
src1_md().data_type(), src1_ptr(), idx);
src1_md().data_type(), src1_ptr(), idx1);
auto dst = load_float_value(
dst_md().data_type(), dst_ptr(), idx);

Expand Down
14 changes: 1 addition & 13 deletions src/gpu/sycl/ref_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,13 @@ struct ref_binary_t : public sycl_gpu_primitive_t {

const bool ok = set_default_params() == status::success
&& check_data_types(src0_d, src1_d, dst_d)
&& check_formats(src0_d, src1_d, dst_d) && is_tensor_op()
&& check_formats(src0_d, src1_d, dst_d)
&& attr()->has_default_values(
sm::scales_runtime | sm::post_ops)
&& IMPLICATION(!attr()->scales_.has_default_values(),
check_scales_mask())
&& post_ops_ok();
if (!ok) return status::unimplemented;
// TODO: extend sycl device info to check supported sub-group sizes.
auto *sycl_engine
= utils::downcast<impl::sycl::sycl_engine_base_t *>(engine);
const auto supported_sub_group_sizes
= sycl_engine->device()
.template get_info<
::sycl::info::device::sub_group_sizes>();
if (!std::any_of(supported_sub_group_sizes.cbegin(),
supported_sub_group_sizes.cend(),
[](size_t size) { return size == 32; })) {
return status::unimplemented;
}

return init_conf();
}
Expand Down

0 comments on commit e177c3c

Please sign in to comment.