Skip to content

Commit

Permalink
resolved merge conflict with master in SnappySessionState
Browse files Browse the repository at this point in the history
  • Loading branch information
hemanthmeka committed Aug 2, 2018
2 parents 4be42d3 + 5708864 commit 32f1dbd
Show file tree
Hide file tree
Showing 400 changed files with 6,332 additions and 1,708 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
package io.snappydata.cluster

import java.io.File
import java.math.BigDecimal
import java.sql.{Connection, DatabaseMetaData, DriverManager, ResultSet, SQLException, Statement}

import com.gemstone.gemfire.distributed.DistributedMember

import scala.collection.mutable
import scala.collection.JavaConverters._

import com.gemstone.gemfire.distributed.internal.membership.InternalDistributedMember
import com.gemstone.gemfire.internal.cache.PartitionedRegion
import com.pivotal.gemfirexd.internal.engine.Misc
Expand All @@ -34,7 +35,6 @@ import io.snappydata.test.dunit.{AvailablePortHelper, SerializableRunnable}
import junit.framework.TestCase
import org.apache.commons.io.FileUtils
import org.junit.Assert

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
Expand Down Expand Up @@ -858,6 +858,44 @@ class QueryRoutingDUnitTest(val s: String)
session.dropTable(table)
}

def testSNAP2247(): Unit = {
val serverHostPort = AvailablePortHelper.getRandomAvailableTCPPort
vm2.invoke(classOf[ClusterManagerTestBase], "startNetServer", serverHostPort)
val conn = DriverManager.getConnection(
"jdbc:snappydata://localhost:" + serverHostPort)
val st = conn.createStatement()
try {
val conn = DriverManager.getConnection(
"jdbc:snappydata://localhost:" + serverHostPort)

val st = conn.createStatement()
st.execute(s"create table trade.securities " +
s"(sec_id int not null, symbol varchar(10) not null, " +
s"price decimal (30, 20), exchange varchar(10) not null, " +
s"tid int, constraint sec_pk primary key (sec_id), " +
s"constraint sec_uq unique (symbol, exchange), constraint exc_ch check " +
s"(exchange in ('nasdaq', 'nye', 'amex', 'lse', 'fse', 'hkse', 'tse'))) " +
s"ENABLE CONCURRENCY CHECKS")

val ps = conn.prepareStatement(s"select price, symbol, exchange from trade.securities" +
s" where (price<? or price >=?) and tid =? order by CASE when exchange ='nasdaq'" +
s" then symbol END desc, CASE when exchange in('nye', 'amex') then sec_id END desc," +
s" CASE when exchange ='lse' then symbol END asc, CASE when exchange ='fse' then" +
s" sec_id END desc, CASE when exchange ='hkse' then symbol END asc," +
s" CASE when exchange ='tse' then symbol END desc")

ps.setBigDecimal(1, new BigDecimal("0.02"))
ps.setBigDecimal(2, new BigDecimal("20.02"))
ps.setInt(3, 3)

ps.execute()
assert(!ps.getResultSet.next())
} finally {
st.execute(s"drop table trade.securities")
conn.close()
}
}

