Skip to content

Commit

Permalink
Adds Datum#row() (#1729)
Browse files Browse the repository at this point in the history
Co-authored-by: Alan Cai <caialan@amazon.com>
  • Loading branch information
johnedquinn and alancai98 authored Jan 24, 2025
1 parent 8c5f0b4 commit da32279
Show file tree
Hide file tree
Showing 12 changed files with 272 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.partiql.eval.internal.helpers

import org.partiql.eval.internal.operator.rex.CastTable
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum
import java.math.BigInteger
Expand Down Expand Up @@ -39,11 +40,27 @@ internal object ValueUtility {
return this.lower().check(type)
}
if (!this.isNull) {
throw PErrors.unexpectedTypeException(type, listOf(this.type))
throw PErrors.unexpectedTypeException(this.type, listOf(type))
}
return Datum.nullValue(type)
}

/**
* Specifically checks for struct, or coerce rows to structs. Same functionality as [check].
*/
fun Datum.checkStruct(): Datum {
if (this.type.code() == PType.VARIANT) {
return this.lower().checkStruct()
}
if (this.type.code() == PType.STRUCT) {
return this
}
if (this.type.code() == PType.ROW) {
return CastTable.cast(this, PType.struct())
}
return this.check(PType.struct())
}

/**
* Returns the underlying string value of a PartiQL text value
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.partiql.spi.types.PType.DYNAMIC
import org.partiql.spi.types.PType.INTEGER
import org.partiql.spi.types.PType.NUMERIC
import org.partiql.spi.types.PType.REAL
import org.partiql.spi.types.PType.ROW
import org.partiql.spi.types.PType.SMALLINT
import org.partiql.spi.types.PType.STRING
import org.partiql.spi.types.PType.STRUCT
Expand Down Expand Up @@ -101,6 +102,7 @@ internal object CastTable {
registerReal()
registerDoublePrecision()
registerStruct()
registerRow()
registerString()
registerBag()
registerList()
Expand Down Expand Up @@ -415,6 +417,15 @@ internal object CastTable {
*/
private fun registerStruct() {
register(STRUCT, STRUCT) { x, _ -> x }
register(STRUCT, ROW) { x, _ -> Datum.row(x.fields.asSequence().toList()) }
}

