diff --git a/apps/regression/graph_controller.cpp b/apps/regression/graph_controller.cpp index 494866e2d..25c01bee8 100644 --- a/apps/regression/graph_controller.cpp +++ b/apps/regression/graph_controller.cpp @@ -12,7 +12,7 @@ GraphController::GraphController(Responder * parentResponder, ButtonRowControlle m_crossCursorView(), m_roundCursorView(), m_bannerView(), - m_view(store, m_cursor, &m_bannerView, &m_crossCursorView), + m_view(store, m_cursor, &m_bannerView, &m_crossCursorView, this), m_store(store), m_initialisationParameterController(this, m_store), m_predictionParameterController(this, m_store, m_cursor, this), diff --git a/apps/regression/graph_view.cpp b/apps/regression/graph_view.cpp index 6088ab40e..15a3348c5 100644 --- a/apps/regression/graph_view.cpp +++ b/apps/regression/graph_view.cpp @@ -1,15 +1,19 @@ #include "graph_view.h" +#include "model/model.h" +#include +#include #include using namespace Shared; namespace Regression { -GraphView::GraphView(Store * store, CurveViewCursor * cursor, BannerView * bannerView, View * cursorView) : +GraphView::GraphView(Store * store, CurveViewCursor * cursor, BannerView * bannerView, View * cursorView, Responder * controller) : CurveView(store, cursor, bannerView, cursorView), m_store(store), m_xLabels{}, - m_yLabels{} + m_yLabels{}, + m_controller(controller) { } @@ -23,12 +27,16 @@ void GraphView::drawRect(KDContext * ctx, KDRect rect) const { for (int series = 0; series < Store::k_numberOfSeries; series++) { if (!m_store->seriesIsEmpty(series)) { KDColor color = Palette::DataColor[series]; - float regressionParameters[2] = {(float)m_store->slope(series), (float)m_store->yIntercept(series)}; + Model * seriesModel = m_store->modelForSeries(series); + double coefficients[Model::k_maxNumberOfCoefficients]; + Poincare::Context * globContext = const_cast(static_cast(m_controller->app()->container()))->globalContext(); + seriesModel->fit(m_store, series, coefficients, globContext); drawCurve(ctx, rect, [](float abscissa, void * model, void * context) { - float * params = (float *)model; - return params[0]*abscissa+params[1]; + Model * regressionModel = static_cast(model); + double * regressionCoefficients = static_cast(context); + return (float)regressionModel->evaluate(regressionCoefficients, abscissa); }, - regressionParameters, nullptr, color); + seriesModel, coefficients, color); for (int index = 0; index < m_store->numberOfPairsOfSeries(series); index++) { drawDot(ctx, rect, m_store->get(series, 0, index), m_store->get(series, 1, index), color); } diff --git a/apps/regression/graph_view.h b/apps/regression/graph_view.h index de60b5e91..511397fc6 100644 --- a/apps/regression/graph_view.h +++ b/apps/regression/graph_view.h @@ -10,16 +10,16 @@ namespace Regression { class GraphView : public Shared::CurveView { public: - GraphView(Store * store, Shared::CurveViewCursor * cursor, Shared::BannerView * bannerView, View * cursorView); + GraphView(Store * store, Shared::CurveViewCursor * cursor, Shared::BannerView * bannerView, View * cursorView, Responder * controller); void drawRect(KDContext * ctx, KDRect rect) const override; private: char * label(Axis axis, int index) const override; Store * m_store; char m_xLabels[k_maxNumberOfXLabels][Poincare::PrintFloat::bufferSizeForFloatsWithPrecision(Constant::ShortNumberOfSignificantDigits)]; char m_yLabels[k_maxNumberOfYLabels][Poincare::PrintFloat::bufferSizeForFloatsWithPrecision(Constant::ShortNumberOfSignificantDigits)]; + Responder * m_controller; }; } - #endif diff --git a/apps/regression/model/model.h b/apps/regression/model/model.h index 3b2577e95..291bac41d 100644 --- a/apps/regression/model/model.h +++ b/apps/regression/model/model.h @@ -24,6 +24,7 @@ public: }; static constexpr int k_numberOfModels = 9; static constexpr int k_maxNumberOfCoefficients = 5; + virtual ~Model() = default; virtual double evaluate(double * modelCoefficients, double x) const = 0; virtual void fit(Store * store, int series, double * modelCoefficients, Poincare::Context * context); private: diff --git a/apps/regression/store.cpp b/apps/regression/store.cpp index 291947aef..b1d5b584f 100644 --- a/apps/regression/store.cpp +++ b/apps/regression/store.cpp @@ -1,4 +1,13 @@ #include "store.h" +#include "model/cubic_model.h" +#include "model/exponential_model.h" +#include "model/linear_model.h" +#include "model/logarithmic_model.h" +#include "model/logistic_model.h" +#include "model/power_model.h" +#include "model/quadratic_model.h" +#include "model/quartic_model.h" +#include "model/trigonometric_model.h" #include #include #include @@ -11,6 +20,8 @@ namespace Regression { static inline float max(float x, float y) { return (x>y ? x : y); } static inline float min(float x, float y) { return (x= 0 && series < k_numberOfSeries); m_regressionTypes[series] = type; } + Model * modelForSeries(int series) { + assert(series >= 0 && series < k_numberOfSeries); + assert((int)m_regressionTypes[series] >= 0 && (int)m_regressionTypes[series] < Model::k_numberOfModels); + return m_regressionModels[(int)m_regressionTypes[series]]; + } /* Return the series index of the closest regression at abscissa x, above * ordinate y if direction > 0, below otherwise */ int closestVerticalRegression(int direction, float x, float y, int currentRegressionSeries); @@ -60,6 +70,7 @@ private: float maxValueOfColumn(int series, int i) const; float minValueOfColumn(int series, int i) const; Model::Type m_regressionTypes[k_numberOfSeries]; + Model * m_regressionModels[Model::k_numberOfModels]; }; typedef double (Store::*ArgCalculPointer)(int, int) const;