Skip to content

Commit

Permalink
Refactor AVG and SUM non-dynamic to separate impls
Browse files Browse the repository at this point in the history
  • Loading branch information
alancai98 committed Jan 24, 2025
1 parent c9259e4 commit 0b4a7a9
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,35 @@ package org.partiql.spi.function.builtins

import org.partiql.spi.function.Aggregation
import org.partiql.spi.function.Parameter
import org.partiql.spi.function.builtins.internal.AccumulatorAvg
import org.partiql.spi.function.builtins.internal.AccumulatorAvgDecimal
import org.partiql.spi.function.builtins.internal.AccumulatorAvgDouble
import org.partiql.spi.function.builtins.internal.AccumulatorAvgDynamic
import org.partiql.spi.types.PType

// TODO: This needs to be formalized. See https://github.com/partiql/partiql-lang-kotlin/issues/1659
/**
* TODO: This needs to be formalized. See https://github.com/partiql/partiql-lang-kotlin/issues/1659
* Return types are mostly implementation-defined. Follows what postgresql does for the non-dynamic cases.
*
* Return type for tinyint, smallint, integer, bigint, decimal, numeric -> decimal
* Return type for float and double precision -> double precision
* Return type for dynamic:
* - if all values are exact numeric -> decimal
* - otherwise -> double precision
*/
private val AVG_DECIMAL = DefaultDecimal.DECIMAL

internal val Agg_AVG__INT8__INT8 = Aggregation.static(
name = "avg",
returns = AVG_DECIMAL,
parameters = arrayOf(Parameter("value", PType.tinyint())),
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
accumulator = ::AccumulatorAvgDecimal,
)

internal val Agg_AVG__INT16__INT16 = Aggregation.static(
name = "avg",
returns = AVG_DECIMAL,
parameters = arrayOf(Parameter("value", PType.smallint())),
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
accumulator = ::AccumulatorAvgDecimal,
)

internal val Agg_AVG__INT32__INT32 = Aggregation.static(
Expand All @@ -32,7 +43,7 @@ internal val Agg_AVG__INT32__INT32 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.integer()),
),
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
accumulator = ::AccumulatorAvgDecimal,
)

internal val Agg_AVG__INT64__INT64 = Aggregation.static(
Expand All @@ -42,7 +53,7 @@ internal val Agg_AVG__INT64__INT64 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.bigint()),
),
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
accumulator = ::AccumulatorAvgDecimal,
)

internal val Agg_AVG__NUMERIC__NUMERIC = Aggregation.static(
Expand All @@ -52,7 +63,7 @@ internal val Agg_AVG__NUMERIC__NUMERIC = Aggregation.static(
parameters = arrayOf(
Parameter("value", DefaultNumeric.NUMERIC),
),
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
accumulator = ::AccumulatorAvgDecimal,
)

internal val Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Aggregation.static(
Expand All @@ -62,17 +73,17 @@ internal val Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Aggregation.static(
parameters = arrayOf(
Parameter("value", AVG_DECIMAL),
),
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
accumulator = ::AccumulatorAvgDecimal,
)

internal val Agg_AVG__FLOAT32__FLOAT32 = Aggregation.static(

name = "avg",
returns = PType.real(),
returns = PType.doublePrecision(),
parameters = arrayOf(
Parameter("value", PType.real()),
),
accumulator = { AccumulatorAvg(PType.doublePrecision()) },
accumulator = ::AccumulatorAvgDouble,
)

internal val Agg_AVG__FLOAT64__FLOAT64 = Aggregation.static(
Expand All @@ -82,7 +93,7 @@ internal val Agg_AVG__FLOAT64__FLOAT64 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.doublePrecision()),
),
accumulator = { AccumulatorAvg(PType.doublePrecision()) },
accumulator = ::AccumulatorAvgDouble,
)

internal val Agg_AVG__ANY__ANY = Aggregation.static(
Expand All @@ -92,5 +103,5 @@ internal val Agg_AVG__ANY__ANY = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.dynamic()),
),
accumulator = { AccumulatorAvg(PType.dynamic()) },
accumulator = ::AccumulatorAvgDynamic,
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,33 @@ package org.partiql.spi.function.builtins

import org.partiql.spi.function.Aggregation
import org.partiql.spi.function.Parameter
import org.partiql.spi.function.builtins.internal.AccumulatorSum
import org.partiql.spi.function.builtins.internal.AccumulatorSumBigInt
import org.partiql.spi.function.builtins.internal.AccumulatorSumDecimal
import org.partiql.spi.function.builtins.internal.AccumulatorSumDouble
import org.partiql.spi.function.builtins.internal.AccumulatorSumDynamic
import org.partiql.spi.types.PType

/**
* TODO: This needs to be formalized. See https://github.com/partiql/partiql-lang-kotlin/issues/1659
* Return types are mostly implementation-defined. Follows what postgresql does for the non-dynamic cases.
*
* Return type for tinyint, smalllint, integer -> bigint
* Return type for bigint, decimal -> decimal
* Return type for numeric -> numeric
* Return type for float and double precision -> double precision
* Return type for dynamic:
* - if all values are integer or smaller -> bigint
* - if all values are exact numeric (all integral + decimal/numeric) -> decimal
* - otherwise -> double precision
*/

