Skip to content

Commit

Permalink
[opt-tech#73] Implement the HEADER option in the UNLOAD command, and …
Browse files Browse the repository at this point in the history
…the IGNOREHEADER option in the COPY command

As per the official Redshift docs (https://docs.aws.amazon.com/redshift/latest/dg/r_UNLOAD.html), the
HEADER options enables Redshift to add the header (column names) to the CSV files produced by the
UNLOAD command.

Similarly, as per the official COPY command docs (https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html),
the IGNOREHEADER option enables Redshift to skip the first N lines in the input CSV file (ie. the CSV
header).
  • Loading branch information
Danijel Schiavuzzi committed Jul 1, 2022
1 parent 5eb359c commit 5530eef
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/main/scala/jp/ne/opt/redshiftfake/CopyCommand.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ case class CopyCommand(
emptyAsNull: Boolean,
delimiter: Char,
nullAs: String,
ignoreHeader: Int,
compression: FileCompressionParameter
) {
val qualifiedTableName = schemaName match {
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/jp/ne/opt/redshiftfake/Interceptor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ trait UnloadInterceptor extends Interceptor {
val jdbcType = JdbcType.valueOf(resultSet.getMetaData.getColumnType(index))
Extractor(jdbcType)
}
val columnNames = (1 to columnCount).map { index =>
resultSet.getMetaData.getColumnName(index)
}

val rows = Iterator.continually(resultSet).takeWhile(_.next()).map { rs =>
val row = extractors.zipWithIndex.map { case (extractor, i) =>
Expand All @@ -112,7 +115,7 @@ trait UnloadInterceptor extends Interceptor {
Row(row)
}

new Writer(command, s3Service).write(rows.toList)
new Writer(command, s3Service).write(columnNames, rows.toList)
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/jp/ne/opt/redshiftfake/UnloadCommand.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ case class UnloadCommand(
credentials: Credentials,
createManifest: Boolean,
delimiter: Char,
addQuotes: Boolean
addQuotes: Boolean,
header: Boolean
)
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class CopyCommandParser extends BaseParser {
s"$any*(?i)BZIP2".r
}

private[this] val ignoreHeaderParser = s"$any*(?i)IGNOREHEADER$space+AS".r ~> "'" ~> """[^']*""".r <~ "'" <~ s"$any*".r

def parse(query: String): Option[CopyCommand] = {
val result = parse(
("(?i)COPY".r ~> tableNameParser) ~
Expand All @@ -81,6 +83,7 @@ class CopyCommandParser extends BaseParser {
parse(emptyAsNullParser, dataConversionParameters).successful,
parse(delimiterParser, dataConversionParameters).getOrElse('|'),
parse(nullAsParser, dataConversionParameters).getOrElse("\u000e"),
parse(ignoreHeaderParser, dataConversionParameters).getOrElse("0").toInt,
parseFileCompression(dataConversionParameters)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class UnloadCommandParser extends BaseParser with QueryCompatibility {

private[this] val addQuotesParser = s"$any*(?i)ADDQUOTES$any*".r

private[this] val headerParser = s"$any*(?i)HEADER$any*".r

private[this] val manifestParser = s"$any*(?i)MANIFEST$any*".r

private[this] val statementParser = "(?i)UNLOAD".r ~> selectStatementParser ^^ { s =>
Expand All @@ -57,7 +59,8 @@ class UnloadCommandParser extends BaseParser with QueryCompatibility {
auth,
parse(manifestParser, unloadOptions).successful,
parse(delimiterParser, unloadOptions).getOrElse('|'),
parse(addQuotesParser, unloadOptions).successful
parse(addQuotesParser, unloadOptions).successful,
parse(headerParser, unloadOptions).successful
)
},
query
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/jp/ne/opt/redshiftfake/read/Reader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Reader(copyCommand: CopyCommand, columnDefinitions: Seq[ColumnDefinition],
case CopyFormat.Manifest(_) | CopyFormat.Default =>
(for {
content <- contents
line <- content.trim.lines
line <- content.trim.lines.drop(copyCommand.ignoreHeader)
} yield {
val csvReader = new CsvReader(line, copyCommand.delimiter, copyCommand.nullAs)
csvReader.toRow
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/jp/ne/opt/redshiftfake/write/Writer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ class Writer(unloadCommand: UnloadCommand, s3Service: S3Service) {
val quoting: Quoting = if (unloadCommand.addQuotes) QUOTE_ALL else QUOTE_NONE
}

def write(rows: Seq[Row]): Unit = {
def write(columnNames: Seq[String], rows: Seq[Row]): Unit = {
val stream = new ByteArrayOutputStream()

using(CSVWriter.open(stream)(csvFormat)) { csvWriter =>
if (unloadCommand.header) {
csvWriter.writeRow(columnNames)
}
rows.foreach { row =>
csvWriter.writeRow(row.columns.map(_.rawValue.getOrElse("")))
}
Expand Down
4 changes: 3 additions & 1 deletion src/test/scala/jp/ne/opt/redshiftfake/IntegrationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ class IntegrationTest extends fixture.FlatSpec
s"""unload ('select c, count(*), sum(a) from foo where b = true group by c order by c') to '${Global.s3Scheme}foo/unloaded_'
|credentials 'aws_access_key_id=AKIAXXXXXXXXXXXXXXX;aws_secret_access_key=YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY'
|manifest
|addquotes""".stripMargin
|addquotes
|header""".stripMargin
)

conn.createStatement().execute(
s"""copy bar from '${Global.s3Scheme}foo/unloaded_manifest'
|credentials 'aws_access_key_id=AKIAXXXXXXXXXXXXXXX;aws_secret_access_key=YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY'
|manifest
|removequotes
|ignoreheader as '1'
|dateformat 'auto'""".stripMargin
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class CopyCommandParserTest extends FlatSpec {
emptyAsNull = false,
delimiter = '|',
nullAs = "\u000e",
ignoreHeader = 0,
compression = FileCompressionParameter.None
)

Expand Down Expand Up @@ -193,6 +194,32 @@ class CopyCommandParserTest extends FlatSpec {
assert(new CopyCommandParser().parse(command).map(_.nullAs) == Some("\u000e"))
}

it should "parse 'IGNOREHEADER AS' from COPY command" in {
val command =
s"""
|COPY "public"."mytable"
|FROM '${Global.s3Scheme}some-bucket/path/to/unloaded_manifest.json'
|CREDENTIALS 'aws_access_key_id=AKIAXXXXXXXXXXXXXXX;aws_secret_access_key=YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY'
|IGNOREHEADER AS '1'
|MANIFEST
|""".stripMargin

assert(new CopyCommandParser().parse(command).map(_.ignoreHeader) == Some(1))
}

it should "set default 'IGNOREHEADER AS' correctly" in {
val command =
s"""
|COPY "public"."mytable"
|FROM '${Global.s3Scheme}some-bucket/path/to/unloaded_manifest.json'
|CREDENTIALS 'aws_access_key_id=AKIAXXXXXXXXXXXXXXX;aws_secret_access_key=YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY'
|FORMAT AS CSV
|MANIFEST
|""".stripMargin

assert(new CopyCommandParser().parse(command).map(_.ignoreHeader) == Some(0))
}

it should "parse aws_role_arn from COPY command" in {
val command =
s"""
Expand All @@ -212,6 +239,7 @@ class CopyCommandParserTest extends FlatSpec {
emptyAsNull = false,
delimiter = '|',
nullAs = "\u000e",
ignoreHeader = 0,
compression = FileCompressionParameter.None
)

Expand Down Expand Up @@ -242,6 +270,7 @@ class CopyCommandParserTest extends FlatSpec {
emptyAsNull = false,
delimiter = '|',
nullAs = "\u000e",
ignoreHeader = 0,
compression = FileCompressionParameter.None
)

Expand Down Expand Up @@ -269,6 +298,7 @@ class CopyCommandParserTest extends FlatSpec {
emptyAsNull = false,
delimiter = '|',
nullAs = "\u000e",
ignoreHeader = 0,
compression = FileCompressionParameter.None
)

Expand All @@ -294,6 +324,7 @@ class CopyCommandParserTest extends FlatSpec {
emptyAsNull = false,
delimiter = '|',
nullAs = "\u000e",
ignoreHeader = 0,
compression = FileCompressionParameter.None
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class UnloadCommandParserTest extends FlatSpec {
),
createManifest = false,
delimiter = '|',
addQuotes = false
addQuotes = false,
header = false
)

assert(new UnloadCommandParser().parse(command) == Some(expected))
Expand Down Expand Up @@ -88,7 +89,8 @@ class UnloadCommandParserTest extends FlatSpec {
credentials = Credentials.WithRole("arn:aws:iam::12345:role/some-role"),
createManifest = false,
delimiter = '|',
addQuotes = false
addQuotes = false,
header = false
)

assert(new UnloadCommandParser().parse(command) == Some(expected))
Expand All @@ -109,7 +111,8 @@ class UnloadCommandParserTest extends FlatSpec {
credentials = Credentials.WithTemporaryToken("some_access_key_id", "some_secret_access_key", "some_session_token"),
createManifest = false,
delimiter = '|',
addQuotes = false
addQuotes = false,
header = false
)

assert(new UnloadCommandParser().parse(command) == Some(expected))
Expand All @@ -132,7 +135,8 @@ class UnloadCommandParserTest extends FlatSpec {
credentials = Credentials.WithTemporaryToken("some_access_key_id", "some_secret_access_key", someSessionToken),
createManifest = false,
delimiter = '|',
addQuotes = false
addQuotes = false,
header = false
)

assert(new UnloadCommandParser().parse(command) == Some(expected))
Expand All @@ -152,7 +156,8 @@ class UnloadCommandParserTest extends FlatSpec {
credentials = Credentials.WithRole("arn:aws:iam::12345:role/some-role"),
createManifest = false,
delimiter = '|',
addQuotes = false
addQuotes = false,
header = false
)

assert(new UnloadCommandParser().parse(command) == Some(expected))
Expand Down

0 comments on commit 5530eef

Please sign in to comment.