[poincare/normal_distribution_function] Factorize code in parent class

This commit is contained in:
Léa Saviot
2019-08-26 10:44:19 +02:00
parent ff0105182d
commit 5c443e0412
12 changed files with 89 additions and 136 deletions

View File

@@ -99,6 +99,7 @@ poincare_src += $(addprefix poincare/src/,\
norm_cdf.cpp \
norm_cdf2.cpp \
norm_pdf.cpp \
normal_distribution_function.cpp \
nth_root.cpp \
number.cpp \
opposite.cpp \

View File

@@ -66,9 +66,10 @@ class Expression : public TreeHandle {
friend class Multiplication;
friend class MultiplicationNode;
friend class NaperianLogarithm;
friend class NormPDF;
friend class NormalDistributionFunction;
friend class NormCDF;
friend class NormCDF2;
friend class NormPDF;
friend class NthRoot;
friend class Number;
friend class Opposite;

View File

@@ -2,11 +2,11 @@
#define POINCARE_INV_NORM_H
#include <poincare/approximation_helper.h>
#include <poincare/expression.h>
#include <poincare/normal_distribution_function.h>
namespace Poincare {
class InvNormNode final : public ExpressionNode {
class InvNormNode final : public NormalDistributionFunctionNode {
public:
// TreeNode
@@ -28,8 +28,6 @@ private:
// Simplication
Expression shallowReduce(ReductionContext reductionContext) override;
LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; };
LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; }
// Evaluation
Evaluation<float> approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, complexFormat, angleUnit); }
@@ -37,9 +35,9 @@ private:
template<typename T> Evaluation<T> templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;
};
class InvNorm final : public Expression {
class InvNorm final : public NormalDistributionFunction {
public:
InvNorm(const InvNormNode * n) : Expression(n) {}
InvNorm(const InvNormNode * n) : NormalDistributionFunction(n) {}
static InvNorm Builder(Expression child0, Expression child1, Expression child2) { return TreeHandle::FixedArityBuilder<InvNorm, InvNormNode>(ArrayBuilder<TreeHandle>(child0, child1, child2).array(), 3); }
static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("invnorm", 3, &UntypedBuilderThreeChildren<InvNorm>);
Expression shallowReduce(ExpressionNode::ReductionContext reductionContext);

View File

@@ -2,11 +2,11 @@
#define POINCARE_NORM_CDF_H
#include <poincare/approximation_helper.h>
#include <poincare/expression.h>
#include <poincare/normal_distribution_function.h>
namespace Poincare {
class NormCDFNode final : public ExpressionNode {
class NormCDFNode final : public NormalDistributionFunctionNode {
public:
// TreeNode
@@ -28,23 +28,17 @@ private:
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
// Simplication
Expression shallowReduce(ReductionContext reductionContext) override;
LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; };
LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; }
// Evaluation
Evaluation<float> approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, complexFormat, angleUnit); }
Evaluation<double> approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<double>(context, complexFormat, angleUnit); }
template<typename T> Evaluation<T> templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;
};
class NormCDF final : public Expression {
class NormCDF final : public NormalDistributionFunction {
public:
NormCDF(const NormCDFNode * n) : Expression(n) {}
NormCDF(const NormCDFNode * n) : NormalDistributionFunction(n) {}
static NormCDF Builder(Expression child0, Expression child1, Expression child2) { return TreeHandle::FixedArityBuilder<NormCDF, NormCDFNode>(ArrayBuilder<TreeHandle>(child0, child1, child2).array(), 3); }
static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("normcdf", 3, &UntypedBuilderThreeChildren<NormCDF>);
Expression shallowReduce(ExpressionNode::ReductionContext reductionContext);
};
}

View File

@@ -2,11 +2,11 @@
#define POINCARE_NORM_CDF2_H
#include <poincare/approximation_helper.h>
#include <poincare/expression.h>
#include <poincare/normal_distribution_function.h>
namespace Poincare {
class NormCDF2Node final : public ExpressionNode {
class NormCDF2Node final : public NormalDistributionFunctionNode {
public:
// TreeNode
@@ -28,23 +28,17 @@ private:
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
// Simplication
Expression shallowReduce(ReductionContext reductionContext) override;
LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; };
LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; }
// Evaluation
Evaluation<float> approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, complexFormat, angleUnit); }
Evaluation<double> approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<double>(context, complexFormat, angleUnit); }
template<typename T> Evaluation<T> templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;
};
class NormCDF2 final : public Expression {
class NormCDF2 final : public NormalDistributionFunction {
public:
NormCDF2(const NormCDF2Node * n) : Expression(n) {}
NormCDF2(const NormCDF2Node * n) : NormalDistributionFunction(n) {}
static NormCDF2 Builder(Expression child0, Expression child1, Expression child2, Expression child3) { return TreeHandle::FixedArityBuilder<NormCDF2, NormCDF2Node>(ArrayBuilder<TreeHandle>(child0, child1, child2, child3).array(), 4); }
static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("normcdf2", 4, &UntypedBuilderFourChildren<NormCDF2>);
Expression shallowReduce(ExpressionNode::ReductionContext reductionContext);
};
}