internal val Agg_SUM__INT8__INT8 = Aggregation.static(
name = "sum",
returns = PType.bigint(),
parameters = arrayOf(
Parameter("value", PType.tinyint()),
),
accumulator = { AccumulatorSum(PType.bigint()) },
accumulator = ::AccumulatorSumBigInt
)

internal val Agg_SUM__INT16__INT16 = Aggregation.static(
Expand All @@ -23,7 +40,7 @@ internal val Agg_SUM__INT16__INT16 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.smallint()),
),
accumulator = { AccumulatorSum(PType.bigint()) },
accumulator = ::AccumulatorSumBigInt
)

internal val Agg_SUM__INT32__INT32 = Aggregation.static(
Expand All @@ -32,7 +49,7 @@ internal val Agg_SUM__INT32__INT32 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.integer()),
),
accumulator = { AccumulatorSum(PType.bigint()) },
accumulator = ::AccumulatorSumBigInt
)

internal val Agg_SUM__INT64__INT64 = Aggregation.static(
Expand All @@ -41,7 +58,7 @@ internal val Agg_SUM__INT64__INT64 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.bigint())
),
accumulator = { AccumulatorSum(DefaultDecimal.DECIMAL) },
accumulator = { AccumulatorSumDecimal(DefaultDecimal.DECIMAL) },
)

internal val Agg_SUM__NUMERIC__NUMERIC = Aggregation.static(
Expand All @@ -50,16 +67,16 @@ internal val Agg_SUM__NUMERIC__NUMERIC = Aggregation.static(
parameters = arrayOf(
Parameter("value", DefaultNumeric.NUMERIC),
),
accumulator = { AccumulatorSum(DefaultNumeric.NUMERIC) },
accumulator = { AccumulatorSumDecimal(DefaultNumeric.NUMERIC) },
)

internal val Agg_SUM__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Aggregation.static(
name = "sum",
returns = DefaultDecimal.DECIMAL,
parameters = arrayOf(
Parameter("value", PType.decimal(38, 19)), // TODO: Rewrite aggregations using new function modeling.
Parameter("value", DefaultDecimal.DECIMAL), // TODO: Rewrite aggregations using new function modeling.
),
accumulator = { AccumulatorSum(DefaultDecimal.DECIMAL) },
accumulator = { AccumulatorSumDecimal(DefaultDecimal.DECIMAL) },
)

internal val Agg_SUM__FLOAT32__FLOAT32 = Aggregation.static(
Expand All @@ -68,7 +85,7 @@ internal val Agg_SUM__FLOAT32__FLOAT32 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.real())
),
accumulator = { AccumulatorSum(PType.real()) },
accumulator = { AccumulatorSumDouble() },
)

internal val Agg_SUM__FLOAT64__FLOAT64 = Aggregation.static(
Expand All @@ -77,7 +94,7 @@ internal val Agg_SUM__FLOAT64__FLOAT64 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.doublePrecision()),
),
accumulator = { AccumulatorSum(PType.doublePrecision()) },
accumulator = { AccumulatorSumDouble() },
)

internal val Agg_SUM__ANY__ANY = Aggregation.static(
Expand All @@ -86,5 +103,5 @@ internal val Agg_SUM__ANY__ANY = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.dynamic()),
),
accumulator = ::AccumulatorSum,
accumulator = ::AccumulatorSumDynamic,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import org.partiql.spi.function.Function
import org.partiql.spi.function.Parameter
import org.partiql.spi.function.builtins.internal.Accumulator
import org.partiql.spi.function.builtins.internal.AccumulatorAnySome
import org.partiql.spi.function.builtins.internal.AccumulatorAvg
import org.partiql.spi.function.builtins.internal.AccumulatorAvgDynamic
import org.partiql.spi.function.builtins.internal.AccumulatorCount
import org.partiql.spi.function.builtins.internal.AccumulatorDistinct
import org.partiql.spi.function.builtins.internal.AccumulatorEvery
import org.partiql.spi.function.builtins.internal.AccumulatorMax
import org.partiql.spi.function.builtins.internal.AccumulatorMin
import org.partiql.spi.function.builtins.internal.AccumulatorSum
import org.partiql.spi.function.builtins.internal.AccumulatorSumDynamic
import org.partiql.spi.types.PType
import org.partiql.spi.utils.FunctionUtils
import org.partiql.spi.value.Datum
Expand Down Expand Up @@ -49,13 +49,13 @@ internal abstract class Fn_COLL_AGG__BAG__ANY(
}
}

object SUM_ALL : Fn_COLL_AGG__BAG__ANY("coll_sum_all", false, ::AccumulatorSum)
object SUM_ALL : Fn_COLL_AGG__BAG__ANY("coll_sum_all", false, ::AccumulatorSumDynamic)

object SUM_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_sum_distinct", true, ::AccumulatorSum)
object SUM_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_sum_distinct", true, ::AccumulatorSumDynamic)

