diff --git a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/Curve.kt b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/Curve.kt index 0d505e9ad..d0ef91549 100644 --- a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/Curve.kt +++ b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/Curve.kt @@ -1,7 +1,6 @@ package nebulosa.curve.fitting import org.apache.commons.math3.analysis.UnivariateFunction -import org.apache.commons.math3.analysis.polynomials.PolynomialFunction fun interface Curve : UnivariateFunction { @@ -9,7 +8,8 @@ fun interface Curve : UnivariateFunction { companion object { - internal fun DoubleArray.curvePoints(): Collection { + @JvmStatic + fun DoubleArray.curvePoints(): Collection { val points = ArrayList(size / 2) for (i in indices step 2) { @@ -18,8 +18,5 @@ fun interface Curve : UnivariateFunction { return points } - - @Suppress("NOTHING_TO_INLINE") - internal inline fun DoubleArray.polynomial() = PolynomialFunction(this) } } diff --git a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/FittedCurve.kt b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/FittedCurve.kt index 0fe2f0bf6..916e220bd 100644 --- a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/FittedCurve.kt +++ b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/FittedCurve.kt @@ -5,10 +5,4 @@ interface FittedCurve : Curve { val minimum: CurvePoint val rSquared: Double - - val minimumX - get() = minimum.x - - val minimumY - get() = minimum.y } diff --git a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/PolynomialCurve.kt b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/PolynomialCurve.kt new file mode 100644 index 000000000..ccc153551 --- /dev/null +++ b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/PolynomialCurve.kt @@ -0,0 +1,10 @@ +package nebulosa.curve.fitting + +import org.apache.commons.math3.analysis.UnivariateFunction + +interface PolynomialCurve : Curve { + + val polynomial: UnivariateFunction + + override fun value(x: Double) = polynomial.value(x) +} diff --git a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/QuadraticFitting.kt b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/QuadraticFitting.kt index f8dbf690c..8d0d30345 100644 --- a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/QuadraticFitting.kt +++ b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/QuadraticFitting.kt @@ -1,21 +1,18 @@ package nebulosa.curve.fitting -import nebulosa.curve.fitting.Curve.Companion.polynomial import org.apache.commons.math3.analysis.UnivariateFunction +import org.apache.commons.math3.analysis.polynomials.PolynomialFunction import org.apache.commons.math3.fitting.PolynomialCurveFitter data object QuadraticFitting : CurveFitting { data class Curve( - private val poly: UnivariateFunction, + override val polynomial: UnivariateFunction, override val minimum: CurvePoint, override val rSquared: Double, - ) : FittedCurve { + ) : FittedCurve, PolynomialCurve - override fun value(x: Double) = poly.value(x) - } - - override fun calculate(points: Collection) = with(FITTER.fit(points).polynomial()) { + override fun calculate(points: Collection) = with(PolynomialFunction(FITTER.fit(points))) { val rSquared = RSquared.calculate(points, this) val minimumX = coefficients[1] / (-2.0 * coefficients[2]) val minimumY = value(minimumX) diff --git a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/RSquared.kt b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/RSquared.kt index ae7b71479..cb2d70493 100644 --- a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/RSquared.kt +++ b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/RSquared.kt @@ -6,6 +6,7 @@ import kotlin.math.pow object RSquared { + @JvmStatic fun calculate(points: Collection, function: UnivariateFunction): Double { val descriptiveStatistics = DescriptiveStatistics(points.size) val predictedValues = DoubleArray(points.size) diff --git a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/TrendLine.kt b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/TrendLine.kt index 3966f897b..de735fc17 100644 --- a/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/TrendLine.kt +++ b/nebulosa-curve-fitting/src/main/kotlin/nebulosa/curve/fitting/TrendLine.kt @@ -13,19 +13,16 @@ data class TrendLine(val points: Collection) : LinearCurve { points.forEach { regression.addData(it.x, it.y) } } - override val slope = regression.slope.zeroIfNaN() + override val slope = regression.slope.let { if (it.isNaN()) 0.0 else it } - override val intercept = regression.intercept.zeroIfNaN() + override val intercept = regression.intercept.let { if (it.isNaN()) 0.0 else it } - override val rSquared = regression.rSquare.zeroIfNaN() + override val rSquared = regression.rSquare.let { if (it.isNaN()) 0.0 else it } - override fun value(x: Double) = regression.predict(x) + override fun value(x: Double) = if (points.isEmpty()) 0.0 else regression.predict(x) companion object { @JvmStatic val ZERO = TrendLine() - - @Suppress("NOTHING_TO_INLINE") - private inline fun Double.zeroIfNaN() = if (isNaN()) 0.0 else this } } diff --git a/nebulosa-curve-fitting/src/test/kotlin/HyperbolicFittingTest.kt b/nebulosa-curve-fitting/src/test/kotlin/HyperbolicFittingTest.kt index 6709f570e..22830f606 100644 --- a/nebulosa-curve-fitting/src/test/kotlin/HyperbolicFittingTest.kt +++ b/nebulosa-curve-fitting/src/test/kotlin/HyperbolicFittingTest.kt @@ -14,8 +14,8 @@ class HyperbolicFittingTest : StringSpec(), CurveFitting { calculate(1000.0, 18.0, 1100.0, 0.0, 1200.0, 0.0) } diff --git a/nebulosa-curve-fitting/src/test/kotlin/QuadraticFittingTest.kt b/nebulosa-curve-fitting/src/test/kotlin/QuadraticFittingTest.kt index 6746d690b..6839b07b6 100644 --- a/nebulosa-curve-fitting/src/test/kotlin/QuadraticFittingTest.kt +++ b/nebulosa-curve-fitting/src/test/kotlin/QuadraticFittingTest.kt @@ -17,8 +17,8 @@ class QuadraticFittingTest : StringSpec(), CurveFitting ) curve(5.0) shouldBeExactly 2.0 - curve.minimumX shouldBe (5.0 plusOrMinus 1e-12) - curve.minimumY shouldBe (2.0 plusOrMinus 1e-12) + curve.minimum.x shouldBe (5.0 plusOrMinus 1e-12) + curve.minimum.y shouldBe (2.0 plusOrMinus 1e-12) curve.rSquared shouldBe (1.0 plusOrMinus 1e-12) } } diff --git a/nebulosa-curve-fitting/src/test/kotlin/TrendLineTest.kt b/nebulosa-curve-fitting/src/test/kotlin/TrendLineTest.kt index 76fa13c1f..45a69d859 100644 --- a/nebulosa-curve-fitting/src/test/kotlin/TrendLineTest.kt +++ b/nebulosa-curve-fitting/src/test/kotlin/TrendLineTest.kt @@ -6,7 +6,7 @@ class TrendLineTest : StringSpec() { init { "no points" { - val line = TrendLine() + val line = TrendLine.ZERO line.slope shouldBeExactly 0.0 line.intercept shouldBeExactly 0.0