Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix check write for spark 3.5 unit tests #6793

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
*/
package org.apache.spark.sql

import java.util.Collections

import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions}
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, LogicalQueryStage}
import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec}
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.test.SQLTestData.TestData
Expand Down Expand Up @@ -99,15 +102,24 @@ trait KyuubiSparkSQLExtensionTest extends QueryTest
withListener(sql(sqlString))(callback)
}

def withListener(df: => DataFrame)(callback: DataWritingCommand => Unit): Unit = {
def withListener(df: => DataFrame)(
callback: DataWritingCommand => Unit,
failIfNotCallback: Boolean = true): Unit = {
val writes = Collections.synchronizedList(new java.util.ArrayList[DataWritingCommand]())

val listener = new QueryExecutionListener {
override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
qe.executedPlan match {
case write: DataWritingCommandExec => callback(write.cmd)
case _ =>
def collectWrite(plan: SparkPlan): Unit = {
plan match {
case write: DataWritingCommandExec =>
writes.add(write.cmd)
case a: AdaptiveSparkPlanExec => collectWrite(a.executedPlan)
case _ =>
}
}
collectWrite(qe.executedPlan)
}
}
spark.listenerManager.register(listener)
Expand All @@ -117,5 +129,20 @@ trait KyuubiSparkSQLExtensionTest extends QueryTest
} finally {
spark.listenerManager.unregister(listener)
}
if (failIfNotCallback && writes.isEmpty) {
fail("No write command found")
}
writes.forEach(callback(_))
}

