diff --git a/src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp b/src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp index e0f62d6f843..c554380cfd2 100644 --- a/src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp +++ b/src/gpu/intel/jit/gemm/xe_hp_systolic_gemm.cpp @@ -67,16 +67,26 @@ status_t xe_hp_systolic_gemm_t::pd_t::init(impl::engine_t *engine) { c_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_DST); } - VDISPATCH_GEMM_SC( - attr_.set_default_formats(dst_md(0)), VERBOSE_UNSUPPORTED_TAG); - - VDISPATCH_GEMM(!use_nocopy(), VERBOSE_SKIP_PRIMITIVE_IMPL); - // LIMITATIONS: // - batch is not supported for unpacked inputs. // - runtime dims are not supported bool limits_ok = !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(), d->k()); + + VDISPATCH_GEMM(limits_ok, VERBOSE_RUNTIMEDIM_UNSUPPORTED); + + // Must check runtime dimensions before calling `set_default_formats` to + // avoid undefined behavior. + VDISPATCH_GEMM_SC( + set_default_formats(d->a_type()), VERBOSE_UNSUPPORTED_TAG); + + VDISPATCH_GEMM_SC( + attr_.set_default_formats(dst_md(0)), VERBOSE_UNSUPPORTED_TAG); + + VDISPATCH_GEMM(!use_nocopy(), VERBOSE_SKIP_PRIMITIVE_IMPL); + + // `set_default_formats` determines a/b/c packing, so it must be called + // prior to this. if (!packed_a()) limits_ok = limits_ok && (d->lda() != DNNL_RUNTIME_DIM_VAL) && (d->batch() == 1); @@ -85,10 +95,6 @@ status_t xe_hp_systolic_gemm_t::pd_t::init(impl::engine_t *engine) { && (d->batch() == 1); if (!packed_c()) limits_ok = limits_ok && (d->ldc() != DNNL_RUNTIME_DIM_VAL); - VDISPATCH_GEMM(limits_ok, VERBOSE_RUNTIMEDIM_UNSUPPORTED); - - VDISPATCH_GEMM_SC( - set_default_formats(d->a_type()), VERBOSE_UNSUPPORTED_TAG); auto attr_skip_mask = smask_t::scales_runtime | smask_t::post_ops; @@ -97,6 +103,7 @@ status_t xe_hp_systolic_gemm_t::pd_t::init(impl::engine_t *engine) { bool arch_ok = utils::one_of(arch, arch_t::xe_hp, arch_t::xe_hpg, arch_t::xe_hpc, arch_t::xe2, arch_t::xe3); + VDISPATCH_GEMM(limits_ok, VERBOSE_RUNTIMEDIM_UNSUPPORTED); VDISPATCH_GEMM((dt_float_ok || dt_int_ok), VERBOSE_UNSUPPORTED_DT_CFG); VDISPATCH_GEMM(arch_ok, VERBOSE_UNSUPPORTED_ARCH, "gpu"); VDISPATCH_GEMM(