Skip to content

Commit

Permalink
Picking changes from the fix for SPARK-32709 (#39)
Browse files Browse the repository at this point in the history
* Picking changes from the fix for SPARK-32709
Applied patch from PR: apache#33432

* Fixing compile issues

* fix compile failure Spark Hive module
  • Loading branch information
s-pedamallu authored Aug 16, 2021
1 parent b4400c7 commit cefa493
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,33 @@ abstract class FileCommitProtocol extends Logging {
*/
def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String

/**
* Notifies the commit protocol to add a new file, and gets back the full path that should be
* used. Must be called on the executors when running tasks.
*
* Note that the returned temp file may have an arbitrary path. The commit protocol only
* promises that the file will be at the location specified by the arguments after job commit.
*
* The "dir" parameter specifies the sub-directory within the base path, used to specify
* partitioning. The "spec" parameter specifies the file name. The rest are left to the commit
* protocol implementation to decide.
*
* Important: it is the caller's responsibility to add uniquely identifying content to "spec"
* if a task is going to write out multiple files to the same dir. The file commit protocol only
* guarantees that files written by different tasks will not conflict.
*
* @since 3.2.0
*/
def newTaskTempFile(
taskContext: TaskAttemptContext, dir: Option[String], spec: FileNameSpec): String = {
if (spec.prefix.isEmpty) {
newTaskTempFile(taskContext, dir, spec.suffix)
} else {
throw new UnsupportedOperationException(s"${getClass.getSimpleName}.newTaskTempFile does " +
s"not support file name prefix: ${spec.prefix}")
}
}

/**
* Similar to newTaskTempFile(), but allows files to committed to an absolute output location.
* Depending on the implementation, there may be weaker guarantees around adding files this way.
Expand All @@ -103,6 +130,31 @@ abstract class FileCommitProtocol extends Logging {
def newTaskTempFileAbsPath(
taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String

/**
* Similar to newTaskTempFile(), but allows files to committed to an absolute output location.
* Depending on the implementation, there may be weaker guarantees around adding files this way.
*
* The "absoluteDir" parameter specifies the final absolute directory of file. The "spec"
* parameter specifies the file name. The rest are left to the commit protocol implementation to
* decide.
*
* Important: it is the caller's responsibility to add uniquely identifying content to "spec"
* if a task is going to write out multiple files to the same dir. The file commit protocol only
* guarantees that files written by different tasks will not conflict.
*
* @since 3.2.0
*/
def newTaskTempFileAbsPath(
taskContext: TaskAttemptContext, absoluteDir: String, spec: FileNameSpec): String = {
if (spec.prefix.isEmpty) {
newTaskTempFileAbsPath(taskContext, absoluteDir, spec.suffix)
} else {
throw new UnsupportedOperationException(
s"${getClass.getSimpleName}.newTaskTempFileAbsPath does not support file name prefix: " +
s"${spec.prefix}")
}
}

/**
* Commits a task after the writes succeed. Must be called on the executors when running tasks.
*/
Expand Down Expand Up @@ -140,6 +192,15 @@ object FileCommitProtocol extends Logging {

object EmptyTaskCommitMessage extends TaskCommitMessage(null)

/**
* The specification for Spark output file name.
* This is used by [[FileCommitProtocol]] to create full path of file.
*
* @param prefix Prefix of file.
* @param suffix Suffix of file.
*/
final case class FileNameSpec(prefix: String, suffix: String)

/**
* Instantiates a FileCommitProtocol using the given className.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,12 @@ class HadoopMapReduceCommitProtocol(

override def newTaskTempFile(
taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
val filename = getFilename(taskContext, ext)
newTaskTempFile(taskContext, dir, FileNameSpec("", ext))
}

override def newTaskTempFile(
taskContext: TaskAttemptContext, dir: Option[String], spec: FileNameSpec): String = {
val filename = getFilename(taskContext, spec)

val stagingDir: Path = committer match {
// For FileOutputCommitter it has its own staging path called "work path".
Expand All @@ -141,7 +146,12 @@ class HadoopMapReduceCommitProtocol(

override def newTaskTempFileAbsPath(
taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
val filename = getFilename(taskContext, ext)
newTaskTempFileAbsPath(taskContext, absoluteDir, FileNameSpec("", ext))
}

override def newTaskTempFileAbsPath(
taskContext: TaskAttemptContext, absoluteDir: String, spec: FileNameSpec): String = {
val filename = getFilename(taskContext, spec)
val absOutputPath = new Path(absoluteDir, filename).toString

// Include a UUID here to prevent file collisions for one task writing to different dirs.
Expand All @@ -152,12 +162,12 @@ class HadoopMapReduceCommitProtocol(
tmpOutputPath
}

protected def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
protected def getFilename(taskContext: TaskAttemptContext, spec: FileNameSpec): String = {
// The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet
// Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
// the file name is fine and won't overflow.
val split = taskContext.getTaskAttemptID.getTaskID.getId
f"part-$split%05d-$jobId$ext"
f"${spec.prefix}part-$split%05d-$jobId${spec.suffix}"
}

override def setupJob(jobContext: JobContext): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, PathOutputCommitter, PathOutputCommitterFactory}

import org.apache.spark.internal.io.FileCommitProtocol.FileNameSpec
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol

/**
Expand Down Expand Up @@ -122,20 +123,20 @@ class PathOutputCommitProtocol(
*
* @param taskContext task context
* @param dir optional subdirectory
* @param ext file extension
* @param spec file naming specification
* @return a path as a string
*/
override def newTaskTempFile(
taskContext: TaskAttemptContext,
dir: Option[String],
ext: String): String = {
spec: FileNameSpec): String = {

val workDir = committer.getWorkPath
val parent = dir.map {
d => new Path(workDir, d)
}.getOrElse(workDir)
val file = new Path(parent, getFilename(taskContext, ext))
logTrace(s"Creating task file $file for dir $dir and ext $ext")
val file = new Path(parent, getFilename(taskContext, spec))
logTrace(s"Creating task file $file for dir $dir and spec $spec")
file.toString
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.internal.io.FileCommitProtocol.FileNameSpec
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
Expand Down Expand Up @@ -157,7 +158,7 @@ class DynamicPartitionDataWriter(
private val isPartitioned = description.partitionColumns.nonEmpty

/** Flag saying whether or not the data to be written out is bucketed. */
private val isBucketed = description.bucketIdExpression.isDefined
protected val isBucketed = description.bucketSpec.isDefined

assert(isPartitioned || isBucketed,
s"""DynamicPartitionWriteTask should be used for writing out data that's either
Expand Down Expand Up @@ -196,7 +197,8 @@ class DynamicPartitionDataWriter(
/** Given an input row, returns the corresponding `bucketId` */
private lazy val getBucketId: InternalRow => Int = {
val proj =
UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns)
UnsafeProjection.create(Seq(description.bucketSpec.get.bucketIdExpression),
description.allColumns)
row => proj(row).getInt(0)
}

Expand All @@ -222,17 +224,23 @@ class DynamicPartitionDataWriter(

val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")

// This must be in a form that matches our bucketing format. See BucketingUtils.
val ext = f"$bucketIdStr.c$fileCounter%03d" +
// The prefix and suffix must be in a form that matches our bucketing format.
// See BucketingUtils.
val prefix = bucketId match {
case Some(id) => description.bucketSpec.get.bucketFileNamePrefix(id)
case _ => ""
}
val suffix = f"$bucketIdStr.c$fileCounter%03d" +
description.outputWriterFactory.getFileExtension(taskAttemptContext)
val fileNameSpec = FileNameSpec(prefix, suffix)

val customPath = partDir.flatMap { dir =>
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
}
val currentPath = if (customPath.isDefined) {
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, fileNameSpec)
} else {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
committer.newTaskTempFile(taskAttemptContext, partDir, fileNameSpec)
}

currentWriter = description.outputWriterFactory.newInstance(
Expand Down Expand Up @@ -277,6 +285,16 @@ class DynamicPartitionDataWriter(
}
}

/**
* Bucketing specification for all the write tasks.
*
* @param bucketIdExpression Expression to calculate bucket id based on bucket column(s).
* @param bucketFileNamePrefix Prefix of output file name based on bucket id.
*/
case class WriterBucketSpec(
bucketIdExpression: Expression,
bucketFileNamePrefix: Int => String)

/** A shared job description for all the write tasks. */
class WriteJobDescription(
val uuid: String, // prevent collision between different (appending) write jobs
Expand All @@ -285,7 +303,7 @@ class WriteJobDescription(
val allColumns: Seq[Attribute],
val dataColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
val bucketIdExpression: Option[Expression],
val bucketSpec: Option[WriterBucketSpec],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -113,12 +114,33 @@ object FileFormatWriter extends Logging {
}
val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan

val bucketIdExpression = bucketSpec.map { spec =>
val writerBucketSpec = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression

if (options.getOrElse(DDLUtils.HIVE_PROVIDER, "false") == "true") {
// Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression.
// Without the extra bitwise-and operation, we can get wrong bucket id when hash value of
// columns is negative. See Hive implementation in
// `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))

// The bucket file name prefix is following Hive, Presto and Trino conversion, so this
// makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
//
// Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
// Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
WriterBucketSpec(bucketIdExpression, fileNamePrefix)
} else {
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
// table and a normal one.
val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets)
.partitionIdExpression
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
}
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
Expand All @@ -139,7 +161,7 @@ object FileFormatWriter extends Logging {
allColumns = outputSpec.outputColumns,
dataColumns = dataColumns,
partitionColumns = partitionColumns,
bucketIdExpression = bucketIdExpression,
bucketSpec = writerBucketSpec,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
Expand All @@ -150,7 +172,8 @@ object FileFormatWriter extends Logging {
)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
val requiredOrdering =
partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
// the sort order doesn't matter
val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
Expand Down Expand Up @@ -265,7 +288,7 @@ object FileFormatWriter extends Logging {
if (sparkPartitionId != 0 && !iterator.hasNext) {
// In case of empty job, leave first partition to save meta for file format like parquet.
new EmptyDirectoryDataWriter(description, taskAttemptContext, committer)
} else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
} else if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) {
new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
} else {
new DynamicPartitionDataWriter(description, taskAttemptContext, committer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ abstract class FileWriteBuilder(
allColumns = allColumns,
dataColumns = allColumns,
partitionColumns = Seq.empty,
bucketIdExpression = None,
bucketSpec = None,
path = pathName,
customPartitionLocations = Map.empty,
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
Expand Down
Loading

0 comments on commit cefa493

Please sign in to comment.