def collectRebalancePartitions(plan: LogicalPlan): Seq[RebalancePartitions] = {
def collect(p: LogicalPlan): Seq[RebalancePartitions] = {
p.flatMap {
case r: RebalancePartitions => Seq(r)
case s: LogicalQueryStage => collect(s.logicalPlan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to match LogicalQueryStage ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case gets a logical plan like this:

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a Spark issue ? IMO, LogicalQueryStage should not exist in query execution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's because the logical plan is obtained by write.cmd

case _ => Nil
}
}
collect(plan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartit
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.execution.InsertIntoHiveTable

import org.apache.kyuubi.sql.KyuubiSQLConf

Expand All @@ -31,17 +30,15 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest {
test("check rebalance exists") {
def check(df: => DataFrame, expectedRebalanceNum: Int = 1): Unit = {
withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") {
var rebalanceNum = 0
withListener(df) { write =>
assert(write.collect {
case r: RebalancePartitions => r
}.size == expectedRebalanceNum)
rebalanceNum += collectRebalancePartitions(write).size
}
assert(rebalanceNum == expectedRebalanceNum)
}
withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "false") {
withListener(df) { write =>
assert(write.collect {
case r: RebalancePartitions => r
}.isEmpty)
assert(collectRebalancePartitions(write).isEmpty)
}
}
}
Expand Down Expand Up @@ -97,29 +94,30 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest {
}

test("check rebalance does not exists") {
def check(df: DataFrame): Unit = {
def checkQuery(df: => DataFrame): Unit = {
assert(collectRebalancePartitions(df.queryExecution.analyzed).isEmpty)
}
def checkWrite(df: => DataFrame): Unit = {
withListener(df) { write =>
assert(write.collect {
case r: RebalancePartitions => r
}.isEmpty)
assert(collectRebalancePartitions(write).isEmpty)
}
}

withSQLConf(
KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true",
KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") {
// test no write command
check(sql("SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
check(sql("SELECT count(*) FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
checkQuery(sql("SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
checkQuery(sql("SELECT count(*) FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))

// test not supported plan
withTable("tmp1") {
sql(s"CREATE TABLE tmp1 (c1 int) PARTITIONED BY (c2 string)")
check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
checkWrite(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
"SELECT /*+ repartition(10) */ * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
checkWrite(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
"SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) ORDER BY c1"))
check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
checkWrite(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
"SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) LIMIT 10"))
}
}
Expand All @@ -128,13 +126,13 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest {
Seq("USING PARQUET", "").foreach { storage =>
withTable("tmp1") {
sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
checkWrite(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
"SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
}

withTable("tmp1") {
sql(s"CREATE TABLE tmp1 (c1 int) $storage")
check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)"))
checkWrite(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)"))
}
}
}
Expand All @@ -143,27 +141,30 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest {
test("test dynamic partition write") {
def checkRepartitionExpression(sqlString: String): Unit = {
withListener(sqlString) { write =>
assert(write.isInstanceOf[InsertIntoHiveTable])
assert(write.collect {
case r: RebalancePartitions if r.partitionExpressions.size == 1 =>
assert(r.partitionExpressions.head.asInstanceOf[Attribute].name === "c2")
r
}.size == 1)
val rebalancePartitions = collectRebalancePartitions(write)
assert(rebalancePartitions.size == 1)
assert(rebalancePartitions.head.partitionExpressions.size == 1 &&
rebalancePartitions.head.partitionExpressions.head.asInstanceOf[Attribute].name === "c2")
}
}

withSQLConf(
KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true",
KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") {
Seq("USING PARQUET", "").foreach { storage =>
withTable("tmp1") {
sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
checkRepartitionExpression("INSERT INTO TABLE tmp1 SELECT 1 as c1, 'a' as c2 ")
}
withTable("tmp2") {
sql(s"CREATE TABLE tmp2 (c1 int, c2 string)")
sql(s"INSERT INTO tmp2 SELECT 1, 'a'")

withTable("tmp1") {
checkRepartitionExpression(
"CREATE TABLE tmp1 PARTITIONED BY(C2) SELECT 1 as c1, 'a' as c2")
withTable("tmp1") {
sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
checkRepartitionExpression("INSERT INTO TABLE tmp1 SELECT c1, c2 from tmp2")
}

withTable("tmp1") {
checkRepartitionExpression(
"CREATE TABLE tmp1 PARTITIONED BY(C2) SELECT c1, c2 from tmp2")
}
}
}
}
Expand All @@ -177,9 +178,7 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest {
withTable("t") {
withListener("CREATE TABLE t STORED AS parquet AS SELECT 1 as a") { write =>
assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand])
assert(write.collect {
case _: RebalancePartitions => true
}.size == 1)
assert(collectRebalancePartitions(write).size == 1)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFu
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, EqualTo, Expression, ExpressionEvalHelper, Literal, NullsLast, SortOrder}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Project, Sort}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFiles}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
Expand Down Expand Up @@ -245,7 +245,15 @@ trait ZorderSuiteBase extends KyuubiSparkSQLExtensionTest with ExpressionEvalHel
planHasRepartition: Boolean,
resHasSort: Boolean): Unit = {
def checkSort(plan: LogicalPlan): Unit = {
assert(plan.isInstanceOf[Sort] === resHasSort)
def collectSort(plan: LogicalPlan): Option[Sort] = {
plan match {
case sort: Sort => Some(sort)
case f: WriteFiles => collectSort(f.child)
case _ => None
}
}
val sortOpt = collectSort(plan)
assert(sortOpt.isDefined === resHasSort)
plan match {
case sort: Sort =>
val colArr = cols.split(",")
Expand Down Expand Up @@ -332,19 +340,20 @@ trait ZorderSuiteBase extends KyuubiSparkSQLExtensionTest with ExpressionEvalHel
assert(df1.queryExecution.analyzed.isInstanceOf[InsertIntoHadoopFsRelationCommand])
checkSort(df1.queryExecution.analyzed.children.head)

withListener(
s"""
|CREATE TABLE zorder_t4 USING PARQUET
|TBLPROPERTIES (
| 'kyuubi.zorder.enabled' = '$enabled',
| 'kyuubi.zorder.cols' = '$cols')
|
|SELECT $repartition * FROM
|VALUES(1,'a',2,4D),(2,'b',3,6D) AS t(c1 ,c2 , c3, c4)
|""".stripMargin) { write =>
assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand])
checkSort(write.query)
}
// TODO: CreateDataSourceTableAsSelectCommand is not supported
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does CreateDataSourceTableAsSelectCommandnot support ? CreateDataSourceTableAsSelectCommand should call InsertIntoHadoopFsRelationCommand internally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InsertIntoHadoopFsRelationCommand's catalogTable will be empty

Copy link
Member Author

@wForget wForget Nov 6, 2024

// withListener(
// s"""
// |CREATE TABLE zorder_t4 USING PARQUET
// |TBLPROPERTIES (
// | 'kyuubi.zorder.enabled' = '$enabled',
// | 'kyuubi.zorder.cols' = '$cols')
// |
// |SELECT $repartition * FROM
// |VALUES(1,'a',2,4D),(2,'b',3,6D) AS t(c1 ,c2 , c3, c4)
// |""".stripMargin) { write =>
// assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand])
// checkSort(write.query)
// }
}
}
}
Expand Down
Loading