/**
* CAST(<row> AS <target>)
*/
private fun registerRow() {
register(ROW, STRUCT) { x, _ -> Datum.struct(x.fields.asSequence().asIterable()) }
register(ROW, ROW) { x, _ -> x }
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.partiql.eval.Environment
import org.partiql.eval.ExprValue
import org.partiql.eval.internal.helpers.PErrors
import org.partiql.eval.internal.helpers.ValueUtility.check
import org.partiql.eval.internal.helpers.ValueUtility.checkStruct
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum

Expand All @@ -13,7 +14,7 @@ internal class ExprPathKey(
) : ExprValue {

override fun eval(env: Environment): Datum {
val rootEvaluated = root.eval(env).check(PType.struct())
val rootEvaluated = root.eval(env).checkStruct()
val keyEvaluated = key.eval(env).check(PType.string())
if (rootEvaluated.isNull || keyEvaluated.isNull) {
return Datum.nullValue()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ package org.partiql.eval.internal.operator.rex
import org.partiql.eval.Environment
import org.partiql.eval.ExprValue
import org.partiql.eval.internal.helpers.PErrors
import org.partiql.eval.internal.helpers.ValueUtility.check
import org.partiql.spi.types.PType
import org.partiql.eval.internal.helpers.ValueUtility.checkStruct
import org.partiql.spi.value.Datum

internal class ExprPathSymbol(
Expand All @@ -13,7 +12,7 @@ internal class ExprPathSymbol(
) : ExprValue {

override fun eval(env: Environment): Datum {
val struct = root.eval(env).check(PType.struct())
val struct = root.eval(env).checkStruct()
if (struct.isNull) {
return Datum.nullValue()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.partiql.eval.internal.operator.rex

import org.partiql.eval.Environment
import org.partiql.eval.ExprValue
import org.partiql.eval.internal.helpers.ValueUtility.check
import org.partiql.eval.internal.helpers.ValueUtility.checkStruct
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum

Expand All @@ -12,7 +12,7 @@ internal class ExprSpread(

override fun eval(env: Environment): Datum {
val tuples = args.map {
it.eval(env).check(PType.struct())
it.eval(env).checkStruct()
}

// Return NULL if any arguments are NULL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import org.partiql.eval.Environment
import org.partiql.eval.ExprRelation
import org.partiql.eval.ExprValue
import org.partiql.eval.internal.helpers.PErrors
import org.partiql.eval.internal.helpers.ValueUtility.check
import org.partiql.spi.types.PType
import org.partiql.eval.internal.helpers.ValueUtility.checkStruct
import org.partiql.spi.value.Datum

/**
Expand All @@ -20,11 +19,6 @@ internal class ExprSubquery(input: ExprRelation, constructor: ExprValue) :
private var _input = input
private var _constructor = constructor

private companion object {
@JvmStatic
private val STRUCT = PType.struct()
}

/**
* TODO simplify
*/
Expand Down Expand Up @@ -55,7 +49,7 @@ internal class ExprSubquery(input: ExprRelation, constructor: ExprValue) :
return null
}
val firstRecord = _input.next()
val tuple = _constructor.eval(env.push(firstRecord)).check(STRUCT)
val tuple = _constructor.eval(env.push(firstRecord)).checkStruct()
if (_input.hasNext()) {
_input.close()
throw PErrors.cardinalityViolationException()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ import org.partiql.eval.ExprRelation
import org.partiql.eval.ExprValue
import org.partiql.eval.internal.helpers.IteratorSupplier
import org.partiql.eval.internal.helpers.PErrors
import org.partiql.eval.internal.helpers.ValueUtility.check
import org.partiql.spi.types.PType
import org.partiql.eval.internal.helpers.ValueUtility.checkStruct
import org.partiql.spi.value.Datum

/**
Expand All @@ -19,12 +18,6 @@ internal class ExprSubqueryRow(input: ExprRelation, constructor: ExprValue) :
private var _input = input
private var _constructor = constructor

private companion object {

@JvmStatic
private val STRUCT = PType.struct()
}

override fun eval(env: Environment): Datum {
val tuple = getFirst(env) ?: return Datum.nullValue()
val values = IteratorSupplier { tuple.fields }.map { it.value }
Expand All @@ -41,7 +34,7 @@ internal class ExprSubqueryRow(input: ExprRelation, constructor: ExprValue) :
return null
}
val firstRecord = _input.next()
val tuple = _constructor.eval(env.push(firstRecord)).check(STRUCT)
val tuple = _constructor.eval(env.push(firstRecord)).checkStruct()
if (_input.hasNext()) {
_input.close()
throw PErrors.cardinalityViolationException()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.partiql.eval.Mode
import org.partiql.eval.compiler.PartiQLCompiler
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum
import org.partiql.spi.value.Field
import org.partiql.value.PartiQLValue
import org.partiql.value.bagValue
import org.partiql.value.boolValue
Expand Down Expand Up @@ -1408,6 +1409,42 @@ class PartiQLEvaluatorTest {
)
}

@Test
fun proveThatRowWorksWhenDynamic() {
val tc =
SuccessTestCase(
input = "t.a = 3",
expected = Datum.bool(true),
mode = Mode.STRICT(),
globals = listOf(
Global(
name = "t",
type = PType.dynamic(),
value = Datum.row(Field.of("a", Datum.integer(3)))
)
)
)
tc.run()
}

@Test
fun proveThatRowWorks() {
val tc =
SuccessTestCase(
input = "t.a = 3",
expected = Datum.bool(true),
mode = Mode.STRICT(),
globals = listOf(
Global(
name = "t",
type = PType.row(org.partiql.spi.types.Field.of("a", PType.integer())),
value = Datum.row(Field.of("a", Datum.integer(3)))
),
)
)
tc.run()
}

@Test
// @Disabled
fun developmentTest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.partiql.spi.catalog.Catalog
import org.partiql.spi.catalog.Name
import org.partiql.spi.catalog.Session
import org.partiql.spi.catalog.Table
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum
import org.partiql.spi.value.DatumReader
import org.partiql.spi.value.ValueUtils
Expand All @@ -23,9 +24,15 @@ import kotlin.test.assertEquals
*/
class Global(
val name: String,
val value: String,
val type: StaticType = StaticType.ANY,
)
val value: Datum,
val type: PType = PType.dynamic(),
) {
constructor(
name: String,
value: String,
type: StaticType = StaticType.ANY,
) : this(name, DatumReader.ion(value.byteInputStream()).next()!!, fromStaticType(type))
}

public class SuccessTestCase(
val input: String,
Expand Down Expand Up @@ -56,8 +63,8 @@ public class SuccessTestCase(
globals.forEach {
val table = Table.standard(
name = Name.of(it.name),
schema = fromStaticType(it.type),
datum = DatumReader.ion(it.value.byteInputStream()).next()!!
schema = it.type,
datum = it.value
)
define(table)
}
Expand Down Expand Up @@ -112,8 +119,8 @@ public class FailureTestCase(
globals.forEach {
val table = Table.standard(
name = Name.of(it.name),
schema = fromStaticType(it.type),
datum = DatumReader.ion(it.value.byteInputStream()).next()!!
schema = it.type,
datum = it.value
)
define(table)
}
Expand Down
6 changes: 6 additions & 0 deletions partiql-spi/api/partiql-spi.api
Original file line number Diff line number Diff line change
Expand Up @@ -590,10 +590,16 @@ public abstract interface class org/partiql/spi/value/Datum : java/lang/Iterable
public static fun numeric (Ljava/math/BigDecimal;II)Lorg/partiql/spi/value/Datum;
public fun pack (Ljava/nio/charset/Charset;)[B
public static fun real (F)Lorg/partiql/spi/value/Datum;
public static fun row ()Lorg/partiql/spi/value/Datum;
public static fun row (Ljava/util/List;)Lorg/partiql/spi/value/Datum;
public static fun row (Ljava/util/List;Ljava/util/List;)Lorg/partiql/spi/value/Datum;
public static fun row (Ljava/util/List;[Lorg/partiql/spi/value/Field;)Lorg/partiql/spi/value/Datum;
public static fun row ([Lorg/partiql/spi/value/Field;)Lorg/partiql/spi/value/Datum;
public static fun smallint (S)Lorg/partiql/spi/value/Datum;
public static fun string (Ljava/lang/String;)Lorg/partiql/spi/value/Datum;
public static fun struct ()Lorg/partiql/spi/value/Datum;
public static fun struct (Ljava/lang/Iterable;)Lorg/partiql/spi/value/Datum;
public static fun struct ([Lorg/partiql/spi/value/Field;)Lorg/partiql/spi/value/Datum;
public static fun time (Ljava/time/LocalTime;I)Lorg/partiql/spi/value/Datum;
public static fun timestamp (Ljava/time/LocalDateTime;I)Lorg/partiql/spi/value/Datum;
public static fun timestampz (Ljava/time/OffsetDateTime;I)Lorg/partiql/spi/value/Datum;
Expand Down
65 changes: 65 additions & 0 deletions partiql-spi/src/main/java/org/partiql/spi/value/Datum.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
import java.math.RoundingMode;
import java.nio.charset.Charset;
import java.time.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

/**
* This is a representation of a value in PartiQL's type system. The intention of this modeling is to
Expand Down Expand Up @@ -691,6 +695,15 @@ static Datum struct() {
return struct(Collections.emptyList());
}

/**
* @param values the backing values
* @return a value of type {@link PType#STRUCT}
*/
@NotNull
static Datum struct(@NotNull Field... values) {
return new DatumStruct(Arrays.stream(values).collect(Collectors.toList()));
}

/**
* @param values the backing values
* @return a value of type {@link PType#STRUCT}
Expand All @@ -700,6 +713,58 @@ static Datum struct(@NotNull Iterable<Field> values) {
return new DatumStruct(values);
}

/**
* Returns an empty {@link PType#ROW}
* @return a value of type {@link PType#ROW}
*/
@NotNull
static Datum row() {
return new DatumRow(new ArrayList<>(), PType.row());
}

/**
* This creates a row.
* @param values the backing values
* @return a value of type {@link PType#ROW}
*/
@NotNull
static Datum row(@NotNull Field... values) {
return new DatumRow(Arrays.stream(values).collect(Collectors.toList()));
}

/**
* This creates a row. Use this if you'd like to save on the computational cost of computing the final type.
* @param typeFields the backing type fields
* @param values the backing values
* @return a value of type {@link PType#ROW}
*/
@NotNull
static Datum row(List<org.partiql.spi.types.Field> typeFields, @NotNull Field... values) {
return row(typeFields, Arrays.stream(values).collect(Collectors.toList()));
}

/**
* Creates a row with the given values.
* @param values the backing values
* @return a value of type {@link PType#ROW}
*/
@NotNull
static Datum row(@NotNull List<Field> values) {
return new DatumRow(values);
}

/**
* Creates a row with the given values. Use this if you'd like to save on the computational cost of computing the final type.
* @param typeFields the backing type fields
* @param values the backing values
* @return a value of type {@link PType#ROW}
*/
@NotNull
static Datum row(@NotNull List<org.partiql.spi.types.Field> typeFields, @NotNull List<Field> values) {
PType type = PType.row(typeFields);
return new DatumRow(values, type);
}

/**
* @param value the backing Ion
* @return a value of type {@link PType#VARIANT}
Expand Down
Loading

0 comments on commit da32279

Please sign in to comment.