Skip to content

Commit

Permalink
Handle Option correctly
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <honnix@users.noreply.github.com>
  • Loading branch information
honnix committed Nov 22, 2023
1 parent 8ec7a78 commit 288ce3e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ import org.flyte.flytekitscala.SdkLiteralTypes.{
}

// The constructor is reflectedly invoked so it cannot be an inner class
case class ScalarNested(foo: String, bar: String)
case class ScalarNested(
foo: String,
bar: Option[String],
nestedNested: Option[ScalarNestedNested]
)
case class ScalarNestedNested(foo: String, bar: Option[String])

class SdkScalaTypeTest {

Expand Down Expand Up @@ -178,7 +183,15 @@ class SdkScalaTypeTest {
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
"bar" -> Struct.Value.ofStringValue("bar"),
"nestedNested" -> Struct.Value.ofStructValue(
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
).asJava
)
)
).asJava
)
)
Expand All @@ -196,7 +209,11 @@ class SdkScalaTypeTest {
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
)

Expand All @@ -218,7 +235,11 @@ class SdkScalaTypeTest {
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
)

Expand All @@ -245,7 +266,15 @@ class SdkScalaTypeTest {
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
"bar" -> Struct.Value.ofStringValue("bar"),
"nestedNested" -> Struct.Value.ofStructValue(
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
).asJava
)
)
).asJava
)
)
Expand Down Expand Up @@ -285,7 +314,11 @@ class SdkScalaTypeTest {
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
)

Expand All @@ -301,7 +334,11 @@ class SdkScalaTypeTest {
"blob" -> SdkBindingDataFactory.of(blob),
"generic" -> SdkBindingDataFactory.of(
SdkLiteralTypes.generics[ScalarNested](),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
).asJava

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,41 +297,35 @@ object SdkLiteralTypes {
): S = {
val mirror = runtimeMirror(classTag[S].runtimeClass.getClassLoader)

def valueToParamValue(value: Any, param: Symbol): Any = {
def valueToParamValue0(value: Any, param: Symbol): Any = {
if (param.typeSignature =:= typeOf[Byte]) {
value.asInstanceOf[Double].toByte
} else if (param.typeSignature =:= typeOf[Short]) {
value.asInstanceOf[Double].toShort
} else if (param.typeSignature =:= typeOf[Int]) {
value.asInstanceOf[Double].toInt
} else if (param.typeSignature =:= typeOf[Long]) {
value.asInstanceOf[Double].toLong
} else if (param.typeSignature =:= typeOf[Float]) {
value.asInstanceOf[Double].toFloat
} else if (param.typeSignature <:< typeOf[Product]) {
val typeTag = createTypeTag(param.typeSignature)
val classTag = ClassTag(
typeTag.mirror.runtimeClass(param.typeSignature)
)
mapToProduct(value.asInstanceOf[Map[String, Any]])(
typeTag,
classTag
)
} else {
value
}
}

if (param.typeSignature <:< typeOf[Option[Any]]) {
def valueToParamValue(value: Any, tpe: Type): Any = {
if (tpe =:= typeOf[Byte]) {
value.asInstanceOf[Double].toByte
} else if (tpe =:= typeOf[Short]) {
value.asInstanceOf[Double].toShort
} else if (tpe =:= typeOf[Int]) {
value.asInstanceOf[Double].toInt
} else if (tpe =:= typeOf[Long]) {
value.asInstanceOf[Double].toLong
} else if (tpe =:= typeOf[Float]) {
value.asInstanceOf[Double].toFloat
} else if (tpe <:< typeOf[Option[Any]]) { // this has to be before Product check
Some(
valueToParamValue0(
valueToParamValue(
value,
param.typeSignature.dealias.typeArgs.head.typeSymbol
tpe.dealias.typeArgs.head
)
)
} else if (tpe <:< typeOf[Product]) {
val typeTag = createTypeTag(tpe)
val classTag = ClassTag(
typeTag.mirror.runtimeClass(tpe)
)
mapToProduct(value.asInstanceOf[Map[String, Any]])(
typeTag,
classTag
)
} else {
valueToParamValue0(value, param)
value
}
}

Expand Down Expand Up @@ -371,7 +365,7 @@ object SdkLiteralTypes {
s"Map is missing required parameter named $paramName"
)
)
valueToParamValue(value, param)
valueToParamValue(value, param.typeSignature.dealias)
})

constructorMirror(constructorArgs: _*).asInstanceOf[S]
Expand Down

0 comments on commit 288ce3e

Please sign in to comment.