Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 13, 2024
1 parent 21b89c3 commit 9b4efc6
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
13 changes: 4 additions & 9 deletions src/common/primitive_attr_quant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,10 @@ struct scales_t : public c_compatible {
CHECK(scales_[arg].set(mask, data_type, group_ndims, group_dims));
return status::success;
}

// Used directly in the following scenarios:
// * CPU GeMM-based matmul to transfer scale as `alpha` arg, referred as
// `gemm_applies_output_scales_`.
status_t erase(int arg) {
if (!check_arg(arg)) return status::invalid_arguments;
const auto it = scales_.find(arg);
if (it != scales_.end()) scales_.erase(it);
return status::success;
// Use this interface with `default_quant_entry` when need to remove a
// specific scale.
status_t set(int arg, const quant_entry_t &other) {
return scales_[arg].set(other);
}

bool has_default_values(const std::vector<int> &supported_args = {}) const {
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/matmul/gemm_bf16_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ status_t gemm_bf16_matmul_t<dst_type>::pd_t::check_and_configure_attributes(
= attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0 && !with_bias();

if (params_.gemm_applies_output_scales_) {
VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.erase(DNNL_ARG_SRC),
VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set(
DNNL_ARG_SRC, default_quant_entry()),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.erase(DNNL_ARG_WEIGHTS),
VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set(
DNNL_ARG_WEIGHTS, default_quant_entry()),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}

Expand Down
6 changes: 4 additions & 2 deletions src/cpu/matmul/gemm_f32_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ status_t gemm_f32_matmul_t::pd_t::configure_attributes() {
params_.gemm_applies_output_scales_
= attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0 && !with_bias();
if (params_.gemm_applies_output_scales_) {
VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.erase(DNNL_ARG_SRC),
VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set(
DNNL_ARG_SRC, default_quant_entry()),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.erase(DNNL_ARG_WEIGHTS),
VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set(
DNNL_ARG_WEIGHTS, default_quant_entry()),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}

Expand Down
4 changes: 2 additions & 2 deletions src/cpu/ref_fused_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ struct ref_fused_convolution_fwd_t : public primitive_t {
auto &scale
= attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | arg);
if (!scale.has_default_values())
CHECK(attr_1x1.scales_.erase(
DNNL_ARG_ATTR_POST_OP_DW | arg));
CHECK(attr_1x1.scales_.set(DNNL_ARG_ATTR_POST_OP_DW | arg,
default_quant_entry()));
}
// erase post-ops after fusion as they will be handled separately
auto &e = attr_1x1.post_ops_.entry_;
Expand Down

0 comments on commit 9b4efc6

Please sign in to comment.