View File

@@ -2,11 +2,11 @@
#define POINCARE_NORM_PDF_H
#include <poincare/approximation_helper.h>
#include <poincare/expression.h>
#include <poincare/normal_distribution_function.h>
namespace Poincare {
class NormPDFNode final : public ExpressionNode {
class NormPDFNode final : public NormalDistributionFunctionNode {
public:
// TreeNode
@@ -28,23 +28,17 @@ private:
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
// Simplication
Expression shallowReduce(ReductionContext reductionContext) override;
LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; };
LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; }
// Evaluation
Evaluation<float> approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, complexFormat, angleUnit); }
Evaluation<double> approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<double>(context, complexFormat, angleUnit); }
template<typename T> Evaluation<T> templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;
};
class NormPDF final : public Expression {
class NormPDF final : public NormalDistributionFunction {
public:
NormPDF(const NormPDFNode * n) : Expression(n) {}
NormPDF(const NormPDFNode * n) : NormalDistributionFunction(n) {}
static NormPDF Builder(Expression child0, Expression child1, Expression child2) { return TreeHandle::FixedArityBuilder<NormPDF, NormPDFNode>(ArrayBuilder<TreeHandle>(child0, child1, child2).array(), 3); }
static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("normpdf", 3, &UntypedBuilderThreeChildren<NormPDF>);
Expression shallowReduce(ExpressionNode::ReductionContext reductionContext);
};
}

View File

@@ -0,0 +1,26 @@
#ifndef POINCARE_NORMAL_DISTRIBUTION_FUNCTION_H
#define POINCARE_NORMAL_DISTRIBUTION_FUNCTION_H
#include <poincare/expression.h>
namespace Poincare {
// NormalDistributionFunctions are NormCDF, NormCDF2, InvNorm and NormPDF
class NormalDistributionFunctionNode : public ExpressionNode {
private:
// Simplication
Expression shallowReduce(ReductionContext reductionContext) override;
LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; };
LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; }
};
class NormalDistributionFunction : public Expression {
public:
NormalDistributionFunction(const NormalDistributionFunctionNode * n) : Expression(n) {}
Expression shallowReduce(Context * context, bool * stopReduction = nullptr);
};
}
#endif

View File