object AVG_ALL : Fn_COLL_AGG__BAG__ANY("coll_avg_all", false, ::AccumulatorAvg)
object AVG_ALL : Fn_COLL_AGG__BAG__ANY("coll_avg_all", false, ::AccumulatorAvgDynamic)

object AVG_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_avg_distinct", true, ::AccumulatorAvg)
object AVG_DISTINCT : Fn_COLL_AGG__BAG__ANY("coll_avg_distinct", true, ::AccumulatorAvgDynamic)

object MIN_ALL : Fn_COLL_AGG__BAG__ANY("coll_min_all", false, ::AccumulatorMin)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,62 +1,88 @@
package org.partiql.spi.function.builtins.internal

import org.partiql.spi.function.builtins.DefaultDecimal
import org.partiql.spi.types.PType
import org.partiql.spi.utils.FunctionUtils.checkIsNumberType
import org.partiql.spi.utils.FunctionUtils.nullToTargetType
import org.partiql.spi.utils.NumberUtils.AccumulatorType
import org.partiql.spi.utils.NumberUtils.MATH_CONTEXT
import org.partiql.spi.utils.NumberUtils.add
import org.partiql.spi.utils.NumberUtils.bigDecimalOf
import org.partiql.spi.utils.NumberUtils.numberValue
import org.partiql.spi.utils.NumberUtils.toTargetType
import org.partiql.spi.value.Datum
import java.math.BigDecimal

// TODO docs + further cleanup
internal class AccumulatorAvg(
private val targetType: PType = PType.dynamic(),
) : Accumulator() {
internal class AccumulatorAvgDecimal : Accumulator() {
private var sum: BigDecimal = BigDecimal.ZERO
private var count: Long = 0L
private var init = false

override fun nextValue(value: Datum) {
checkIsNumberType(funcName = "AVG", value = value)
if (!init) {
init = true
}
val arg1 = bigDecimalOf(value.numberValue(), MATH_CONTEXT)
sum = sum.add(arg1, MATH_CONTEXT)
count += 1L
}

override fun value(): Datum = when (count) {
0L -> Datum.nullValue(DefaultDecimal.DECIMAL)
else -> Datum.decimal(bigDecimalOf(sum).divide(bigDecimalOf(count), MATH_CONTEXT))
}
}

internal class AccumulatorAvgDouble : Accumulator() {
private var sum: Double = 0.0
private var count: Long = 0L
private var init = false

override fun nextValue(value: Datum) {
checkIsNumberType(funcName = "AVG", value = value)
val arg1 = value.double
if (!init) {
init = true
}
sum += arg1
count += 1L
}

override fun value(): Datum = when (count) {
0L -> Datum.nullValue(PType.doublePrecision())
else -> Datum.doublePrecision(sum / count.toDouble())
}
}

internal class AccumulatorAvgDynamic : Accumulator() {
private var sum: Number? = null
private var count: Long = 0L
private var dynamicSumType: PType? = targetType
private var accumulatorType: AccumulatorType? = null

override fun nextValue(value: Datum) {
checkIsNumberType(funcName = "AVG", value = value)
when (targetType.code()) {
PType.DECIMAL -> {
if (sum == null) {
sum = BigDecimal.ZERO
if (sum == null) {
sum = when (value.type.code()) {
PType.REAL, PType.DOUBLE -> {
accumulatorType = AccumulatorType.APPROX
0.0
}
}
PType.DOUBLE -> {
if (sum == null) {
sum = 0.0
else -> {
accumulatorType = AccumulatorType.DECIMAL
BigDecimal.ZERO
}
}
PType.DYNAMIC -> if (sum == null) {
dynamicSumType = when (value.type.code()) {
PType.REAL, PType.DOUBLE -> {
sum = BigDecimal.ZERO
PType.doublePrecision()
}
PType.TINYINT, PType.SMALLINT, PType.INTEGER, PType.BIGINT, PType.DECIMAL, PType.NUMERIC -> {
sum = BigDecimal.ZERO
PType.decimal()
}
else -> error("Unexpected type: ${value.type}")
}
} else {
when (value.type.code()) {
PType.REAL, PType.DOUBLE -> {
dynamicSumType = PType.doublePrecision()
}
}
} else {
if (value.type.code() == PType.REAL || value.type.code() == PType.DOUBLE) {
accumulatorType = AccumulatorType.APPROX
}
}
sum = add(sum!!, value, dynamicSumType!!)
sum = add(sum!!, value, accumulatorType!!)
count += 1L
}

override fun value(): Datum = when (count) {
0L -> nullToTargetType(targetType)
0L -> Datum.nullValue(PType.dynamic())
else -> {
when (sum) {
is BigDecimal -> {
Expand All @@ -66,7 +92,7 @@ internal class AccumulatorAvg(
(sum!!.toDouble()) / count.toDouble()
}
else -> error("Sum should be BigDecimal or Double")
}.toTargetType(targetType)
}.toTargetType(PType.dynamic())
}
}
}
Loading

0 comments on commit 0b4a7a9

Please sign in to comment.