[poincare/inv_norm] Don't simplify if variance is not checked to be >0

This commit is contained in:
Léa Saviot
2019-08-23 16:53:50 +02:00
parent cd227ef3bd
commit a2f435bed3
2 changed files with 34 additions and 4 deletions

View File

@@ -50,16 +50,43 @@ Expression InvNorm::shallowReduce(ExpressionNode::ReductionContext reductionCont
}
Expression c0 = childAtIndex(0);
Expression c1 = childAtIndex(1);
Expression c2 = childAtIndex(2);
Context * context = reductionContext.context();
if (c0.deepIsMatrix(context) || c1.deepIsMatrix(context) || c2.deepIsMatrix(context)) {
return replaceWithUndefinedInPlace();
}
if (!c1.isReal(context) || !c2.isReal(context)) {
// If we cannot check that mu and variance are real, return
return *this;
}
{
Context * context = reductionContext.context();
Expression c2 = childAtIndex(2);
if (c0.deepIsMatrix(context) || c1.deepIsMatrix(context) || c2.deepIsMatrix(context)) {
ExpressionNode::Sign s = c2.sign(context);
if (s == ExpressionNode::Sign::Negative) {
return replaceWithUndefinedInPlace();
}
if (c0.type() != ExpressionNode::Type::Rational) {
// If we cannot check that the variance is positive, return
if (s != ExpressionNode::Sign::Positive) {
return *this;
}
}
// If we cannot check that the variance is not null, return
if (c2.type() != ExpressionNode::Type::Rational) {
return *this;
}
{
Rational r2 = static_cast<Rational &>(c2);
if (r2.isZero()) {
return replaceWithUndefinedInPlace();
}
}
if (c0.type() != ExpressionNode::Type::Rational) {
return *this;
}
// Undef if x < 0 or x > 1
Rational r0 = static_cast<Rational &>(c0);
if (r0.isNegative()) {

View File

@@ -1066,4 +1066,7 @@ QUIZ_CASE(poincare_probabolity) {
assert_parsed_expression_simplify_to("invnorm(0.5,2,3)", "2");
assert_parsed_expression_simplify_to("invnorm(1,2,3)", "inf");
assert_parsed_expression_simplify_to("invnorm(1.3,2,3)", "undef");
assert_parsed_expression_simplify_to("invnorm(3/4,2,random())", "invnorm(3/4,2,random())"); // random can be 0
assert_parsed_expression_simplify_to("invnorm(0.5,2,0)", Undefined::Name());
assert_parsed_expression_simplify_to("invnorm(0.5,2,-1)", Undefined::Name());
}