[poincare] NormCDF2

This commit is contained in:
Léa Saviot
2019-08-23 14:14:44 +02:00
parent 294def02bd
commit 0316702adf
11 changed files with 139 additions and 7 deletions

View File

@@ -96,6 +96,7 @@ poincare_src += $(addprefix poincare/src/,\
n_ary_expression.cpp \
naperian_logarithm.cpp \
norm_cdf.cpp \
norm_cdf2.cpp \
nth_root.cpp \
number.cpp \
opposite.cpp \

View File

@@ -66,6 +66,7 @@ class Expression : public TreeHandle {
friend class MultiplicationNode;
friend class NaperianLogarithm;
friend class NormCDF;
friend class NormCDF2;
friend class NthRoot;
friend class Number;
friend class Opposite;
@@ -280,6 +281,11 @@ protected:
assert(children.type() == ExpressionNode::Type::Matrix);
return U::Builder(children.childAtIndex(0), children.childAtIndex(1), children.childAtIndex(2));
}
template<typename U>
static Expression UntypedBuilderFourChildren(Expression children) {
assert(children.type() == ExpressionNode::Type::Matrix);
return U::Builder(children.childAtIndex(0), children.childAtIndex(1), children.childAtIndex(2), children.childAtIndex(3));
}
template<class T> T convert() const {
/* This function allows to convert Expression to derived Expressions.

View File

@@ -73,6 +73,7 @@ public:
MatrixTrace,
NaperianLogarithm,
NormCDF,
NormCDF2,
NthRoot,
Opposite,
Parenthesis,

View File

@@ -0,0 +1,52 @@
#ifndef POINCARE_NORMCDF2_H
#define POINCARE_NORMCDF2_H
#include <poincare/approximation_helper.h>
#include <poincare/expression.h>
namespace Poincare {
class NormCDF2Node final : public ExpressionNode {
public:
// TreeNode
size_t size() const override { return sizeof(NormCDF2Node); }
int numberOfChildren() const override;
#if POINCARE_TREE_LOG
virtual void logNodeName(std::ostream & stream) const override {
stream << "NormCDF2";
}
#endif
// Properties
Type type() const override { return Type::NormCDF2; }
Sign sign(Context * context) const override { return Sign::Positive; }
Expression setSign(Sign s, ReductionContext reductionContext) override;
private:
// Layout
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
// Simplication
Expression shallowReduce(ReductionContext reductionContext) override;
LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; };
LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; }
// Evaluation
Evaluation<float> approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, complexFormat, angleUnit); }
Evaluation<double> approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<double>(context, complexFormat, angleUnit); }
template<typename T> Evaluation<T> templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;
};
class NormCDF2 final : public Expression {
public:
NormCDF2(const NormCDF2Node * n) : Expression(n) {}
static NormCDF2 Builder(Expression child0, Expression child1, Expression child2, Expression child3) { return TreeHandle::FixedArityBuilder<NormCDF2, NormCDF2Node>(ArrayBuilder<TreeHandle>(child0, child1, child2, child3).array(), 4); }
static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("normcdf2", 4, &UntypedBuilderFourChildren<NormCDF2>);
Expression shallowReduce(ExpressionNode::ReductionContext reductionContext);
};
}
#endif

View File

@@ -9,7 +9,7 @@ class NormalDistribution final {
public:
template<typename T> static T EvaluateAtAbscissa(T x, T mu, T sigma);
template<typename T> static T CumulativeDistributiveFunctionAtAbscissa(T x, T mu, T sigma);
static double CumulativeDistributiveInverseForProbability(double probability, float mu, float sigma);
template<typename T> static T CumulativeDistributiveInverseForProbability(T probability, T mu, T sigma);
private:
/* For the standard normal distribution, P(X < y) > 0.9999995 for y >= 4.892 so the
* value displayed is 1. But this is dependent on the fact that we display
@@ -17,7 +17,7 @@ private:
static_assert(Preferences::LargeNumberOfSignificantDigits == 7, "k_boundStandardNormalDistribution is ill-defined compared to LargeNumberOfSignificantDigits");
constexpr static double k_boundStandardNormalDistribution = 4.892;
template<typename T> static T StandardNormalCumulativeDistributiveFunctionAtAbscissa(T abscissa);
static double StandardNormalCumulativeDistributiveInverseForProbability(double probability);
template<typename T> static T StandardNormalCumulativeDistributiveInverseForProbability(T probability);
};
}

View File

@@ -52,6 +52,7 @@
#include <poincare/multiplication.h>
#include <poincare/naperian_logarithm.h>
#include <poincare/norm_cdf.h>
#include <poincare/norm_cdf2.h>
#include <poincare/nth_root.h>
#include <poincare/number.h>
#include <poincare/opposite.h>

View File

@@ -0,0 +1,62 @@
#include <poincare/norm_cdf2.h>
#include <poincare/layout_helper.h>
#include <poincare/normal_distribution.h>
#include <poincare/serialization_helper.h>
#include <assert.h>
namespace Poincare {
constexpr Expression::FunctionHelper NormCDF2::s_functionHelper;
int NormCDF2Node::numberOfChildren() const { return NormCDF2::s_functionHelper.numberOfChildren(); }
Expression NormCDF2Node::setSign(Sign s, ReductionContext reductionContext) {
assert(s == Sign::Positive);
return NormCDF2(this);
}
Layout NormCDF2Node::createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const {
return LayoutHelper::Prefix(NormCDF2(this), floatDisplayMode, numberOfSignificantDigits, NormCDF2::s_functionHelper.name());
}
int NormCDF2Node::serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const {
return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, NormCDF2::s_functionHelper.name());
}
Expression NormCDF2Node::shallowReduce(ReductionContext reductionContext) {
return NormCDF2(this).shallowReduce(reductionContext);
}
template<typename T>
Evaluation<T> NormCDF2Node::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const {
Evaluation<T> aEvaluation = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit);
Evaluation<T> bEvaluation = childAtIndex(1)->approximate(T(), context, complexFormat, angleUnit);
Evaluation<T> muEvaluation = childAtIndex(2)->approximate(T(), context, complexFormat, angleUnit);
Evaluation<T> sigmaEvaluation = childAtIndex(3)->approximate(T(), context, complexFormat, angleUnit);
T a = aEvaluation.toScalar();
T b = bEvaluation.toScalar();
T mu = muEvaluation.toScalar();
T sigma = sigmaEvaluation.toScalar();
if (std::isnan(a) || std::isnan(b) || std::isnan(mu) || std::isnan(sigma)) {
return Complex<T>::Undefined();
}
if (b <= a) {
return Complex<T>::Builder((T)0.0);
}
return Complex<T>::Builder(NormalDistribution::CumulativeDistributiveFunctionAtAbscissa(b, mu, sigma) - NormalDistribution::CumulativeDistributiveFunctionAtAbscissa(a, mu, sigma));
}
Expression NormCDF2::shallowReduce(ExpressionNode::ReductionContext reductionContext) {
{
Expression e = Expression::defaultShallowReduce();
if (e.isUndefined()) {
return e;
}
}
//TODO LEA
return *this;
}
}

View File

@@ -23,8 +23,9 @@ T NormalDistribution::CumulativeDistributiveFunctionAtAbscissa(T x, T mu, T sigm
return StandardNormalCumulativeDistributiveFunctionAtAbscissa<T>((x-mu)/std::fabs(sigma));
}
double NormalDistribution::CumulativeDistributiveInverseForProbability(double probability, float mu, float sigma) {
if (sigma == 0.0f) {
template<typename T>
T NormalDistribution::CumulativeDistributiveInverseForProbability(T probability, T mu, T sigma) {
if (sigma == (T)0.0) {
return NAN;
}
return StandardNormalCumulativeDistributiveInverseForProbability(probability) * std::fabs(sigma) + mu;
@@ -44,7 +45,8 @@ T NormalDistribution::StandardNormalCumulativeDistributiveFunctionAtAbscissa(T a
return ((T)0.5) + ((T)0.5) * std::erf(abscissa/std::sqrt(((T)2.0)));
}
double NormalDistribution::StandardNormalCumulativeDistributiveInverseForProbability(double probability) {
template<typename T>
T NormalDistribution::StandardNormalCumulativeDistributiveInverseForProbability(T probability) {
if (probability >= 1.0) {
return INFINITY;
}
@@ -52,13 +54,15 @@ double NormalDistribution::StandardNormalCumulativeDistributiveInverseForProbabi
return -INFINITY;
}
if (probability < 0.5) {
return -StandardNormalCumulativeDistributiveInverseForProbability(1-probability);
return -StandardNormalCumulativeDistributiveInverseForProbability(1.0-probability);
}
return std::sqrt(2.0) * erfInv(2.0 * probability - 1.0);
}
template float NormalDistribution::EvaluateAtAbscissa<float>(float, float, float);
template double NormalDistribution::CumulativeDistributiveFunctionAtAbscissa<double>(double, double, double);
template float NormalDistribution::CumulativeDistributiveFunctionAtAbscissa<float>(float, float, float);
template double NormalDistribution::CumulativeDistributiveFunctionAtAbscissa<double>(double, double, double);
template float NormalDistribution::CumulativeDistributiveInverseForProbability<float>(float, float, float);
template double NormalDistribution::CumulativeDistributiveInverseForProbability<double>(double, double, double);
}

View File

@@ -121,6 +121,7 @@ private:
&CommonLogarithm::s_functionHelper,
&Logarithm::s_functionHelper,
&NormCDF::s_functionHelper,
&NormCDF2::s_functionHelper,
&PermuteCoefficient::s_functionHelper,
&SimplePredictionInterval::s_functionHelper,
&PredictionInterval::s_functionHelper,

View File

@@ -317,6 +317,7 @@ template MatrixTranspose TreeHandle::FixedArityBuilder<MatrixTranspose, MatrixTr
template Multiplication TreeHandle::NAryBuilder<Multiplication, MultiplicationNode>(TreeHandle*, size_t);
template NaperianLogarithm TreeHandle::FixedArityBuilder<NaperianLogarithm, NaperianLogarithmNode>(TreeHandle*, size_t);
template NormCDF TreeHandle::FixedArityBuilder<NormCDF, NormCDFNode>(TreeHandle*, size_t);
template NormCDF2 TreeHandle::FixedArityBuilder<NormCDF2, NormCDF2Node>(TreeHandle*, size_t);
template NthRoot TreeHandle::FixedArityBuilder<NthRoot, NthRootNode>(TreeHandle*, size_t);
template Opposite TreeHandle::FixedArityBuilder<Opposite, OppositeNode>(TreeHandle*, size_t);
template Parenthesis TreeHandle::FixedArityBuilder<Parenthesis, ParenthesisNode>(TreeHandle*, size_t);

View File

@@ -279,6 +279,9 @@ QUIZ_CASE(poincare_approximation_function) {
assert_expression_approximates_to<float>("normcdf(1.2, 3.4, 5.6)", "0.3472125");
assert_expression_approximates_to<double>("normcdf(1.2, 3.4, 5.6)", "3.4721249841587ᴇ-1");
assert_expression_approximates_to<float>("normcdf2(0.5, 3.6, 1.3, 3.4)", "0.3436388");
assert_expression_approximates_to<double>("normcdf2(0.5, 3.6, 1.3, 3.4)", "3.4363881299147ᴇ-1");
assert_expression_approximates_to<float>("permute(10, 4)", "5040");
assert_expression_approximates_to<double>("permute(10, 4)", "5040");