[poincare] Fix PredictionInterval

This commit is contained in:
Émilie Feral
2018-08-28 16:36:10 +02:00
parent 6452d496b0
commit 13076dfd02
7 changed files with 99 additions and 75 deletions

View File

@@ -98,6 +98,7 @@ objs += $(addprefix poincare/src/,\
parenthesis.o\
power.o\
print_float.o\
prediction_interval.o\
preferences.o\
product.o\
randint.o\

View File

@@ -167,6 +167,7 @@
#include <poincare/matrix_complex.h>
#include <poincare/multiplication.h>
#include <poincare/power.h>
#include <poincare/prediction_interval.h>
#include <poincare/randint.h>
#include <poincare/random.h>
#include <poincare/subtraction.h>

View File

@@ -2,32 +2,59 @@
#define POINCARE_PREDICTION_INTERVAL_H
#include <poincare/layout_helper.h>
#include <poincare/static_hierarchy.h>
#include <poincare/expression.h>
#include <poincare/serialization_helper.h>
namespace Poincare {
class PredictionInterval : public StaticHierarchy<2> {
using StaticHierarchy<2>::StaticHierarchy;
class PredictionIntervalNode : public ExpressionNode {
public:
Type type() const override;
int polynomialDegree(char symbolName) const override;
private:
/* Layout */
LayoutRef createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override {
return LayoutHelper::Prefix(this, floatDisplayMode, numberOfSignificantDigits, name());
static PredictionIntervalNode * FailedAllocationStaticNode();
PredictionIntervalNode * failedAllocationStaticNode() override { return FailedAllocationStaticNode(); }
// TreeNode
size_t size() const override { return sizeof(PredictionIntervalNode); }
int numberOfChildren() const override { return 2; }
#if POINCARE_TREE_LOG
virtual void logNodeName(std::ostream & stream) const override {
stream << "PredictionInterval";
}
#endif
// ExpressionNode
// Properties
Type type() const override { return Type::PredictionInterval; }
int polynomialDegree(char symbolName) const override { return -1; }
private:
// Layout
LayoutReference createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override {
return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, name());
}
const char * name() const { return "prediction95"; }
/* Simplification */
// Simplification
Expression shallowReduce(Context& context, Preferences::AngleUnit angleUnit) const override;
/* Evaluation */
Evaluation<float> approximate(Expression::SinglePrecision p, Context& context, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, angleUnit); }
Evaluation<double> * approximate(Expression::DoublePrecision p, Context& context, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<double>(context, angleUnit); }
// Evaluation
Evaluation<float> approximate(SinglePrecision p, Context& context, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, angleUnit); }
Evaluation<double> approximate(DoublePrecision p, Context& context, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<double>(context, angleUnit); }
template<typename T> Evaluation<T> templatedApproximate(Context& context, Preferences::AngleUnit angleUnit) const;
};
class PredictionInterval : public Expression {
public:
PredictionInterval() : Expression(TreePool::sharedPool()->createTreeNode<PredictionIntervalNode>()) {}
PredictionInterval(const PredictionIntervalNode * n) : Expression(n) {}
PredictionInterval(Expression child1, Expression child2) : PredictionInterval() {
replaceChildAtIndexInPlace(0, child1);
replaceChildAtIndexInPlace(1, child2);
}
// Expression
Expression shallowReduce(Context& context, Preferences::AngleUnit angleUnit) const;
};
}
#endif

View File

