diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java index 21be392f8f1..92f8ec7cb33 100644 --- a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java +++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java @@ -81,6 +81,7 @@ import org.apache.calcite.sql.SqlUpdate; import org.apache.calcite.sql.SqlUtil; import org.apache.calcite.sql.fun.SqlInternalOperators; +import org.apache.calcite.sql.fun.SqlMinMaxAggFunction; import org.apache.calcite.sql.fun.SqlSingleValueAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; @@ -577,6 +578,8 @@ protected Builder buildAggregate(Aggregate e, Builder builder, RelDataType aggCallRelDataType = aggCall.getType(); if (aggCall.getAggregation() instanceof SqlSingleValueAggFunction) { aggCallSqlNode = dialect.rewriteSingleValueExpr(aggCallSqlNode, aggCallRelDataType); + } else if (aggCall.getAggregation() instanceof SqlMinMaxAggFunction) { + aggCallSqlNode = dialect.rewriteMaxMinExpr(aggCallSqlNode, aggCallRelDataType); } addSelect(selectList, aggCallSqlNode, e.getRowType()); } diff --git a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java index a013c861e55..686170afcb5 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java @@ -872,6 +872,13 @@ public SqlNode rewriteSingleValueExpr(SqlNode aggCall, RelDataType relDataType) return aggCall; } + /** With x as BOOLEAN column, rewrite MAX(x)/MIN(x) as BOOL_OR(x)/BOOL_AND(x) + * for certain database variants (Postgres and Redshift, currently). + */ + public SqlNode rewriteMaxMinExpr(SqlNode aggCall, RelDataType relDataType) { + return aggCall; + } + /** * Returns the SqlNode for emulating the null direction for the given field * or null if no emulation needs to be done. diff --git a/core/src/main/java/org/apache/calcite/sql/dialect/PostgresqlSqlDialect.java b/core/src/main/java/org/apache/calcite/sql/dialect/PostgresqlSqlDialect.java index 9dd41938a15..2292780250d 100644 --- a/core/src/main/java/org/apache/calcite/sql/dialect/PostgresqlSqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/dialect/PostgresqlSqlDialect.java @@ -173,12 +173,16 @@ public PostgresqlSqlDialect(Context context) { timeUnitNode.getParserPosition()); SqlFloorFunction.unparseDatetimeFunction(writer, call2, "DATE_TRUNC", false); break; - default: super.unparseCall(writer, call, leftPrec, rightPrec); } } + @Override public SqlNode rewriteMaxMinExpr(SqlNode aggCall, RelDataType relDataType) { + RedshiftSqlDialect redshiftSqlDialect = new RedshiftSqlDialect(DEFAULT_CONTEXT); + return redshiftSqlDialect.rewriteMaxMinExpr(aggCall, relDataType); + } + @Override public boolean supportsGroupByLiteral() { return false; } diff --git a/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java b/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java index 4e94977ce30..01719bc2ab7 100644 --- a/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java @@ -20,11 +20,15 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.rel.type.RelDataTypeSystemImpl; +import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlDataTypeSpec; import org.apache.calcite.sql.SqlDialect; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlUserDefinedTypeNameSpec; import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.SqlTypeName; @@ -107,6 +111,21 @@ public RedshiftSqlDialect(Context context) { SqlParserPos.ZERO); } + @Override public SqlNode rewriteMaxMinExpr(SqlNode aggCall, RelDataType relDataType) { + // The behavior of this method depends on the argument type, + // and whether it is MIN/MAX + final SqlTypeName type = relDataType.getSqlTypeName(); + final boolean isMax = aggCall.getKind() == SqlKind.MAX; + // If the type is BOOLEAN, create a new call to the correct operator + if (type == SqlTypeName.BOOLEAN) { + final SqlOperator op = isMax ? SqlLibraryOperators.BOOL_OR : SqlLibraryOperators.BOOL_AND; + final SqlNode operand = ((SqlBasicCall) aggCall).operand(0); + return op.createCall(SqlParserPos.ZERO, operand); + } + // Otherwise, just return as it arrived + return aggCall; + } + @Override public boolean supportsGroupByLiteral() { return false; } diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java index 8e4fb1d0820..d5b00ec700d 100644 --- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java @@ -6677,12 +6677,37 @@ private void checkLiteral2(String expression, String expected) { * Add BITAND_AGG, BITOR_AGG functions (enabled in Snowflake library). */ @Test void testBitOrAgg() { final String query = "select bit_or(\"product_id\")\n" - + "from \"product\""; + + "from \"product\""; final String expectedSnowflake = "SELECT BITOR_AGG(\"product_id\")\n" - + "FROM \"foodmart\".\"product\""; + + "FROM \"foodmart\".\"product\""; sql(query).withLibrary(SqlLibrary.SNOWFLAKE).withSnowflake().ok(expectedSnowflake); } + /** Test case for + * [CALCITE-6220] + * Rewrite MIN/MAX(bool) as BOOL_AND/BOOL_OR for Postgres, Redshift. */ + @Test void testMaxMinOnBooleanColumn() { + final String query = "select max(\"brand_name\" = 'a'), " + + "min(\"brand_name\" = 'a'), " + + "min(\"brand_name\")\n" + + "from \"product\""; + final String expected = "SELECT MAX(\"brand_name\" = 'a'), " + + "MIN(\"brand_name\" = 'a'), " + + "MIN(\"brand_name\")\n" + + "FROM \"foodmart\".\"product\""; + final String expectedPostgres = "SELECT BOOL_OR(\"brand_name\" = 'a'), " + + "BOOL_AND(\"brand_name\" = 'a'), " + + "MIN(\"brand_name\")\n" + + "FROM \"foodmart\".\"product\""; + final String expectedRedshift = "SELECT BOOL_OR(\"brand_name\" = 'a'), " + + "BOOL_AND(\"brand_name\" = 'a'), " + + "MIN(\"brand_name\")\n" + + "FROM \"foodmart\".\"product\""; + sql(query).ok(expected); + sql(query).withPostgresql().ok(expectedPostgres); + sql(query).withRedshift().ok(expectedRedshift); + } + /** Test case for * [CALCITE-6156] * Add ENDSWITH, STARTSWITH functions (enabled in Postgres, Snowflake libraries). */