[apps/regression] Factorize Model::simplifiedExpression

This commit is contained in:
Léa Saviot
2019-06-28 15:55:50 +02:00
committed by Émilie Feral
parent e8be088478
commit 5a79d26191
10 changed files with 120 additions and 113 deletions

View File

@@ -48,33 +48,6 @@ Layout CubicModel::layout() {
return m_layout;
}
Expression CubicModel::simplifiedExpression(double * modelCoefficients, Poincare::Context * context) {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
double c = modelCoefficients[2];
double d = modelCoefficients[3];
Expression addChildren[] = {
Multiplication::Builder(
Number::DecimalNumber(a),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(3.0))),
Multiplication::Builder(
Number::DecimalNumber(b),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(2.0))),
Multiplication::Builder(
Number::DecimalNumber(c),
Symbol::Builder('x')),
Number::DecimalNumber(d)
};
// a*x^3+b*x^2+c*x+d
Expression result = Addition::Builder(addChildren, 4);
PoincareHelpers::Simplify(&result, *context);
return result;
}
double CubicModel::evaluate(double * modelCoefficients, double x) const {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
@@ -104,4 +77,30 @@ double CubicModel::partialDerivate(double * modelCoefficients, int derivateCoeff
return 0.0;
}
Expression CubicModel::expression(double * modelCoefficients) {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
double c = modelCoefficients[2];
double d = modelCoefficients[3];
Expression addChildren[] = {
Multiplication::Builder(
Number::DecimalNumber(a),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(3.0))),
Multiplication::Builder(
Number::DecimalNumber(b),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(2.0))),
Multiplication::Builder(
Number::DecimalNumber(c),
Symbol::Builder('x')),
Number::DecimalNumber(d)
};
// a*x^3+b*x^2+c*x+d
Expression result = Addition::Builder(addChildren, 4);
return result;
}
}

View File

@@ -9,12 +9,13 @@ class CubicModel : public Model {
public:
using Model::Model;
Poincare::Layout layout() override;
Poincare::Expression simplifiedExpression(double * modelCoefficients, Poincare::Context * context) override;
I18n::Message formulaMessage() const override { return I18n::Message::CubicRegressionFormula; }
double evaluate(double * modelCoefficients, double x) const override;
double partialDerivate(double * modelCoefficients, int derivateCoefficientIndex, double x) const override;
int numberOfCoefficients() const override { return 4; }
int bannerLinesCount() const override { return 4; }
private:
Poincare::Expression expression(double * modelCoefficients) override;
};
}

View File

@@ -15,6 +15,14 @@ void Model::tidy() {
m_layout = Layout();
}
Poincare::Expression Model::simplifiedExpression(double * modelCoefficients, Poincare::Context * context) {
Expression e = expression(modelCoefficients);
if (!e.isUninitialized()) {
PoincareHelpers::Simplify(&e, *context);
}
return e;
}
double Model::levelSet(double * modelCoefficients, double xMin, double step, double xMax, double y, Poincare::Context * context) {
Expression yExpression = Number::DecimalNumber(y);
PoincareHelpers::Simplify(&yExpression, *context);

View File

@@ -30,8 +30,7 @@ public:
virtual Poincare::Layout layout() = 0;
// Reinitialize m_layout to empty the pool
void tidy();
// simplifiedExpression is overrided only by Models that override levelSet
virtual Poincare::Expression simplifiedExpression(double * modelCoefficients, Poincare::Context * context) { return Poincare::Expression(); }
Poincare::Expression simplifiedExpression(double * modelCoefficients, Poincare::Context * context);
virtual I18n::Message formulaMessage() const = 0;
virtual double evaluate(double * modelCoefficients, double x) const = 0;
virtual double levelSet(double * modelCoefficients, double xMin, double step, double xMax, double y, Poincare::Context * context);
@@ -45,6 +44,7 @@ protected:
Poincare::Layout m_layout;
private:
// Model attributes
virtual Poincare::Expression expression(double * modelCoefficients) { return Poincare::Expression(); } // expression is overrided only by Models that do not override levelSet
virtual double partialDerivate(double * modelCoefficients, int derivateCoefficientIndex, double x) const = 0;
// Levenberg-Marquardt

View File

@@ -40,27 +40,6 @@ Layout QuadraticModel::layout() {
return m_layout;
}
Expression QuadraticModel::simplifiedExpression(double * modelCoefficients, Poincare::Context * context) {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
double c = modelCoefficients[2];
// a*x^2+b*x+c
Expression addChildren[] = {
Multiplication::Builder(
Number::DecimalNumber(a),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(2.0))),
Multiplication::Builder(
Number::DecimalNumber(b),
Symbol::Builder('x')),
Number::DecimalNumber(c)
};
Expression result = Addition::Builder(addChildren, 3);
PoincareHelpers::Simplify(&result, *context);
return result;
}
double QuadraticModel::evaluate(double * modelCoefficients, double x) const {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
@@ -85,4 +64,24 @@ double QuadraticModel::partialDerivate(double * modelCoefficients, int derivateC
return 0.0;
}
Expression QuadraticModel::expression(double * modelCoefficients) {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
double c = modelCoefficients[2];
// a*x^2+b*x+c
Expression addChildren[] = {
Multiplication::Builder(
Number::DecimalNumber(a),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(2.0))),
Multiplication::Builder(
Number::DecimalNumber(b),
Symbol::Builder('x')),
Number::DecimalNumber(c)
};
Expression result = Addition::Builder(addChildren, 3);
return result;
}
}

