From 5530eefd750f0cfe5ade3dba788e276043b7626b Mon Sep 17 00:00:00 2001 From: Danijel Schiavuzzi Date: Fri, 1 Jul 2022 12:39:08 +0100 Subject: [PATCH] [opt-tech/redshift-fake-driver#73] Implement the HEADER option in the UNLOAD command, and the IGNOREHEADER option in the COPY command As per the official Redshift docs (https://docs.aws.amazon.com/redshift/latest/dg/r_UNLOAD.html), the HEADER options enables Redshift to add the header (column names) to the CSV files produced by the UNLOAD command. Similarly, as per the official COPY command docs (https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html), the IGNOREHEADER option enables Redshift to skip the first N lines in the input CSV file (ie. the CSV header). --- .../jp/ne/opt/redshiftfake/CopyCommand.scala | 1 + .../jp/ne/opt/redshiftfake/Interceptor.scala | 5 ++- .../ne/opt/redshiftfake/UnloadCommand.scala | 3 +- .../parse/CopyCommandParser.scala | 3 ++ .../parse/UnloadCommandParser.scala | 5 ++- .../jp/ne/opt/redshiftfake/read/Reader.scala | 2 +- .../jp/ne/opt/redshiftfake/write/Writer.scala | 5 ++- .../ne/opt/redshiftfake/IntegrationTest.scala | 4 ++- .../parse/CopyCommandParserTest.scala | 31 +++++++++++++++++++ .../parse/UnloadCommandParserTest.scala | 15 ++++++--- 10 files changed, 63 insertions(+), 11 deletions(-) diff --git a/src/main/scala/jp/ne/opt/redshiftfake/CopyCommand.scala b/src/main/scala/jp/ne/opt/redshiftfake/CopyCommand.scala index a054848..a8ca7dc 100644 --- a/src/main/scala/jp/ne/opt/redshiftfake/CopyCommand.scala +++ b/src/main/scala/jp/ne/opt/redshiftfake/CopyCommand.scala @@ -17,6 +17,7 @@ case class CopyCommand( emptyAsNull: Boolean, delimiter: Char, nullAs: String, + ignoreHeader: Int, compression: FileCompressionParameter ) { val qualifiedTableName = schemaName match { diff --git a/src/main/scala/jp/ne/opt/redshiftfake/Interceptor.scala b/src/main/scala/jp/ne/opt/redshiftfake/Interceptor.scala index 35de594..99a37f5 100644 --- a/src/main/scala/jp/ne/opt/redshiftfake/Interceptor.scala +++ b/src/main/scala/jp/ne/opt/redshiftfake/Interceptor.scala @@ -104,6 +104,9 @@ trait UnloadInterceptor extends Interceptor { val jdbcType = JdbcType.valueOf(resultSet.getMetaData.getColumnType(index)) Extractor(jdbcType) } + val columnNames = (1 to columnCount).map { index => + resultSet.getMetaData.getColumnName(index) + } val rows = Iterator.continually(resultSet).takeWhile(_.next()).map { rs => val row = extractors.zipWithIndex.map { case (extractor, i) => @@ -112,7 +115,7 @@ trait UnloadInterceptor extends Interceptor { Row(row) } - new Writer(command, s3Service).write(rows.toList) + new Writer(command, s3Service).write(columnNames, rows.toList) } } } diff --git a/src/main/scala/jp/ne/opt/redshiftfake/UnloadCommand.scala b/src/main/scala/jp/ne/opt/redshiftfake/UnloadCommand.scala index 1085e04..f9a6318 100644 --- a/src/main/scala/jp/ne/opt/redshiftfake/UnloadCommand.scala +++ b/src/main/scala/jp/ne/opt/redshiftfake/UnloadCommand.scala @@ -11,5 +11,6 @@ case class UnloadCommand( credentials: Credentials, createManifest: Boolean, delimiter: Char, - addQuotes: Boolean + addQuotes: Boolean, + header: Boolean ) diff --git a/src/main/scala/jp/ne/opt/redshiftfake/parse/CopyCommandParser.scala b/src/main/scala/jp/ne/opt/redshiftfake/parse/CopyCommandParser.scala index 26fb1a2..d93174e 100644 --- a/src/main/scala/jp/ne/opt/redshiftfake/parse/CopyCommandParser.scala +++ b/src/main/scala/jp/ne/opt/redshiftfake/parse/CopyCommandParser.scala @@ -61,6 +61,8 @@ class CopyCommandParser extends BaseParser { s"$any*(?i)BZIP2".r } + private[this] val ignoreHeaderParser = s"$any*(?i)IGNOREHEADER$space+AS".r ~> "'" ~> """[^']*""".r <~ "'" <~ s"$any*".r + def parse(query: String): Option[CopyCommand] = { val result = parse( ("(?i)COPY".r ~> tableNameParser) ~ @@ -81,6 +83,7 @@ class CopyCommandParser extends BaseParser { parse(emptyAsNullParser, dataConversionParameters).successful, parse(delimiterParser, dataConversionParameters).getOrElse('|'), parse(nullAsParser, dataConversionParameters).getOrElse("\u000e"), + parse(ignoreHeaderParser, dataConversionParameters).getOrElse("0").toInt, parseFileCompression(dataConversionParameters) ) diff --git a/src/main/scala/jp/ne/opt/redshiftfake/parse/UnloadCommandParser.scala b/src/main/scala/jp/ne/opt/redshiftfake/parse/UnloadCommandParser.scala index 3180951..2d0dacd 100644 --- a/src/main/scala/jp/ne/opt/redshiftfake/parse/UnloadCommandParser.scala +++ b/src/main/scala/jp/ne/opt/redshiftfake/parse/UnloadCommandParser.scala @@ -38,6 +38,8 @@ class UnloadCommandParser extends BaseParser with QueryCompatibility { private[this] val addQuotesParser = s"$any*(?i)ADDQUOTES$any*".r + private[this] val headerParser = s"$any*(?i)HEADER$any*".r + private[this] val manifestParser = s"$any*(?i)MANIFEST$any*".r private[this] val statementParser = "(?i)UNLOAD".r ~> selectStatementParser ^^ { s => @@ -57,7 +59,8 @@ class UnloadCommandParser extends BaseParser with QueryCompatibility { auth, parse(manifestParser, unloadOptions).successful, parse(delimiterParser, unloadOptions).getOrElse('|'), - parse(addQuotesParser, unloadOptions).successful + parse(addQuotesParser, unloadOptions).successful, + parse(headerParser, unloadOptions).successful ) }, query diff --git a/src/main/scala/jp/ne/opt/redshiftfake/read/Reader.scala b/src/main/scala/jp/ne/opt/redshiftfake/read/Reader.scala index 3c116e5..9577811 100644 --- a/src/main/scala/jp/ne/opt/redshiftfake/read/Reader.scala +++ b/src/main/scala/jp/ne/opt/redshiftfake/read/Reader.scala @@ -38,7 +38,7 @@ class Reader(copyCommand: CopyCommand, columnDefinitions: Seq[ColumnDefinition], case CopyFormat.Manifest(_) | CopyFormat.Default => (for { content <- contents - line <- content.trim.lines + line <- content.trim.lines.drop(copyCommand.ignoreHeader) } yield { val csvReader = new CsvReader(line, copyCommand.delimiter, copyCommand.nullAs) csvReader.toRow diff --git a/src/main/scala/jp/ne/opt/redshiftfake/write/Writer.scala b/src/main/scala/jp/ne/opt/redshiftfake/write/Writer.scala index 867757a..16859a3 100644 --- a/src/main/scala/jp/ne/opt/redshiftfake/write/Writer.scala +++ b/src/main/scala/jp/ne/opt/redshiftfake/write/Writer.scala @@ -21,10 +21,13 @@ class Writer(unloadCommand: UnloadCommand, s3Service: S3Service) { val quoting: Quoting = if (unloadCommand.addQuotes) QUOTE_ALL else QUOTE_NONE } - def write(rows: Seq[Row]): Unit = { + def write(columnNames: Seq[String], rows: Seq[Row]): Unit = { val stream = new ByteArrayOutputStream() using(CSVWriter.open(stream)(csvFormat)) { csvWriter => + if (unloadCommand.header) { + csvWriter.writeRow(columnNames) + } rows.foreach { row => csvWriter.writeRow(row.columns.map(_.rawValue.getOrElse(""))) } diff --git a/src/test/scala/jp/ne/opt/redshiftfake/IntegrationTest.scala b/src/test/scala/jp/ne/opt/redshiftfake/IntegrationTest.scala index 38c30c8..161edd2 100644 --- a/src/test/scala/jp/ne/opt/redshiftfake/IntegrationTest.scala +++ b/src/test/scala/jp/ne/opt/redshiftfake/IntegrationTest.scala @@ -39,7 +39,8 @@ class IntegrationTest extends fixture.FlatSpec s"""unload ('select c, count(*), sum(a) from foo where b = true group by c order by c') to '${Global.s3Scheme}foo/unloaded_' |credentials 'aws_access_key_id=AKIAXXXXXXXXXXXXXXX;aws_secret_access_key=YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY' |manifest - |addquotes""".stripMargin + |addquotes + |header""".stripMargin ) conn.createStatement().execute( @@ -47,6 +48,7 @@ class IntegrationTest extends fixture.FlatSpec |credentials 'aws_access_key_id=AKIAXXXXXXXXXXXXXXX;aws_secret_access_key=YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY' |manifest |removequotes + |ignoreheader as '1' |dateformat 'auto'""".stripMargin ) diff --git a/src/test/scala/jp/ne/opt/redshiftfake/parse/CopyCommandParserTest.scala b/src/test/scala/jp/ne/opt/redshiftfake/parse/CopyCommandParserTest.scala index bd437f4..49cc403 100644 --- a/src/test/scala/jp/ne/opt/redshiftfake/parse/CopyCommandParserTest.scala +++ b/src/test/scala/jp/ne/opt/redshiftfake/parse/CopyCommandParserTest.scala @@ -27,6 +27,7 @@ class CopyCommandParserTest extends FlatSpec { emptyAsNull = false, delimiter = '|', nullAs = "\u000e", + ignoreHeader = 0, compression = FileCompressionParameter.None ) @@ -193,6 +194,32 @@ class CopyCommandParserTest extends FlatSpec { assert(new CopyCommandParser().parse(command).map(_.nullAs) == Some("\u000e")) } + it should "parse 'IGNOREHEADER AS' from COPY command" in { + val command = + s""" + |COPY "public"."mytable" + |FROM '${Global.s3Scheme}some-bucket/path/to/unloaded_manifest.json' + |CREDENTIALS 'aws_access_key_id=AKIAXXXXXXXXXXXXXXX;aws_secret_access_key=YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY' + |IGNOREHEADER AS '1' + |MANIFEST + |""".stripMargin + + assert(new CopyCommandParser().parse(command).map(_.ignoreHeader) == Some(1)) + } + + it should "set default 'IGNOREHEADER AS' correctly" in { + val command = + s""" + |COPY "public"."mytable" + |FROM '${Global.s3Scheme}some-bucket/path/to/unloaded_manifest.json' + |CREDENTIALS 'aws_access_key_id=AKIAXXXXXXXXXXXXXXX;aws_secret_access_key=YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY' + |FORMAT AS CSV + |MANIFEST + |""".stripMargin + + assert(new CopyCommandParser().parse(command).map(_.ignoreHeader) == Some(0)) + } + it should "parse aws_role_arn from COPY command" in { val command = s""" @@ -212,6 +239,7 @@ class CopyCommandParserTest extends FlatSpec { emptyAsNull = false, delimiter = '|', nullAs = "\u000e", + ignoreHeader = 0, compression = FileCompressionParameter.None ) @@ -242,6 +270,7 @@ class CopyCommandParserTest extends FlatSpec { emptyAsNull = false, delimiter = '|', nullAs = "\u000e", + ignoreHeader = 0, compression = FileCompressionParameter.None ) @@ -269,6 +298,7 @@ class CopyCommandParserTest extends FlatSpec { emptyAsNull = false, delimiter = '|', nullAs = "\u000e", + ignoreHeader = 0, compression = FileCompressionParameter.None ) @@ -294,6 +324,7 @@ class CopyCommandParserTest extends FlatSpec { emptyAsNull = false, delimiter = '|', nullAs = "\u000e", + ignoreHeader = 0, compression = FileCompressionParameter.None ) diff --git a/src/test/scala/jp/ne/opt/redshiftfake/parse/UnloadCommandParserTest.scala b/src/test/scala/jp/ne/opt/redshiftfake/parse/UnloadCommandParserTest.scala index 44be234..f673b7a 100644 --- a/src/test/scala/jp/ne/opt/redshiftfake/parse/UnloadCommandParserTest.scala +++ b/src/test/scala/jp/ne/opt/redshiftfake/parse/UnloadCommandParserTest.scala @@ -21,7 +21,8 @@ class UnloadCommandParserTest extends FlatSpec { ), createManifest = false, delimiter = '|', - addQuotes = false + addQuotes = false, + header = false ) assert(new UnloadCommandParser().parse(command) == Some(expected)) @@ -88,7 +89,8 @@ class UnloadCommandParserTest extends FlatSpec { credentials = Credentials.WithRole("arn:aws:iam::12345:role/some-role"), createManifest = false, delimiter = '|', - addQuotes = false + addQuotes = false, + header = false ) assert(new UnloadCommandParser().parse(command) == Some(expected)) @@ -109,7 +111,8 @@ class UnloadCommandParserTest extends FlatSpec { credentials = Credentials.WithTemporaryToken("some_access_key_id", "some_secret_access_key", "some_session_token"), createManifest = false, delimiter = '|', - addQuotes = false + addQuotes = false, + header = false ) assert(new UnloadCommandParser().parse(command) == Some(expected)) @@ -132,7 +135,8 @@ class UnloadCommandParserTest extends FlatSpec { credentials = Credentials.WithTemporaryToken("some_access_key_id", "some_secret_access_key", someSessionToken), createManifest = false, delimiter = '|', - addQuotes = false + addQuotes = false, + header = false ) assert(new UnloadCommandParser().parse(command) == Some(expected)) @@ -152,7 +156,8 @@ class UnloadCommandParserTest extends FlatSpec { credentials = Credentials.WithRole("arn:aws:iam::12345:role/some-role"), createManifest = false, delimiter = '|', - addQuotes = false + addQuotes = false, + header = false ) assert(new UnloadCommandParser().parse(command) == Some(expected))