From d7cac31bd2cce22389124a559b610a6a981d45ee Mon Sep 17 00:00:00 2001 From: Brian Yang <125406446+gpupuck@users.noreply.github.com> Date: Thu, 12 Dec 2024 15:28:55 -0800 Subject: [PATCH] Disable MKL DNN build for ARM only due to JAX PR#23225 (#1195) JAX PR#23225 https://github.com/jax-ml/jax/pull/23225 introduced a JAX build error for ARM: ``` #14 93.17 In file included from external/compute_library/src/core/utils/logging/FilePrinter.cpp:24: 5120#14 93.17 In file included from external/compute_library/arm_compute/core/utils/logging/FilePrinter.h:27: 5121#14 93.17 external/compute_library/arm_compute/core/utils/logging/IPrinter.h:56:34: error: no type named 'string' in namespace 'std' 5122#14 93.17 56 | inline void print(const std::string &msg) 5123#14 93.17 | ~~~~~^ 5124#14 93.17 external/compute_library/arm_compute/core/utils/logging/IPrinter.h:67:44: error: no type named 'string' in namespace 'std' 5125#14 93.17 67 | virtual void print_internal(const std::string &msg) = 0; 5126#14 93.17 | ~~~~~^ 5127#14 93.17 In file included from external/compute_library/src/core/utils/logging/FilePrinter.cpp:24: 5128#14 93.17 external/compute_library/arm_compute/core/utils/logging/FilePrinter.h:47:49: error: non-virtual member function marked 'override' hides virtual member function 5129#14 93.17 47 | void print_internal(const std::string &msg) override; 5130#14 93.17 | ^ 5131#14 93.17 external/compute_library/arm_compute/core/utils/logging/IPrinter.h:67:18: note: hidden overloaded virtual function 'arm_compute::logging::printer::print_internal' declared here: type mismatch at 1st parameter ('const int &' vs 'const std::string &' (aka 'const basic_string &')) 5132#14 93.17 67 | virtual void print_internal(const std::string &msg) = 0; 5133#14 93.17 | ^ 5134#14 93.17 In file included from external/compute_library/src/core/utils/logging/FilePrinter.cpp:24: 5135#14 93.17 external/compute_library/arm_compute/core/utils/logging/FilePrinter.h:36:7: warning: abstract class is marked 'final' [-Wabstract-final-class] 5136#14 93.17 36 | class FilePrinter final : public Printer 5137#14 93.17 | ^ 5138#14 93.17 external/compute_library/arm_compute/core/utils/logging/IPrinter.h:67:18: note: unimplemented pure virtual method 'print_internal' in 'FilePrinter' 5139#14 93.17 67 | virtual void print_internal(const std::string &msg) = 0; 5140#14 93.17 | ^ 5141#14 93.17 1 warning and 3 errors generated. 5142#14 93.40 Target //jaxlib/tools:build_wheel failed to build ``` Internal triage indicates that the issue is deeper than it looks like. For now, we have to workaround the issue by disabling mkl dnn library for ARM, which won't effect us on the CUDA plugin builds. --- .github/container/build-jax.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index 95cf5246b..b179e9456 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -185,6 +185,8 @@ case "${CPU_ARCH}" in ;; "arm64") export CC_OPT_FLAGS="-march=armv8-a" + # ARM ACL build issue introduced in PR#23225 + BUILD_PARAM="${BUILD_PARAM} --disable_mkl_dnn" ;; esac