@@ -33,8 +33,8 @@ template<typename T>
Evaluation<T> ConfidenceIntervalNode::templatedApproximate(Context& context, Preferences::AngleUnit angleUnit) const {
Evaluation<T> fInput = childAtIndex(0)->approximate(T(), context, angleUnit);
Evaluation<T> nInput = childAtIndex(1)->approximate(T(), context, angleUnit);
T f = static_cast<Complex<T> >(fInput).toScalar();
T n = static_cast<Complex<T> >(nInput).toScalar();
T f = static_cast<Complex<T> &>(fInput).toScalar();
T n = static_cast<Complex<T> &>(nInput).toScalar();
if (std::isnan(f) || std::isnan(n) || n != (int)n || n < 0 || f < 0 || f > 1) {
return Complex<T>::Undefined();
}
@@ -63,13 +63,13 @@ Expression ConfidenceInterval::shallowReduce(Context& context, Preferences::Angl
}
#endif
if (c0.type() == ExpressionNode::Type::Rational) {
Rational r0 = static_cast<Rational>(c0);
Rational r0 = static_cast<Rational&>(c0);
if (r0.signedIntegerNumerator().isNegative() || Integer::NaturalOrder(r0.signedIntegerNumerator(), r0.integerDenominator()) > 0) {
return Undefined();
}
}
if (c1.type() == ExpressionNode::Type::Rational) {
Rational r1 = static_cast<Rational>(c1);
Rational r1 = static_cast<Rational&>(c1);
if (!r1.integerDenominator().isOne() || r1.signedIntegerNumerator().isNegative()) {
return Undefined();
}
@@ -77,8 +77,8 @@ Expression ConfidenceInterval::shallowReduce(Context& context, Preferences::Angl
if (c0.type() != ExpressionNode::Type::Rational || c1.type() != ExpressionNode::Type::Rational) {
return *this;
}
Rational r0 = static_cast<Rational>(c0);
Rational r1 = static_cast<Rational>(c1);
Rational r0 = static_cast<Rational&>(c0);
Rational r1 = static_cast<Rational&>(c1);
// Compute [r0-1/sqr(r1), r0+1/sqr(r1)]
Expression sqr = Power(r1, Rational(-1, 2));
Matrix matrix;

View File

@@ -140,8 +140,8 @@ lcm { poincare_expression_yylval.expression = LeastCommonMultiple(); return FUNC
ln { poincare_expression_yylval.expression = NaperianLogarithm(); return FUNCTION; }
log { return LOGFUNCTION; }
/*permute { poincare_expression_yylval.expression = new PermuteCoefficient(); return FUNCTION; }
prediction95 { poincare_expression_yylval.expression = new PredictionInterval(); return FUNCTION; }
*/
prediction95 { poincare_expression_yylval.expression = PredictionInterval(); return FUNCTION; }
prediction { poincare_expression_yylval.expression = SimplePredictionInterval(); return FUNCTION; }
product { poincare_expression_yylval.expression = Product(); return FUNCTION; }
quo { poincare_expression_yylval.expression = DivisionQuotient(); return FUNCTION; }

View File

@@ -12,17 +12,33 @@ extern "C" {
namespace Poincare {
ExpressionNode::Type PredictionInterval::type() const {
return Type::PredictionInterval;
PredictionIntervalNode * PredictionIntervalNode::FailedAllocationStaticNode() {
static AllocationFailureExpressionNode<PredictionIntervalNode> failure;
TreePool::sharedPool()->registerStaticNodeIfRequired(&failure);
return &failure;
}
Expression * PredictionInterval::clone() const {
PredictionInterval * a = new PredictionInterval(m_operands, true);
return a;
LayoutReference PredictionIntervalNode::createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const {
return LayoutHelper::Prefix(PredictionInterval(this), floatDisplayMode, numberOfSignificantDigits, name());
}
int PredictionInterval::polynomialDegree(char symbolName) const {
return -1;
Expression PredictionIntervalNode::shallowReduce(Context& context, Preferences::AngleUnit angleUnit) const {
return PredictionInterval(this).shallowReduce(context, angleUnit);
}
template<typename T>
Evaluation<T> PredictionIntervalNode::templatedApproximate(Context& context, Preferences::AngleUnit angleUnit) const {
Evaluation<T> pInput = childAtIndex(0)->approximate(T(), context, angleUnit);
Evaluation<T> nInput = childAtIndex(1)->approximate(T(), context, angleUnit);
T p = static_cast<Complex<T> &>(pInput).toScalar();
T n = static_cast<Complex<T> &>(nInput).toScalar();
if (std::isnan(p) || std::isnan(n) || n != (int)n || n < 0 || p < 0 || p > 1) {
return Complex<T>::Undefined();
}
std::complex<T> operands[2];
operands[0] = std::complex<T>(p - 1.96*std::sqrt(p*(1.0-p))/std::sqrt(n));
operands[1] = std::complex<T>(p + 1.96*std::sqrt(p*(1.0-p))/std::sqrt(n));
return MatrixComplex<T>(operands, 1, 2);
}
Expression PredictionInterval::shallowReduce(Context& context, Preferences::AngleUnit angleUnit) const {
@@ -30,62 +46,44 @@ Expression PredictionInterval::shallowReduce(Context& context, Preferences::Angl
if (e.isUndefinedOrAllocationFailure()) {
return e;
}
Expression * op0 = childAtIndex(0);
Expression * op1 = childAtIndex(1);
Expression op0 = childAtIndex(0);
Expression op1 = childAtIndex(1);
#if MATRIX_EXACT_REDUCING
if (op0->type() == Type::Matrix || op1->type() == Type::Matrix) {
return replaceWith(new Undefined(), true);
if (op0.type() == ExpressionNode::Type::Matrix || op1.type() == ExpressionNode::Type::Matrix) {
return Undefined();
}
#endif
if (op0->type() == Type::Rational) {
Rational * r0 = static_cast<Rational *>(op0);
if (r0->numerator().isNegative() || Integer::NaturalOrder(r0->numerator(), r0->denominator()) > 0) {
return replaceWith(new Undefined(), true);
if (op0.type() == ExpressionNode::Type::Rational) {
Rational r0 = static_cast<Rational &>(op0);
if (r0.sign() == ExpressionNode::Sign::Negative || Integer::NaturalOrder(r0.unsignedIntegerNumerator(), r0.integerDenominator()) > 0) {
return Undefined();
}
}
if (op1->type() == Type::Rational) {
Rational * r1 = static_cast<Rational *>(op1);
if (!r1->denominator().isOne() || r1->numerator().isNegative()) {
return replaceWith(new Undefined(), true);
if (op1.type() == ExpressionNode::Type::Rational) {
Rational r1 = static_cast<Rational &>(op1);
if (!r1.integerDenominator().isOne() || r1.sign() == ExpressionNode::Sign::Negative) {
return Undefined();
}
}
if (op0->type() != Type::Rational || op1->type() != Type::Rational) {
return this;
if (op0.type() != ExpressionNode::Type::Rational || op1.type() != ExpressionNode::Type::Rational) {
return *this;
}
Rational * r0 = static_cast<Rational *>(op0);
Rational * r1 = static_cast<Rational *>(op1);
if (!r1->denominator().isOne() || r1->numerator().isNegative() || r0->numerator().isNegative() || Integer::NaturalOrder(r0->numerator(), r0->denominator()) > 0) {
return replaceWith(new Undefined(), true);
Rational r0 = static_cast<Rational &>(op0);
Rational r1 = static_cast<Rational &>(op1);
if (!r1.integerDenominator().isOne() || r1.sign() == ExpressionNode::Sign::Negative || r0.sign() == ExpressionNode::Sign::Negative || Integer::NaturalOrder(r0.unsignedIntegerNumerator(), r0.integerDenominator()) > 0) {
return Undefined();
}
detachOperand(r0);
detachOperand(r1);
/* [r0-1.96*sqrt(r0*(1-r0)/r1), r0+1.96*sqrt(r0*(1-r0)/r1)]*/
// Compute numerator = r0*(1-r0)
Rational * numerator = new Rational(Rational::Multiplication(*r0, Rational(Integer::Subtraction(r0->denominator(), r0->numerator()), r0->denominator())));
Rational numerator = Rational::Multiplication(r0, Rational(Integer::Subtraction(r0.integerDenominator(), r0.unsignedIntegerNumerator()), r0.integerDenominator()));
// Compute sqr = sqrt(r0*(1-r0)/r1)
Expression * sqr = new Power(new Division(numerator, r1, false), new Rational(1, 2), false);
Expression * m = new Multiplication(new Rational(196, 100), sqr, false);
const Expression * newOperands[2] = {new Addition(r0, new Multiplication(new Rational(-1), m, false), false), new Addition(r0, m, true),};
Expression * matrix = replaceWith(new Matrix(newOperands, 1, 2, false), true);
return matrix->deepReduce(context, angleUnit);
}
template<typename T>
Evaluation<T> PredictionInterval::templatedApproximate(Context& context, Preferences::AngleUnit angleUnit) const {
Evaluation<T> * pInput = childAtIndex(0)->approximate(T(), context, angleUnit);
Evaluation<T> * nInput = childAtIndex(1)->approximate(T(), context, angleUnit);
T p = static_cast<Complex<T> *>(pInput)->toScalar();
T n = static_cast<Complex<T> *>(nInput)->toScalar();
delete pInput;
delete nInput;
if (std::isnan(p) || std::isnan(n) || n != (int)n || n < 0 || p < 0 || p > 1) {
return new Complex<T>(Complex<T>::Undefined());
}
std::complex<T> operands[2];
operands[0] = std::complex<T>(p - 1.96*std::sqrt(p*(1.0-p))/std::sqrt(n));
operands[1] = std::complex<T>(p + 1.96*std::sqrt(p*(1.0-p))/std::sqrt(n));
return new MatrixComplex<T>(operands, 1, 2);
Expression sqr = Power(Division(numerator, r1), Rational(1, 2));
Expression m = Multiplication(Rational(196, 100), sqr);
Matrix matrix;
matrix.addChildAtIndexInPlace(Addition(r0, Multiplication(Rational(-1), m)), 0, 0);
matrix.addChildAtIndexInPlace(Addition(r0, m), 1, 1);
matrix.setDimensions(1, 2);
return matrix.deepReduce(context, angleUnit);
}
}

View File

@@ -43,9 +43,7 @@ QUIZ_CASE(poincare_parse_function) {
assert_parsed_expression_type("permute(10, 4)", ExpressionNode::Type::PermuteCoefficient);
#endif
assert_parsed_expression_type("prediction(0.1, 100)", ExpressionNode::Type::ConfidenceInterval);
#if 0
assert_parsed_expression_type("prediction95(0.1, 100)", ExpressionNode::Type::PredictionInterval);
#endif
assert_parsed_expression_type("product(n, 4, 10)", ExpressionNode::Type::Product);
assert_parsed_expression_type("quo(29, 10)", ExpressionNode::Type::DivisionQuotient);
@@ -174,11 +172,10 @@ QUIZ_CASE(poincare_function_evaluate) {
assert_parsed_expression_evaluates_to<float>("prediction(0.1, 100)", "[[0,0.2]]");
assert_parsed_expression_evaluates_to<double>("prediction(0.1, 100)", "[[0,0.2]]");
#if 0
assert_parsed_expression_evaluates_to<float>("prediction95(0.1, 100)", "[[0.0412,0.1588]]");
assert_parsed_expression_evaluates_to<double>("prediction95(0.1, 100)", "[[0.0412,0.1588]]");
#endif
assert_parsed_expression_evaluates_to<float>("product(2+n*I, 1, 5)", "(-100)-540*I");
assert_parsed_expression_evaluates_to<double>("product(2+n*I, 1, 5)", "(-100)-540*I");