Skip to content

Commit

Permalink
[api]: Improve Curve Fitting algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
tiagohm committed May 22, 2024
1 parent 29502bc commit 617bf57
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package nebulosa.curve.fitting

import org.apache.commons.math3.analysis.UnivariateFunction
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction

fun interface Curve : UnivariateFunction {

operator fun invoke(x: Double) = value(x)

companion object {

internal fun DoubleArray.curvePoints(): Collection<CurvePoint> {
@JvmStatic
fun DoubleArray.curvePoints(): Collection<CurvePoint> {
val points = ArrayList<CurvePoint>(size / 2)

for (i in indices step 2) {
Expand All @@ -18,8 +18,5 @@ fun interface Curve : UnivariateFunction {

return points
}

@Suppress("NOTHING_TO_INLINE")
internal inline fun DoubleArray.polynomial() = PolynomialFunction(this)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,4 @@ interface FittedCurve : Curve {
val minimum: CurvePoint

val rSquared: Double

val minimumX
get() = minimum.x

val minimumY
get() = minimum.y
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package nebulosa.curve.fitting

import org.apache.commons.math3.analysis.UnivariateFunction

interface PolynomialCurve : Curve {

val polynomial: UnivariateFunction

override fun value(x: Double) = polynomial.value(x)
}
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
package nebulosa.curve.fitting

import nebulosa.curve.fitting.Curve.Companion.polynomial
import org.apache.commons.math3.analysis.UnivariateFunction
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction
import org.apache.commons.math3.fitting.PolynomialCurveFitter

data object QuadraticFitting : CurveFitting<QuadraticFitting.Curve> {

data class Curve(
private val poly: UnivariateFunction,
override val polynomial: UnivariateFunction,
override val minimum: CurvePoint,
override val rSquared: Double,
) : FittedCurve {
) : FittedCurve, PolynomialCurve

override fun value(x: Double) = poly.value(x)
}

override fun calculate(points: Collection<CurvePoint>) = with(FITTER.fit(points).polynomial()) {
override fun calculate(points: Collection<CurvePoint>) = with(PolynomialFunction(FITTER.fit(points))) {
val rSquared = RSquared.calculate(points, this)
val minimumX = coefficients[1] / (-2.0 * coefficients[2])
val minimumY = value(minimumX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import kotlin.math.pow

object RSquared {

@JvmStatic
fun calculate(points: Collection<CurvePoint>, function: UnivariateFunction): Double {
val descriptiveStatistics = DescriptiveStatistics(points.size)
val predictedValues = DoubleArray(points.size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,16 @@ data class TrendLine(val points: Collection<CurvePoint>) : LinearCurve {
points.forEach { regression.addData(it.x, it.y) }
}

override val slope = regression.slope.zeroIfNaN()
override val slope = regression.slope.let { if (it.isNaN()) 0.0 else it }

override val intercept = regression.intercept.zeroIfNaN()
override val intercept = regression.intercept.let { if (it.isNaN()) 0.0 else it }

override val rSquared = regression.rSquare.zeroIfNaN()
override val rSquared = regression.rSquare.let { if (it.isNaN()) 0.0 else it }

override fun value(x: Double) = regression.predict(x)
override fun value(x: Double) = if (points.isEmpty()) 0.0 else regression.predict(x)

companion object {

@JvmStatic val ZERO = TrendLine()

@Suppress("NOTHING_TO_INLINE")
private inline fun Double.zeroIfNaN() = if (isNaN()) 0.0 else this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class HyperbolicFittingTest : StringSpec(), CurveFitting<HyperbolicFitting.Curve
6.0, 3.0, 7.0, 6.0, 8.0, 11.0, 9.0, 18.0,
)

curve.minimumX shouldBe (5.0 plusOrMinus 1e-12)
curve.minimumY shouldBe (1.2 plusOrMinus 1e-12)
curve.minimum.x shouldBe (5.0 plusOrMinus 1e-12)
curve.minimum.y shouldBe (1.2 plusOrMinus 1e-12)
}
"bad data:prevent infinit loop" {
shouldThrow<IllegalArgumentException> { calculate(1000.0, 18.0, 1100.0, 0.0, 1200.0, 0.0) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class QuadraticFittingTest : StringSpec(), CurveFitting<QuadraticFitting.Curve>
)

curve(5.0) shouldBeExactly 2.0
curve.minimumX shouldBe (5.0 plusOrMinus 1e-12)
curve.minimumY shouldBe (2.0 plusOrMinus 1e-12)
curve.minimum.x shouldBe (5.0 plusOrMinus 1e-12)
curve.minimum.y shouldBe (2.0 plusOrMinus 1e-12)
curve.rSquared shouldBe (1.0 plusOrMinus 1e-12)
}
}
Expand Down
2 changes: 1 addition & 1 deletion nebulosa-curve-fitting/src/test/kotlin/TrendLineTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class TrendLineTest : StringSpec() {

init {
"no points" {
val line = TrendLine()
val line = TrendLine.ZERO

line.slope shouldBeExactly 0.0
line.intercept shouldBeExactly 0.0
Expand Down

0 comments on commit 617bf57

Please sign in to comment.