Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

src: attr: quantization refactor (part 2) #2570

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/common/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,18 @@ status_t conv_attr_check(const convolution_desc_t &desc, const engine_t *engine,
// Check zero points
if (!attr->zero_points_.has_default_values()) {
const auto &zp = attr->zero_points_;
int mask_src = 0, mask_wei = 0, mask_dst = 0;
zp.get(DNNL_ARG_SRC, &mask_src);
zp.get(DNNL_ARG_WEIGHTS, &mask_wei);
zp.get(DNNL_ARG_DST, &mask_dst);

VCHECK_CONV_UNIMPL((mask_src == 0 || mask_src == 1 << 1)
&& (mask_wei == 0)
&& (mask_dst == 0 || mask_dst == 1 << 1),

VCHECK_CONV_UNIMPL(IMPLICATION(!zp.has_default_values(DNNL_ARG_SRC),
utils::one_of(zp.get_mask(DNNL_ARG_SRC),
0, 1 << 1)),
VERBOSE_UNSUPPORTED_ZP_CFG);
VCHECK_CONV_UNIMPL(
IMPLICATION(!zp.has_default_values(DNNL_ARG_WEIGHTS),
zp.get_mask(DNNL_ARG_WEIGHTS) == 0),
VERBOSE_UNSUPPORTED_ZP_CFG);
VCHECK_CONV_UNIMPL(IMPLICATION(!zp.has_default_values(DNNL_ARG_DST),
utils::one_of(zp.get_mask(DNNL_ARG_DST),
0, 1 << 1)),
VERBOSE_UNSUPPORTED_ZP_CFG);
}

Expand Down
17 changes: 11 additions & 6 deletions src/common/deconvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,18 @@ status_t deconv_attr_check(const deconvolution_desc_t &desc,
// Check zero points
if (!attr->zero_points_.has_default_values()) {
const auto &zp = attr->zero_points_;
int mask_src = 0, mask_dst = 0;
zp.get(DNNL_ARG_SRC, &mask_src);
zp.get(DNNL_ARG_DST, &mask_dst);

VCHECK_DECONV_UNIMPL(zp.has_default_values(DNNL_ARG_WEIGHTS)
&& (mask_src == 0 || mask_src == 1 << 1)
&& (mask_dst == 0 || mask_dst == 1 << 1),
VCHECK_DECONV_UNIMPL(
IMPLICATION(!zp.has_default_values(DNNL_ARG_SRC),
utils::one_of(
zp.get_mask(DNNL_ARG_SRC), 0, 1 << 1)),
VERBOSE_UNSUPPORTED_ZP_CFG);
VCHECK_DECONV_UNIMPL(zp.has_default_values(DNNL_ARG_WEIGHTS),
VERBOSE_UNSUPPORTED_ZP_CFG);
VCHECK_DECONV_UNIMPL(
IMPLICATION(!zp.has_default_values(DNNL_ARG_DST),
utils::one_of(
zp.get_mask(DNNL_ARG_DST), 0, 1 << 1)),
VERBOSE_UNSUPPORTED_ZP_CFG);
}

Expand Down
128 changes: 72 additions & 56 deletions src/common/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,70 +175,86 @@ status_t matmul_attr_check(const matmul_desc_t &desc, const engine_t *engine,
// Check zero points
if (!attr->zero_points_.has_default_values()) {
const auto &zp = attr->zero_points_;
int mask_src = 0, mask_wei = 0, mask_dst = 0;
zp.get(DNNL_ARG_SRC, &mask_src);
zp.get(DNNL_ARG_WEIGHTS, &mask_wei);
zp.get(DNNL_ARG_DST, &mask_dst);

VCHECK_MATMUL_UNIMPL(utils::one_of(mask_src, 0, src_qmask_K,
src_qmask_M + src_qmask_K),
VERBOSE_UNSUPPORTED_ZP_CFG);
// Masks for weights zero points can be any - skipping them.
VCHECK_MATMUL_UNIMPL(mask_dst == 0
|| (desc.dst_desc.ndims == 2 && mask_dst == 1 << 1),
VERBOSE_UNSUPPORTED_ZP_CFG);
dim_t src_zero_point_group_k = 1;
if (!zp.has_default_values(DNNL_ARG_SRC)) {
const int mask_src = zp.get_mask(DNNL_ARG_SRC);

if (utils::one_of(zp.get_data_type(DNNL_ARG_WEIGHTS), data_type::s4,
data_type::u4)) {
dim_t k = desc.weights_desc.dims[ndims_wei - 2];
dim_t n = desc.weights_desc.dims[ndims_wei - 1];
VCHECK_MATMUL_UNIMPL(
IMPLICATION(mask_wei & wei_qmask_K, k % 2 == 0),
VCHECK_MATMUL_UNIMPL(utils::one_of(mask_src, 0, src_qmask_K,
src_qmask_M + src_qmask_K),
VERBOSE_UNSUPPORTED_ZP_CFG);
VCHECK_MATMUL_UNIMPL(
IMPLICATION(mask_wei & wei_qmask_N, n % 2 == 0),

if (!zp.get(DNNL_ARG_SRC).has_default_groups()) {
if (mask_src & src_qmask_K)
src_zero_point_group_k = zp.get_group(DNNL_ARG_SRC, 1);
}

// Due to hardware specifics, groups should be multiple of 32.
VCHECK_MATMUL_UNIMPL(IMPLICATION(src_zero_point_group_k > 1,
src_zero_point_group_k % 32 == 0),
VERBOSE_UNSUPPORTED_ZP_CFG);
}

// Check dependency between zps.
// Source zps groups are supported for int8 source and must divide
// or be divided by weights groups when both are greater than 1.
const auto src_zp_group_k = (mask_src & src_qmask_K)
&& zp.get_groups_ndims(DNNL_ARG_SRC) > 0
? zp.get_groups(DNNL_ARG_SRC)[1]
: 1;
const auto wei_zp_group_k = (mask_wei & wei_qmask_K)
&& zp.get_groups_ndims(DNNL_ARG_WEIGHTS) > 0
? zp.get_groups(DNNL_ARG_WEIGHTS)[0]
: 1;
const bool groups_are_divisible
= IMPLICATION(src_zp_group_k > 1 && wei_zp_group_k > 1,
(src_zp_group_k % wei_zp_group_k == 0)
|| (wei_zp_group_k % src_zp_group_k == 0));
VCHECK_MATMUL_UNIMPL(IMPLICATION(src_zp_group_k > 1,
src_is_int8 && groups_are_divisible),
VERBOSE_UNSUPPORTED_ZP_CFG);
dim_t wei_zero_point_group_k = 1;
dim_t wei_zero_point_group_n = 1;
if (!zp.has_default_values(DNNL_ARG_WEIGHTS)) {
const int mask_wei = zp.get_mask(DNNL_ARG_WEIGHTS);

// Groups per N are solely for weights decompression as it's impossible
// to get performant kernel for a single `k` element in chain for
// regular quantized case.
const auto wei_zp_group_n = (mask_wei & wei_qmask_N)
&& zp.get_groups_ndims(DNNL_ARG_WEIGHTS) > 0
? zp.get_groups(DNNL_ARG_WEIGHTS)[1]
: 1;
VCHECK_MATMUL_UNIMPL(
IMPLICATION(wei_zp_group_n > 1, attr->fpmath_.apply_to_int_),
VERBOSE_UNSUPPORTED_ZP_CFG);
// Masks for weights zero_points can be any - skipping them.

// Due to hardware specifics, groups should be multiple of 32.
VCHECK_MATMUL_UNIMPL(
IMPLICATION(src_zp_group_k > 1, src_zp_group_k % 32 == 0),
VERBOSE_UNSUPPORTED_ZP_CFG);
VCHECK_MATMUL_UNIMPL(
IMPLICATION(wei_zp_group_k > 1, wei_zp_group_k % 32 == 0),
VERBOSE_UNSUPPORTED_ZP_CFG);
VCHECK_MATMUL_UNIMPL(
IMPLICATION(wei_zp_group_n > 1, wei_zp_group_n % 32 == 0),
if (!zp.get(DNNL_ARG_WEIGHTS).has_default_groups()) {
if (mask_wei & wei_qmask_K)
wei_zero_point_group_k = zp.get_group(DNNL_ARG_WEIGHTS, 0);
if (mask_wei & wei_qmask_N)
wei_zero_point_group_n = zp.get_group(DNNL_ARG_WEIGHTS, 1);
}

// Groups per N are solely for weights decompression as it's
// impossible to get performant kernel for a single `k` element in
// chain for regular quantized case.
VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_zero_point_group_n > 1,
attr->fpmath_.apply_to_int_),
VERBOSE_UNSUPPORTED_ZP_CFG);

// Due to hardware specifics, groups should be multiple of 32.
VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_zero_point_group_k > 1,
wei_zero_point_group_k % 32 == 0),
VERBOSE_UNSUPPORTED_ZP_CFG);
VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_zero_point_group_n > 1,
wei_zero_point_group_n % 32 == 0),
VERBOSE_UNSUPPORTED_ZP_CFG);

if (utils::one_of(zp.get_data_type(DNNL_ARG_WEIGHTS), data_type::s4,
data_type::u4)) {
dim_t k = desc.weights_desc.dims[ndims_wei - 2];
dim_t n = desc.weights_desc.dims[ndims_wei - 1];
VCHECK_MATMUL_UNIMPL(
IMPLICATION(mask_wei & wei_qmask_K, k % 2 == 0),
VERBOSE_UNSUPPORTED_ZP_CFG);
VCHECK_MATMUL_UNIMPL(
IMPLICATION(mask_wei & wei_qmask_N, n % 2 == 0),
VERBOSE_UNSUPPORTED_ZP_CFG);
}
}

if (!zp.has_default_values(DNNL_ARG_DST)) {
const int mask_dst = zp.get_mask(DNNL_ARG_DST);

VCHECK_MATMUL_UNIMPL(mask_dst == 0
|| (desc.dst_desc.ndims == 2 && mask_dst == 1 << 1),
VERBOSE_UNSUPPORTED_ZP_CFG);
}

// Check dependency between zero_points.
// Source zero_points groups are supported for int8 source and must
// divide or be divided by weights groups when both are greater than 1.
const bool groups_are_divisible = IMPLICATION(
src_zero_point_group_k > 1 && wei_zero_point_group_k > 1,
(src_zero_point_group_k % wei_zero_point_group_k == 0)
|| (wei_zero_point_group_k % src_zero_point_group_k
== 0));
VCHECK_MATMUL_UNIMPL(IMPLICATION(src_zero_point_group_k > 1,
src_is_int8 && groups_are_divisible),
VERBOSE_UNSUPPORTED_ZP_CFG);
}

Expand Down
46 changes: 3 additions & 43 deletions src/common/primitive_attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,45 +69,6 @@ status_t rnn_create_time_scales_t::set(
return status::success;
}

status_t zero_points_t::get(int arg, int *mask, data_type_t *dt) const {
if (mask) *mask = get_mask(arg);
if (dt) *dt = get_data_type(arg);
return status::success;
}

int zero_points_t::get(int arg) const {
return get_mask(arg);
}

status_t zero_points_t::set(int arg, int mask, int ndims, const dims_t groups,
data_type_t data_type) {
const bool supported_arg
= utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST);
if (!supported_arg) return status::unimplemented;

switch (arg) {
case DNNL_ARG_SRC:
is_set_src = true;
mask_src = mask;
data_type_src = data_type;
group_ndims_src = ndims;
utils::array_copy(group_dims_src, groups, group_ndims_src);
break;
case DNNL_ARG_WEIGHTS:
is_set_wei = true;
mask_wei = mask;
data_type_wei = data_type;
group_ndims_wei = ndims;
utils::array_copy(group_dims_wei, groups, group_ndims_wei);
break;
case DNNL_ARG_DST:
is_set_dst = true;
mask_dst = mask;
break;
}
return status::success;
}

status_t dropout_t::set_default_formats(const memory_desc_t *dst_md) {
auto is_any_or_undef = [](format_kind_t kind) {
return one_of(kind, dnnl_format_kind_any, dnnl_format_kind_undef);
Expand Down Expand Up @@ -561,9 +522,8 @@ status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg,

status_t dnnl_primitive_attr_set_zero_points_mask(
primitive_attr_t *attr, int arg, int mask) {
bool ok = attr && mask >= 0;
if (!ok) return invalid_arguments;

VCHECK_ATTR(attr, VERBOSE_NULL_ARG);
VCHECK_ATTR(mask >= 0, VERBOSE_BAD_PARAM, "mask");
return attr->zero_points_.set(arg, mask);
}

Expand All @@ -585,7 +545,7 @@ status_t dnnl_primitive_attr_set_zero_points(dnnl_primitive_attr_t attr,
VCHECK_ATTR(IMPLICATION(ndims, validate_dims(ndims, group_dims)),
VERBOSE_BAD_PARAM, "group_dims");

return attr->zero_points_.set(arg, mask, ndims, group_dims, data_type);
return attr->zero_points_.set(arg, mask, data_type, ndims, group_dims);
}

status_t dnnl_primitive_attr_get_rounding(
Expand Down
35 changes: 35 additions & 0 deletions src/common/primitive_attr_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,40 @@ std::string scales_t::get_verbose() const {
return s;
}

size_t zero_points_t::get_hash() const {
size_t seed = 0;
// Go through zero_points for all arguments.
for (const auto &e : zero_points_) {
seed = hash_combine(seed, e.first);
seed = hash_combine(seed, e.second.get_hash());
}
return seed;
}

void zero_points_t::serialize(serialization_stream_t &sstream) const {
for (const auto &e : zero_points_) {
sstream.write(&e.first);
e.second.serialize(sstream);
}
}

std::string zero_points_t::get_verbose() const {
std::string s;
std::string empty_delim, attr_delim = "+";
std::string delim = empty_delim;
for (const auto &zero_point : zero_points_) {
const auto &q = zero_point.second;
if (q.has_default_values()) continue;

int arg = zero_point.first;
s.append(delim)
.append(arg2str(arg))
.append(":")
.append(q.get_verbose());
delim = attr_delim;
}
return s;
}

} // namespace impl
} // namespace dnnl
Loading
Loading