View File

@@ -9,12 +9,13 @@ class QuadraticModel : public Model {
public:
using Model::Model;
Poincare::Layout layout() override;
Poincare::Expression simplifiedExpression(double * modelCoefficients, Poincare::Context * context) override;
I18n::Message formulaMessage() const override { return I18n::Message::QuadraticRegressionFormula; }
double evaluate(double * modelCoefficients, double x) const override;
double partialDerivate(double * modelCoefficients, int derivateCoefficientIndex, double x) const override;
int numberOfCoefficients() const override { return 3; }
int bannerLinesCount() const override { return 3; }
private:
Poincare::Expression expression(double * modelCoefficients) override;
};
}

View File

@@ -56,43 +56,6 @@ Layout QuarticModel::layout() {
return m_layout;
}
Expression QuarticModel::simplifiedExpression(double * modelCoefficients, Poincare::Context * context) {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
double c = modelCoefficients[2];
double d = modelCoefficients[3];
double e = modelCoefficients[4];
Expression addChildren[] = {
// a*x^4
Multiplication::Builder(
Number::DecimalNumber(a),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(4.0))),
// b*x^3
Multiplication::Builder(
Number::DecimalNumber(b),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(3.0))),
// c*x^2
Multiplication::Builder(
Number::DecimalNumber(c),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(2.0))),
// d*x
Multiplication::Builder(
Number::DecimalNumber(d),
Symbol::Builder('x')),
// e
Number::DecimalNumber(e)
};
Expression result = Addition::Builder(addChildren, 5);
PoincareHelpers::Simplify(&result, *context);
return result;
}
double QuarticModel::evaluate(double * modelCoefficients, double x) const {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
@@ -127,4 +90,40 @@ double QuarticModel::partialDerivate(double * modelCoefficients, int derivateCoe
return 0.0;
}
Expression QuarticModel::expression(double * modelCoefficients) {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
double c = modelCoefficients[2];
double d = modelCoefficients[3];
double e = modelCoefficients[4];
Expression addChildren[] = {
// a*x^4
Multiplication::Builder(
Number::DecimalNumber(a),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(4.0))),
// b*x^3
Multiplication::Builder(
Number::DecimalNumber(b),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(3.0))),
// c*x^2
Multiplication::Builder(
Number::DecimalNumber(c),
Power::Builder(
Symbol::Builder('x'),
Decimal::Builder(2.0))),
// d*x
Multiplication::Builder(
Number::DecimalNumber(d),
Symbol::Builder('x')),
// e
Number::DecimalNumber(e)
};
Expression result = Addition::Builder(addChildren, 5);
return result;
}
}

View File

@@ -9,12 +9,13 @@ class QuarticModel : public Model {
public:
using Model::Model;
Poincare::Layout layout() override;
Poincare::Expression simplifiedExpression(double * modelCoefficients, Poincare::Context * context) override;
I18n::Message formulaMessage() const override { return I18n::Message::QuarticRegressionFormula; }
double evaluate(double * modelCoefficients, double x) const override;
double partialDerivate(double * modelCoefficients, int derivateCoefficientIndex, double x) const override;
int numberOfCoefficients() const override { return 5; }
int bannerLinesCount() const override { return 4; }
private:
Poincare::Expression expression(double * modelCoefficients) override;
};
}

View File

@@ -24,27 +24,6 @@ Layout TrigonometricModel::layout() {
return m_layout;
}
Expression TrigonometricModel::simplifiedExpression(double * modelCoefficients, Poincare::Context * context) {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
double c = modelCoefficients[2];
double d = modelCoefficients[3];
// a*sin(bx+c)+d
Expression result =
Addition::Builder(
Multiplication::Builder(
Number::DecimalNumber(a),
Sine::Builder(
Addition::Builder(
Multiplication::Builder(
Number::DecimalNumber(b),
Symbol::Builder('x')),
Number::DecimalNumber(c)))),
Number::DecimalNumber(d));
PoincareHelpers::Simplify(&result, *context);
return result;
}
double TrigonometricModel::evaluate(double * modelCoefficients, double x) const {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
@@ -79,4 +58,24 @@ double TrigonometricModel::partialDerivate(double * modelCoefficients, int deriv
return 0.0;
}
Expression TrigonometricModel::expression(double * modelCoefficients) {
double a = modelCoefficients[0];
double b = modelCoefficients[1];
double c = modelCoefficients[2];
double d = modelCoefficients[3];
// a*sin(bx+c)+d
Expression result =
Addition::Builder(
Multiplication::Builder(
Number::DecimalNumber(a),
Sine::Builder(
Addition::Builder(
Multiplication::Builder(
Number::DecimalNumber(b),
Symbol::Builder('x')),
Number::DecimalNumber(c)))),
Number::DecimalNumber(d));
return result;
}
}

View File

@@ -9,15 +9,15 @@ class TrigonometricModel : public Model {
public:
using Model::Model;
Poincare::Layout layout() override;
Poincare::Expression simplifiedExpression(double * modelCoefficients, Poincare::Context * context) override;
I18n::Message formulaMessage() const override { return I18n::Message::TrigonometricRegressionFormula; }
double evaluate(double * modelCoefficients, double x) const override;
double partialDerivate(double * modelCoefficients, int derivateCoefficientIndex, double x) const override;
int numberOfCoefficients() const override { return 4; }
int bannerLinesCount() const override { return 4; }
private:
Poincare::Expression expression(double * modelCoefficients) override;
};
}
#endif