diff --git a/poincare/include/poincare/normal_distribution.h b/poincare/include/poincare/normal_distribution.h index 1ec7b98d6..bf459783e 100644 --- a/poincare/include/poincare/normal_distribution.h +++ b/poincare/include/poincare/normal_distribution.h @@ -1,6 +1,7 @@ #ifndef POINCARE_NORMAL_DISTRIBUTION_H #define POINCARE_NORMAL_DISTRIBUTION_H +#include #include namespace Poincare { @@ -11,6 +12,9 @@ public: template static T CumulativeDistributiveFunctionAtAbscissa(T x, T mu, T var); template static T CumulativeDistributiveInverseForProbability(T probability, T mu, T var); template static bool ParametersAreOK(T mu, T var); + /* ExpressionParametersAreOK returns true if the expression could be verified. + * The result of the verification is *result. */ + static bool ExpressionParametersAreOK(bool * result, const Expression & mu, const Expression & var, Context * context); private: /* For the standard normal distribution, P(X < y) > 0.9999995 for y >= 4.892 so the * value displayed is 1. But this is dependent on the fact that we display diff --git a/poincare/src/inv_norm.cpp b/poincare/src/inv_norm.cpp index 3d3f2702b..1098df8cb 100644 --- a/poincare/src/inv_norm.cpp +++ b/poincare/src/inv_norm.cpp @@ -51,40 +51,26 @@ Expression InvNorm::shallowReduce(ExpressionNode::ReductionContext reductionCont Expression c2 = childAtIndex(2); Context * context = reductionContext.context(); - if (c0.deepIsMatrix(context) || c1.deepIsMatrix(context) || c2.deepIsMatrix(context)) { + // Check mu and var + bool muAndVarOK = false; + bool couldCheckMuAndVar = NormalDistribution::ExpressionParametersAreOK(&muAndVarOK, c1, c2, context); + if (!couldCheckMuAndVar) { + return *this; + } + if (!muAndVarOK) { return replaceWithUndefinedInPlace(); } - if (!c1.isReal(context) || !c2.isReal(context)) { - // If we cannot check that mu and variance are real, return - return *this; + // Check a + if (c0.deepIsMatrix(context)) { + return replaceWithUndefinedInPlace(); } - - { - ExpressionNode::Sign s = c2.sign(context); - if (s == ExpressionNode::Sign::Negative) { - return replaceWithUndefinedInPlace(); - } - // If we cannot check that the variance is positive, return - if (s != ExpressionNode::Sign::Positive) { - return *this; - } - } - - // If we cannot check that the variance is not null, return - if (c2.type() != ExpressionNode::Type::Rational) { - return *this; - } - { - Rational r2 = static_cast(c2); - if (r2.isZero()) { - return replaceWithUndefinedInPlace(); - } - } - if (c0.type() != ExpressionNode::Type::Rational) { return *this; } + + // Special values + // Undef if x < 0 or x > 1 Rational r0 = static_cast(c0); if (r0.isNegative()) { diff --git a/poincare/src/normal_distribution.cpp b/poincare/src/normal_distribution.cpp index 19df31ef1..2d3945331 100644 --- a/poincare/src/normal_distribution.cpp +++ b/poincare/src/normal_distribution.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -38,6 +39,44 @@ bool NormalDistribution::ParametersAreOK(T mu, T var) { && var > (T)0.0; } +bool NormalDistribution::ExpressionParametersAreOK(bool * result, const Expression & mu, const Expression & var, Context * context) { + assert(result != nullptr); + if (mu.deepIsMatrix(context) || var.deepIsMatrix(context)) { + *result = false; + return true; + } + + if (!mu.isReal(context) || !var.isReal(context)) { + // We cannot check that mu and variance are real + return false; + } + + { + ExpressionNode::Sign s = var.sign(context); + if (s == ExpressionNode::Sign::Negative) { + *result = false; + return true; + } + // We cannot check that the variance is positive + if (s != ExpressionNode::Sign::Positive) { + return false; + } + } + + if (var.type() != ExpressionNode::Type::Rational) { + // We cannot check that the variance is not null + return false; + } + + const Rational rationalVar = static_cast(var); + if (rationalVar.isZero()) { + *result = false; + return true; + } + *result = true; + return true; +} + template T NormalDistribution::StandardNormalCumulativeDistributiveFunctionAtAbscissa(T abscissa) { if (std::isnan(abscissa) || std::isinf(abscissa)) {