@@ -41,26 +41,16 @@ Evaluation<T> InvNormNode::templatedApproximate(Context * context, Preferences::
Expression InvNorm::shallowReduce(ExpressionNode::ReductionContext reductionContext) {
{
Expression e = Expression::defaultShallowReduce();
if (e.isUndefined()) {
bool stopReduction = false;
Expression e = NormalDistributionFunction::shallowReduce(reductionContext.context(), &stopReduction);
if (stopReduction) {
return e;
}
}
Expression a = childAtIndex(0);
Expression mu = childAtIndex(1);
Expression var = childAtIndex(2);
Context * context = reductionContext.context();
// Check mu and var
bool muAndVarOK = false;
bool couldCheckMuAndVar = NormalDistribution::ExpressionParametersAreOK(&muAndVarOK, mu, var, context);
if (!couldCheckMuAndVar) {
return *this;
}
if (!muAndVarOK) {
return replaceWithUndefinedInPlace();
}
// Check a
if (a.deepIsMatrix(context)) {
return replaceWithUndefinedInPlace();

View File

@@ -23,10 +23,6 @@ int NormCDFNode::serialize(char * buffer, int bufferSize, Preferences::PrintFloa
return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, NormCDF::s_functionHelper.name());
}
Expression NormCDFNode::shallowReduce(ReductionContext reductionContext) {
return NormCDF(this).shallowReduce(reductionContext);
}
template<typename T>
Evaluation<T> NormCDFNode::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const {
Evaluation<T> aEvaluation = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit);
@@ -41,29 +37,4 @@ Evaluation<T> NormCDFNode::templatedApproximate(Context * context, Preferences::
return Complex<T>::Builder(NormalDistribution::CumulativeDistributiveFunctionAtAbscissa(a, mu, var));
}
Expression NormCDF::shallowReduce(ExpressionNode::ReductionContext reductionContext) {
{
Expression e = Expression::defaultShallowReduce();
if (e.isUndefined()) {
return e;
}
}
Expression mu = childAtIndex(1);
Expression var = childAtIndex(2);
Context * context = reductionContext.context();
// Check mu and var
bool muAndVarOK = false;
bool couldCheckMuAndVar = NormalDistribution::ExpressionParametersAreOK(&muAndVarOK, mu, var, context);
if (!couldCheckMuAndVar) {
return *this;
}
if (!muAndVarOK) {
return replaceWithUndefinedInPlace();
}
return *this;
}
}

View File

@@ -23,10 +23,6 @@ int NormCDF2Node::serialize(char * buffer, int bufferSize, Preferences::PrintFlo
return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, NormCDF2::s_functionHelper.name());
}
Expression NormCDF2Node::shallowReduce(ReductionContext reductionContext) {
return NormCDF2(this).shallowReduce(reductionContext);
}
template<typename T>
Evaluation<T> NormCDF2Node::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const {
Evaluation<T> aEvaluation = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit);
@@ -48,29 +44,4 @@ Evaluation<T> NormCDF2Node::templatedApproximate(Context * context, Preferences:
return Complex<T>::Builder(NormalDistribution::CumulativeDistributiveFunctionAtAbscissa(b, mu, var) - NormalDistribution::CumulativeDistributiveFunctionAtAbscissa(a, mu, var));
}
Expression NormCDF2::shallowReduce(ExpressionNode::ReductionContext reductionContext) {
{
Expression e = Expression::defaultShallowReduce();
if (e.isUndefined()) {
return e;
}
}
// TODO Factorize with norm_cdf and inv_norm ?
Expression mu = childAtIndex(2);
Expression var = childAtIndex(3);
Context * context = reductionContext.context();
// Check mu and var
bool muAndVarOK = false;
bool couldCheckMuAndVar = NormalDistribution::ExpressionParametersAreOK(&muAndVarOK, mu, var, context);
if (!couldCheckMuAndVar) {
return *this;
}
if (!muAndVarOK) {
return replaceWithUndefinedInPlace();
}
return *this;
}
}

View File

@@ -23,10 +23,6 @@ int NormPDFNode::serialize(char * buffer, int bufferSize, Preferences::PrintFloa
return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, NormPDF::s_functionHelper.name());
}
Expression NormPDFNode::shallowReduce(ReductionContext reductionContext) {
return NormPDF(this).shallowReduce(reductionContext);
}
template<typename T>
Evaluation<T> NormPDFNode::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const {
Evaluation<T> xEvaluation = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit);
@@ -41,28 +37,4 @@ Evaluation<T> NormPDFNode::templatedApproximate(Context * context, Preferences::
return Complex<T>::Builder(NormalDistribution::EvaluateAtAbscissa(x, mu, var));
}
Expression NormPDF::shallowReduce(ExpressionNode::ReductionContext reductionContext) {
{
Expression e = Expression::defaultShallowReduce();
if (e.isUndefined()) {
return e;
}
}
Expression mu = childAtIndex(1);
Expression var = childAtIndex(2);
Context * context = reductionContext.context();
// Check mu and var
bool muAndVarOK = false;
bool couldCheckMuAndVar = NormalDistribution::ExpressionParametersAreOK(&muAndVarOK, mu, var, context);
if (!couldCheckMuAndVar) {
return *this;
}
if (!muAndVarOK) {
return replaceWithUndefinedInPlace();
}
return *this;
}
}

View File

@@ -0,0 +1,41 @@
#include <poincare/normal_distribution_function.h>
#include <poincare/normal_distribution.h>
#include <assert.h>
namespace Poincare {
Expression NormalDistributionFunctionNode::shallowReduce(ReductionContext reductionContext) {
return NormalDistributionFunction(this).shallowReduce(reductionContext.context());
}
Expression NormalDistributionFunction::shallowReduce(Context * context, bool * stopReduction) {
if (stopReduction != nullptr) {
*stopReduction = true;
}
{
Expression e = Expression::defaultShallowReduce();
if (e.isUndefined()) {
return e;
}
}
Expression mu = childAtIndex(1);
Expression var = childAtIndex(2);
// Check mu and var
bool muAndVarOK = false;
bool couldCheckMuAndVar = NormalDistribution::ExpressionParametersAreOK(&muAndVarOK, mu, var, context);
if (!couldCheckMuAndVar) {
return *this;
}
if (!muAndVarOK) {
return replaceWithUndefinedInPlace();
}
if (stopReduction != nullptr) {
*stopReduction = false;
}
return *this;
}
}