diff --git a/apps/regression/model/logistic_model.cpp b/apps/regression/model/logistic_model.cpp index 5ffe02a48..f74db8f3d 100644 --- a/apps/regression/model/logistic_model.cpp +++ b/apps/regression/model/logistic_model.cpp @@ -1,4 +1,5 @@ #include "logistic_model.h" +#include "../store.h" #include #include #include @@ -81,4 +82,16 @@ double LogisticModel::partialDerivate(double * modelCoefficients, int derivateCo return 0.0; } +void LogisticModel::specializedInitCoefficientsForFit(double * modelCoefficients, double defaultValue, Store * store, int series) const { + assert(store != nullptr && series >= 0 && series < Store::k_numberOfSeries && !store->seriesIsEmpty(series)); + modelCoefficients[0] = defaultValue; + modelCoefficients[1] = defaultValue; + /* If the data is a standard logistic function, the ordinates are between 0 + * and c. Twice the standard vertical deviation is a rough estimate of c + * that is "close enough" to c to seed the coefficient, without being too + * dependent on outliers.*/ + modelCoefficients[2] = 2.0*store->standardDeviationOfColumn(series, 1); +} + + } diff --git a/apps/regression/model/logistic_model.h b/apps/regression/model/logistic_model.h index aa338e88e..fbdca27e9 100644 --- a/apps/regression/model/logistic_model.h +++ b/apps/regression/model/logistic_model.h @@ -15,6 +15,8 @@ public: double partialDerivate(double * modelCoefficients, int derivateCoefficientIndex, double x) const override; int numberOfCoefficients() const override { return 3; } int bannerLinesCount() const override { return 3; } +private: + void specializedInitCoefficientsForFit(double * modelCoefficients, double defaultValue, Store * store, int series) const override; }; }