Skip to content

Commit

Permalink
xe: jit: gemm: rearrange runtime dim checks
Browse files Browse the repository at this point in the history
This reverts commit d22e43f and splits
the runtime dimension checks to avoid undefined behavior.
  • Loading branch information
atkassen committed Jan 23, 2025
1 parent fbe6b3c commit 05f8f7a
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,26 @@ status_t xe_hp_systolic_gemm_t::pd_t::init(impl::engine_t *engine) {
c_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_DST);
}

VDISPATCH_GEMM_SC(
attr_.set_default_formats(dst_md(0)), VERBOSE_UNSUPPORTED_TAG);

VDISPATCH_GEMM(!use_nocopy(), VERBOSE_SKIP_PRIMITIVE_IMPL);

// LIMITATIONS:
// - batch is not supported for unpacked inputs.
// - runtime dims are not supported
bool limits_ok
= !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(), d->k());

VDISPATCH_GEMM(limits_ok, VERBOSE_RUNTIMEDIM_UNSUPPORTED);

// Must check runtime dimensions before calling `set_default_formats` to
// avoid UB.
VDISPATCH_GEMM_SC(
set_default_formats(d->a_type()), VERBOSE_UNSUPPORTED_TAG);

VDISPATCH_GEMM_SC(
attr_.set_default_formats(dst_md(0)), VERBOSE_UNSUPPORTED_TAG);

VDISPATCH_GEMM(!use_nocopy(), VERBOSE_SKIP_PRIMITIVE_IMPL);

// set_default_formats` determines a/b/c packing, so it must be called
// prior to this.
if (!packed_a())
limits_ok = limits_ok && (d->lda() != DNNL_RUNTIME_DIM_VAL)
&& (d->batch() == 1);
Expand All @@ -85,10 +95,6 @@ status_t xe_hp_systolic_gemm_t::pd_t::init(impl::engine_t *engine) {
&& (d->batch() == 1);
if (!packed_c())
limits_ok = limits_ok && (d->ldc() != DNNL_RUNTIME_DIM_VAL);
VDISPATCH_GEMM(limits_ok, VERBOSE_RUNTIMEDIM_UNSUPPORTED);

VDISPATCH_GEMM_SC(
set_default_formats(d->a_type()), VERBOSE_UNSUPPORTED_TAG);

auto attr_skip_mask = smask_t::scales_runtime | smask_t::post_ops;

Expand All @@ -97,6 +103,7 @@ status_t xe_hp_systolic_gemm_t::pd_t::init(impl::engine_t *engine) {
bool arch_ok = utils::one_of(arch, arch_t::xe_hp, arch_t::xe_hpg,
arch_t::xe_hpc, arch_t::xe2, arch_t::xe3);

VDISPATCH_GEMM(limits_ok, VERBOSE_RUNTIMEDIM_UNSUPPORTED);
VDISPATCH_GEMM((dt_float_ok || dt_int_ok), VERBOSE_UNSUPPORTED_DT_CFG);
VDISPATCH_GEMM(arch_ok, VERBOSE_UNSUPPORTED_ARCH, "gpu");
VDISPATCH_GEMM(
Expand Down

0 comments on commit 05f8f7a

Please sign in to comment.