def limitInsertRows(numRows: Int, serverHostPort: Int, tableName: String): Unit = {

val conn = DriverManager.getConnection(
Expand Down
36 changes: 17 additions & 19 deletions cluster/src/dunit/scala/org/apache/spark/sql/TPCHDUnitTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,14 @@ object TPCHUtils extends Logging {
def validateResult(snc: SQLContext, isSnappy: Boolean, isTokenization: Boolean = false): Unit = {
val sc: SparkContext = snc.sparkContext

val fineName = if (!isTokenization) {
val fileName = if (!isTokenization) {
if (isSnappy) "Result_Snappy.out" else "Result_Spark.out"
} else {
"Result_Snappy_Tokenization.out"
}

val resultFileStream: FileOutputStream = new FileOutputStream(new File(fineName))
val resultOutputStream: PrintStream = new PrintStream(resultFileStream)
val resultsLogFileStream: FileOutputStream = new FileOutputStream(new File(fileName))
val resultsLogStream: PrintStream = new PrintStream(resultsLogFileStream)

// scalastyle:off
for (query <- queries) {
Expand All @@ -366,24 +366,25 @@ object TPCHUtils extends Logging {
val expectedFile = sc.textFile(getClass.getResource(
s"/TPCH/RESULT/Snappy_$query.out").getPath)

val queryFileName = if (isSnappy) s"1_Snappy_$query.out" else s"1_Spark_$query.out"
val actualFile = sc.textFile(queryFileName)
//val queryFileName = if (isSnappy) s"1_Snappy_$query.out" else s"1_Spark_$query.out"
val queryResultsFileName = if (isSnappy) s"1_Snappy_Q${query}_Results.out" else s"1_Spark_Q${query}_Results.out"
val actualFile = sc.textFile(queryResultsFileName)

val expectedLineSet = expectedFile.collect().toList.sorted
val actualLineSet = actualFile.collect().toList.sorted

if (!actualLineSet.equals(expectedLineSet)) {
if (!(expectedLineSet.size == actualLineSet.size)) {
resultOutputStream.println(s"For $query " +
resultsLogStream.println(s"For $query " +
s"result count mismatched observed with " +
s"expected ${expectedLineSet.size} and actual ${actualLineSet.size}")
} else {
for ((expectedLine, actualLine) <- expectedLineSet zip actualLineSet) {
if (!expectedLine.equals(actualLine)) {
resultOutputStream.println(s"For $query result mismatched observed")
resultOutputStream.println(s"Expected : $expectedLine")
resultOutputStream.println(s"Found : $actualLine")
resultOutputStream.println(s"-------------------------------------")
resultsLogStream.println(s"For $query result mismatched observed")
resultsLogStream.println(s"Expected : $expectedLine")
resultsLogStream.println(s"Found : $actualLine")
resultsLogStream.println(s"-------------------------------------")
}
}
}
Expand All @@ -399,16 +400,16 @@ object TPCHUtils extends Logging {
val actualLineSet = secondRunFile.collect().toList.sorted

if (actualLineSet.equals(expectedLineSet)) {
resultOutputStream.println(s"For $query result matched observed")
resultOutputStream.println(s"-------------------------------------")
resultsLogStream.println(s"For $query result matched observed")
resultsLogStream.println(s"-------------------------------------")
}
}
}
// scalastyle:on
resultOutputStream.close()
resultFileStream.close()
resultsLogStream.close()
resultsLogFileStream.close()

val resultOutputFile = sc.textFile(fineName)
val resultOutputFile = sc.textFile(fileName)

if(!isTokenization) {
assert(resultOutputFile.count() == 0,
Expand All @@ -433,11 +434,8 @@ object TPCHUtils extends Logging {
fileName: String = ""): Unit = {
snc.sql(s"set spark.sql.crossJoin.enabled = true")

// queries.foreach(query => TPCH_Snappy.execute(query, snc,
// isResultCollection, isSnappy, warmup = warmup,
// runsForAverage = runsForAverage, avgPrintStream = System.out))
queries.foreach(query => QueryExecutor.execute(query, snc, isResultCollection,
isSnappy, isDynamic = isDynamic, warmup = warmup, runsForAverage = runsForAverage,
avgPrintStream = System.out))
avgTimePrintStream = System.out))
}
}
146 changes: 78 additions & 68 deletions cluster/src/main/scala/io/snappydata/gemxd/SparkSQLExecuteImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import com.pivotal.gemfirexd.internal.snappy.{LeadNodeExecutionContext, SparkSQL
import io.snappydata.{Constant, QueryHint}

import org.apache.spark.serializer.{KryoSerializerPool, StructTypeSerializer}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.collection.Utils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -88,17 +89,9 @@ class SparkSQLExecuteImpl(val sql: String,
private[this] lazy val colTypes = getColumnTypes

// check for query hint to serialize complex types as JSON strings
private[this] val complexTypeAsJson = session.getPreviousQueryHints.get(
QueryHint.ComplexTypeAsJson.toString) match {
case null => true
case v => Misc.parseBoolean(v)
}
private[this] val complexTypeAsJson = SparkSQLExecuteImpl.getJsonProperties(session)

private val (allAsClob, columnsAsClob) = session.getPreviousQueryHints.get(
QueryHint.ColumnsAsClob.toString) match {
case null => (false, Set.empty[String])
case v => Utils.parseColumnsAsClob(v)
}
private val (allAsClob, columnsAsClob) = SparkSQLExecuteImpl.getClobProperties(session)

override def packRows(msg: LeadNodeExecutorMsg,
snappyResultHolder: SnappyResultHolder): Unit = {
Expand All @@ -121,7 +114,8 @@ class SparkSQLExecuteImpl(val sql: String,
CachedDataFrame.localBlockStoreResultHandler(rddId, bm),
CachedDataFrame.localBlockStoreDecoder(querySchema.length, bm))
hdos.clearForReuse()
writeMetaData(srh)
SparkSQLExecuteImpl.writeMetaData(srh, hdos, tableNames, nullability, getColumnNames,
getColumnTypes, getColumnDataTypes)

var id = 0
for (block <- partitionBlocks) {
Expand Down Expand Up @@ -191,77 +185,53 @@ class SparkSQLExecuteImpl(val sql: String,
override def serializeRows(out: DataOutput, hasMetadata: Boolean): Unit =
SparkSQLExecuteImpl.serializeRows(out, hasMetadata, hdos)

private lazy val (tableNames, nullability) = getTableNamesAndNullability

def getTableNamesAndNullability: (Array[String], Array[Boolean]) = {
var i = 0
val output = df.queryExecution.analyzed.output
val tables = new Array[String](output.length)
val nullables = new Array[Boolean](output.length)
output.foreach { a =>
val fn = a.qualifiedName
val dotIdx = fn.lastIndexOf('.')
if (dotIdx > 0) {
tables(i) = fn.substring(0, dotIdx)
} else {
tables(i) = ""
}
nullables(i) = a.nullable
i += 1
}
(tables, nullables)
}

private def writeMetaData(srh: SnappyResultHolder): Unit = {
val hdos = this.hdos
// indicates that the metadata is being packed too
srh.setHasMetadata()
DataSerializer.writeStringArray(tableNames, hdos)
DataSerializer.writeStringArray(getColumnNames, hdos)
DataSerializer.writeBooleanArray(nullability, hdos)
for (i <- colTypes.indices) {
val (tp, precision, scale) = colTypes(i)
InternalDataSerializer.writeSignedVL(tp, hdos)
tp match {
case StoredFormatIds.SQL_DECIMAL_ID =>
InternalDataSerializer.writeSignedVL(precision, hdos) // precision
InternalDataSerializer.writeSignedVL(scale, hdos) // scale
case StoredFormatIds.SQL_VARCHAR_ID |
StoredFormatIds.SQL_CHAR_ID =>
// Write the size as precision
InternalDataSerializer.writeSignedVL(precision, hdos)
case StoredFormatIds.REF_TYPE_ID =>
// Write the DataType
hdos.write(KryoSerializerPool.serialize((kryo, out) =>
StructTypeSerializer.writeType(kryo, out, querySchema(i).dataType)))
case _ => // ignore for others
}
}
}
private lazy val (tableNames, nullability) = SparkSQLExecuteImpl.
getTableNamesAndNullability(df.queryExecution.analyzed.output)

def getColumnNames: Array[String] = {
querySchema.fieldNames
}

private def getColumnTypes: Array[(Int, Int, Int)] =
querySchema.map(f => getSQLType(f)).toArray
querySchema.map(f => SparkSQLExecuteImpl.getSQLType(f.dataType, complexTypeAsJson,
Some(f.metadata), Some(f.name), Some(allAsClob), Some(columnsAsClob))).toArray

private def getColumnDataTypes: Array[DataType] =
querySchema.map(_.dataType).toArray
}

object SparkSQLExecuteImpl {

private def getSQLType(f: StructField): (Int, Int, Int) = {
val dataType = f.dataType
def getJsonProperties(session: SnappySession): Boolean = session.getPreviousQueryHints.get(
QueryHint.ComplexTypeAsJson.toString) match {
case null => true
case v => Misc.parseBoolean(v)
}

def getClobProperties(session: SnappySession): (Boolean, Set[String]) =
session.getPreviousQueryHints.get(QueryHint.ColumnsAsClob.toString) match {
case null => (false, Set.empty[String])
case v => Utils.parseColumnsAsClob(v)
}

def getSQLType(dataType: DataType, complexTypeAsJson: Boolean,
metaData: Option[Metadata] = None, metaName: Option[String] = None,
allAsClob: Option[Boolean] = None, columnsAsClob: Option[Set[String]] = None): (Int,
Int, Int) = {
dataType match {
case IntegerType => (StoredFormatIds.SQL_INTEGER_ID, -1, -1)
case StringType =>
case StringType if metaData.isDefined =>
TypeUtilities.getMetadata[String](Constant.CHAR_TYPE_BASE_PROP,
f.metadata) match {
metaData.get) match {
case Some(base) if base != "CLOB" =>
lazy val size = TypeUtilities.getMetadata[Long](
Constant.CHAR_TYPE_SIZE_PROP, f.metadata)
Constant.CHAR_TYPE_SIZE_PROP, metaData.get)
lazy val varcharSize = size.getOrElse(
Constant.MAX_VARCHAR_SIZE.toLong).toInt
lazy val charSize = size.getOrElse(
Constant.MAX_CHAR_SIZE.toLong).toInt
if (allAsClob ||
(columnsAsClob.nonEmpty && columnsAsClob.contains(f.name))) {
if (allAsClob.get ||
(columnsAsClob.get.nonEmpty && columnsAsClob.get.contains(metaName.get))) {
if (base != "STRING") {
if (base == "VARCHAR") {
(StoredFormatIds.SQL_VARCHAR_ID, varcharSize, -1)
Expand All @@ -282,6 +252,7 @@ class SparkSQLExecuteImpl(val sql: String,

case _ => (StoredFormatIds.SQL_CLOB_ID, -1, -1) // CLOB
}
case StringType => (StoredFormatIds.SQL_CLOB_ID, -1, -1) // CLOB
case LongType => (StoredFormatIds.SQL_LONGINT_ID, -1, -1)
case TimestampType => (StoredFormatIds.SQL_TIMESTAMP_ID, -1, -1)
case DateType => (StoredFormatIds.SQL_DATE_ID, -1, -1)
Expand All @@ -302,9 +273,48 @@ class SparkSQLExecuteImpl(val sql: String,
case _ => (StoredFormatIds.REF_TYPE_ID, -1, -1)
}
}
}

object SparkSQLExecuteImpl {
def getTableNamesAndNullability(output: Seq[expressions.Attribute]):
(Seq[String], Seq[Boolean]) = {
output.map { a =>
val fn = a.qualifiedName
val dotIdx = fn.lastIndexOf('.')
if (dotIdx > 0) {
(fn.substring(0, dotIdx), a.nullable)
} else {
("", a.nullable)
}
}.unzip
}

def writeMetaData(srh: SnappyResultHolder, hdos: GfxdHeapDataOutputStream,
tableNames: Seq[String], nullability: Seq[Boolean], columnNames: Seq[String],
colTypes: Seq[(Int, Int, Int)], dataTypes: Seq[DataType]): Unit = {
// indicates that the metadata is being packed too
srh.setHasMetadata()
DataSerializer.writeStringArray(tableNames.toArray, hdos)
DataSerializer.writeStringArray(columnNames.toArray, hdos)
DataSerializer.writeBooleanArray(nullability.toArray, hdos)
for (i <- colTypes.indices) {
val (tp, precision, scale) = colTypes(i)
InternalDataSerializer.writeSignedVL(tp, hdos)
tp match {
case StoredFormatIds.SQL_DECIMAL_ID =>
InternalDataSerializer.writeSignedVL(precision, hdos) // precision
InternalDataSerializer.writeSignedVL(scale, hdos) // scale
case StoredFormatIds.SQL_VARCHAR_ID |
StoredFormatIds.SQL_CHAR_ID =>
// Write the size as precision
InternalDataSerializer.writeSignedVL(precision, hdos)
case StoredFormatIds.REF_TYPE_ID =>
// Write the DataType
hdos.write(KryoSerializerPool.serialize((kryo, out) =>
StructTypeSerializer.writeType(kryo, out, dataTypes(i))))
case _ => // ignore for others
}
}
}

def getContextOrCurrentClassLoader: ClassLoader =
Option(Thread.currentThread().getContextClassLoader)
.getOrElse(getClass.getClassLoader)
Expand Down
Loading

0 comments on commit 32f1dbd

Please sign in to comment.