From 05f8f7aa296920992e07914931e66a8c0f9ad87b Mon Sep 17 00:00:00 2001 From: "Kassen, Andrew" Date: Thu, 23 Jan 2025 08:31:44 -0800 Subject: [PATCH] xe: jit: gemm: rearrange runtime dim checks This reverts commit d22e43f9635af48c235fa778223f1e01b5c67a80 and splits the runtime dimension checks to avoid undefined behavior. --- .../intel/jit/gemm/xe_hp_systolic_gemm.cpp | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp b/src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp index e0f62d6f843..a9a7a252246 100644 --- a/src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp +++ b/src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp @@ -67,16 +67,26 @@ status_t xe_hp_systolic_gemm_t::pd_t::init(impl::engine_t *engine) { c_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_DST); } - VDISPATCH_GEMM_SC( - attr_.set_default_formats(dst_md(0)), VERBOSE_UNSUPPORTED_TAG); - - VDISPATCH_GEMM(!use_nocopy(), VERBOSE_SKIP_PRIMITIVE_IMPL); - // LIMITATIONS: // - batch is not supported for unpacked inputs. // - runtime dims are not supported bool limits_ok = !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(), d->k()); + + VDISPATCH_GEMM(limits_ok, VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + // Must check runtime dimensions before calling `set_default_formats` to + // avoid UB. + VDISPATCH_GEMM_SC( + set_default_formats(d->a_type()), VERBOSE_UNSUPPORTED_TAG); + + VDISPATCH_GEMM_SC( + attr_.set_default_formats(dst_md(0)), VERBOSE_UNSUPPORTED_TAG); + + VDISPATCH_GEMM(!use_nocopy(), VERBOSE_SKIP_PRIMITIVE_IMPL); + + // set_default_formats` determines a/b/c packing, so it must be called + // prior to this. if (!packed_a()) limits_ok = limits_ok && (d->lda() != DNNL_RUNTIME_DIM_VAL) && (d->batch() == 1); @@ -85,10 +95,6 @@ status_t xe_hp_systolic_gemm_t::pd_t::init(impl::engine_t *engine) { && (d->batch() == 1); if (!packed_c()) limits_ok = limits_ok && (d->ldc() != DNNL_RUNTIME_DIM_VAL); - VDISPATCH_GEMM(limits_ok, VERBOSE_RUNTIMEDIM_UNSUPPORTED); - - VDISPATCH_GEMM_SC( - set_default_formats(d->a_type()), VERBOSE_UNSUPPORTED_TAG); auto attr_skip_mask = smask_t::scales_runtime | smask_t::post_ops; @@ -97,6 +103,7 @@ status_t xe_hp_systolic_gemm_t::pd_t::init(impl::engine_t *engine) { bool arch_ok = utils::one_of(arch, arch_t::xe_hp, arch_t::xe_hpg, arch_t::xe_hpc, arch_t::xe2, arch_t::xe3); + VDISPATCH_GEMM(limits_ok, VERBOSE_RUNTIMEDIM_UNSUPPORTED); VDISPATCH_GEMM((dt_float_ok || dt_int_ok), VERBOSE_UNSUPPORTED_DT_CFG); VDISPATCH_GEMM(arch_ok, VERBOSE_UNSUPPORTED_ARCH, "gpu"); VDISPATCH_GEMM(