diff --git a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp index 9a9209b3db32..83cc8b101374 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp @@ -82,11 +82,12 @@ static void populateHloFeatures(Operation *op, InputFeatures &features) { } } -static void populateFeatures(Operation *op, const Dialect *stablehloDialect, +static void populateFeatures(Operation *op, const Dialect *chloDialect, + const Dialect *stablehloDialect, const Dialect *tosaDialect, InputFeatures &features) { Dialect *d = op->getDialect(); - if (d == stablehloDialect) { + if (d == stablehloDialect || d == chloDialect) { features.hasStableHLO = true; return populateHloFeatures(op, features); } @@ -101,14 +102,15 @@ void AutoInputConversionPipelinePass::runOnOperation() { MLIRContext *ctxt = &getContext(); InputFeatures features; + const Dialect *chloDialect = ctxt->getLoadedDialect("chlo"); const Dialect *stablehloDialect = ctxt->getLoadedDialect("stablehlo"); const Dialect *tosaDialect = ctxt->getLoadedDialect("tosa"); - if (!stablehloDialect && !tosaDialect) { + if (!chloDialect && !stablehloDialect && !tosaDialect) { return; } auto res = module.walk([&](Operation *op) { - populateFeatures(op, stablehloDialect, tosaDialect, features); + populateFeatures(op, chloDialect, stablehloDialect, tosaDialect, features); if (features.hasStableHLO && features.hasTOSA) { module.emitError("not yet implemented mixture of *HLO and TOSA"); return WalkResult::interrupt();