[poincare/vectors] Implement Vectors operations

Change-Id: Ib5aa2f5951f4eabe3b04078eab8b7d3a4d3769e7
This commit is contained in:
Hugo Saint-Vignes
2020-07-21 17:44:15 +02:00
committed by Émilie Feral
parent ca91b7c43d
commit 6db02ea122
19 changed files with 462 additions and 5 deletions

View File

@@ -151,6 +151,9 @@ poincare_src += $(addprefix poincare/src/,\
unit_convert.cpp \
unreal.cpp \
variable_context.cpp \
vector_cross.cpp \
vector_dot.cpp \
vector_norm.cpp \
)
poincare_src += $(addprefix poincare/src/parsing/,\

View File

@@ -34,6 +34,9 @@ public:
Expression complexToExpression(Preferences::Preferences::ComplexFormat complexFormat) const override;
std::complex<T> trace() const override { return *this; }
std::complex<T> determinant() const override { return *this; }
Evaluation<T> cross(Evaluation<T> * e) const override { return Complex<T>::Undefined(); }
std::complex<T> dot(Evaluation<T> * e) const override { return std::complex<T>(NAN, NAN); }
std::complex<T> norm() const override { return std::complex<T>(NAN, NAN); }
};
template<typename T>

View File

@@ -32,6 +32,9 @@ public:
virtual Expression complexToExpression(Preferences::ComplexFormat complexFormat) const = 0;
virtual std::complex<T> trace() const = 0;
virtual std::complex<T> determinant() const = 0;
virtual Evaluation<T> cross(Evaluation<T> * e) const = 0;
virtual std::complex<T> dot(Evaluation<T> * e) const = 0;
virtual std::complex<T> norm() const = 0;
};
template<typename T>
@@ -64,6 +67,9 @@ public:
Expression complexToExpression(Preferences::ComplexFormat complexFormat) const;
std::complex<T> trace() const { return node()->trace(); }
std::complex<T> determinant() const { return node()->determinant(); }
Evaluation<T> cross(Evaluation<T> * e) const { return node()->cross(e); }
std::complex<T> dot(Evaluation<T> * e) const { return node()->dot(e); }
std::complex<T> norm() const { return node()->norm(); }
protected:
Evaluation(EvaluationNode<T> * n) : TreeHandle(n) {}
};

View File

@@ -103,6 +103,9 @@ class Expression : public TreeHandle {
friend class TrigonometryCheatTable;
friend class Unit;
friend class UnitConvert;
friend class VectorCross;
friend class VectorDot;
friend class VectorNorm;
friend class AdditionNode;
friend class DerivativeNode;

View File

@@ -23,7 +23,8 @@ class ExpressionNode : public TreeNode {
friend class PowerNode;
friend class SymbolNode;
public:
enum class Type : uint8_t {
// The types order is important here.
enum class Type : uint8_t {
Uninitialized = 0,
Undefined = 1,
Unreal,
@@ -33,6 +34,8 @@ public:
Double,
Float,
Infinity,
/* When merging number nodes together, we do a linear scan which stops at
* the first non-number child. */
Multiplication,
Power,
Addition,
@@ -95,10 +98,15 @@ public:
SquareRoot,
Subtraction,
Sum,
VectorDot,
VectorNorm,
/* When sorting the children of an expression, we assert that the following
* nodes are at the end of the list : */
// - Units
Unit,
// - Complexes
ComplexCartesian,
// - Any kind of matrices :
ConfidenceInterval,
MatrixDimension,
MatrixIdentity,
@@ -107,9 +115,11 @@ public:
MatrixRowEchelonForm,
MatrixReducedRowEchelonForm,
PredictionInterval,
VectorCross,
Matrix,
EmptyExpression
};
};
/* Poor man's RTTI */
virtual Type type() const = 0;

View File

@@ -85,6 +85,9 @@ public:
* not. */
Expression createInverse(ExpressionNode::ReductionContext reductionContext, bool * couldComputeInverse) const;
Expression determinant(ExpressionNode::ReductionContext reductionContext, bool * couldComputeDeterminant, bool inPlace);
Expression norm(ExpressionNode::ReductionContext reductionContext) const;
Expression dot(Matrix * b, ExpressionNode::ReductionContext reductionContext) const;
Matrix cross(Matrix * b, ExpressionNode::ReductionContext reductionContext) const;
// TODO: find another solution for inverse and determinant (avoid capping the matrix)
static constexpr int k_maxNumberOfCoefficients = 100;

View File

@@ -47,6 +47,9 @@ public:
MatrixComplex<T> inverse() const;
MatrixComplex<T> transpose() const;
MatrixComplex<T> ref(bool reduced) const;
std::complex<T> norm() const override;
std::complex<T> dot(Evaluation<T> * e) const override;
Evaluation<T> cross(Evaluation<T> * e) const override;
private:
// See comment on Matrix
uint16_t m_numberOfRows;

View File

@@ -0,0 +1,48 @@
#ifndef POINCARE_VECTOR_CROSS_H
#define POINCARE_VECTOR_CROSS_H
#include <poincare/expression.h>
namespace Poincare {
class VectorCrossNode final : public ExpressionNode {
public:
// TreeNode
size_t size() const override { return sizeof(VectorCrossNode); }
int numberOfChildren() const override;
#if POINCARE_TREE_LOG
void logNodeName(std::ostream & stream) const override {
stream << "VectorCross";
}
#endif
// Properties
Type type() const override { return Type::VectorCross; }
private:
// Layout
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
// Simplification
Expression shallowReduce(ReductionContext reductionContext) override;
LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; };
LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; }
// Evaluation
Evaluation<float> approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, complexFormat, angleUnit); }
Evaluation<double> approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<double>(context, complexFormat, angleUnit); }
template<typename T> Evaluation<T> templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;
};
class VectorCross final : public Expression {
public:
VectorCross(const VectorCrossNode * n) : Expression(n) {}
static VectorCross Builder(Expression child0, Expression child1) { return TreeHandle::FixedArityBuilder<VectorCross, VectorCrossNode>({child0, child1}); }
static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("cross", 2, &UntypedBuilderTwoChildren<VectorCross>);
Expression shallowReduce(ExpressionNode::ReductionContext reductionContext);
};
}
#endif

View File

@@ -0,0 +1,48 @@
#ifndef POINCARE_VECTOR_DOT_H
#define POINCARE_VECTOR_DOT_H
#include <poincare/expression.h>
namespace Poincare {
class VectorDotNode final : public ExpressionNode {
public:
// TreeNode
size_t size() const override { return sizeof(VectorDotNode); }
int numberOfChildren() const override;
#if POINCARE_TREE_LOG
void logNodeName(std::ostream & stream) const override {
stream << "VectorDot";
}
#endif
// Properties
Type type() const override { return Type::VectorDot; }
private:
// Layout
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
// Simplification
Expression shallowReduce(ReductionContext reductionContext) override;
LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; };
LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; }
// Evaluation
Evaluation<float> approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, complexFormat, angleUnit); }
Evaluation<double> approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<double>(context, complexFormat, angleUnit); }
template<typename T> Evaluation<T> templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;
};
class VectorDot final : public Expression {
public:
VectorDot(const VectorDotNode * n) : Expression(n) {}
static VectorDot Builder(Expression child0, Expression child1) { return TreeHandle::FixedArityBuilder<VectorDot, VectorDotNode>({child0, child1}); }
static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("dot", 2, &UntypedBuilderTwoChildren<VectorDot>);
Expression shallowReduce(ExpressionNode::ReductionContext reductionContext);
};
}
#endif

View File

@@ -0,0 +1,48 @@
#ifndef POINCARE_VECTOR_NORM_H
#define POINCARE_VECTOR_NORM_H
#include <poincare/expression.h>
namespace Poincare {
class VectorNormNode final : public ExpressionNode {
public:
// TreeNode
size_t size() const override { return sizeof(VectorNormNode); }
int numberOfChildren() const override;
#if POINCARE_TREE_LOG
void logNodeName(std::ostream & stream) const override {
stream << "VectorNorm";
}
#endif
// Properties
Type type() const override { return Type::VectorNorm; }
private:
// Layout
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
// Simplification
Expression shallowReduce(ReductionContext reductionContext) override;
LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; };
LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; }
// Evaluation
Evaluation<float> approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<float>(context, complexFormat, angleUnit); }
Evaluation<double> approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate<double>(context, complexFormat, angleUnit); }
template<typename T> Evaluation<T> templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;
};
class VectorNorm final : public Expression {
public:
VectorNorm(const VectorNormNode * n) : Expression(n) {}
static VectorNorm Builder(Expression child) { return TreeHandle::FixedArityBuilder<VectorNorm, VectorNormNode>({child}); }
static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("norm", 1, &UntypedBuilderOneChild<VectorNorm>);
Expression shallowReduce(ExpressionNode::ReductionContext reductionContext);
};
}
#endif

View File

@@ -89,5 +89,8 @@
#include <poincare/unit_convert.h>
#include <poincare/unreal.h>
#include <poincare/variable_context.h>
#include <poincare/vector_cross.h>
#include <poincare/vector_dot.h>
#include <poincare/vector_norm.h>
#endif

View File

@@ -187,7 +187,8 @@ bool Expression::IsMatrix(const Expression e, Context * context) {
|| e.type() == ExpressionNode::Type::MatrixIdentity
|| e.type() == ExpressionNode::Type::MatrixTranspose
|| e.type() == ExpressionNode::Type::MatrixRowEchelonForm
|| e.type() == ExpressionNode::Type::MatrixReducedRowEchelonForm;
|| e.type() == ExpressionNode::Type::MatrixReducedRowEchelonForm
|| e.type() == ExpressionNode::Type::VectorCross;
}
bool Expression::IsInfinity(const Expression e, Context * context) {

View File

@@ -6,8 +6,10 @@
#include <poincare/matrix_complex.h>
#include <poincare/matrix_layout.h>
#include <poincare/multiplication.h>
#include <poincare/power.h>
#include <poincare/rational.h>
#include <poincare/serialization_helper.h>
#include <poincare/square_root.h>
#include <poincare/subtraction.h>
#include <poincare/undefined.h>
#include <assert.h>
@@ -483,6 +485,52 @@ Expression Matrix::determinant(ExpressionNode::ReductionContext reductionContext
return result;
}
Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const {
assert(numberOfColumns() == 1);
Addition sum = Addition::Builder();
for (int j = 0; j < numberOfRows(); j++) {
Expression absValue = AbsoluteValue::Builder(const_cast<Matrix *>(this)->matrixChild(0, j).clone());
Expression squaredAbsValue = Power::Builder(absValue, Rational::Builder(2));
absValue.shallowReduce(reductionContext);
sum.addChildAtIndexInPlace(squaredAbsValue, sum.numberOfChildren(), sum.numberOfChildren());
squaredAbsValue.shallowReduce(reductionContext);
}
Expression result = SquareRoot::Builder(sum);
sum.shallowReduce(reductionContext);
return result;
}
Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionContext) const {
// Dot product is defined between two vectors of same size
assert(numberOfRows() == b->numberOfRows() && numberOfColumns() == 1 && b->numberOfColumns() == 1);
Addition sum = Addition::Builder();
for (int j = 0; j < numberOfRows(); j++) {
Expression product = Multiplication::Builder(const_cast<Matrix *>(this)->matrixChild(0, j).clone(), const_cast<Matrix *>(b)->matrixChild(0, j).clone());
sum.addChildAtIndexInPlace(product, sum.numberOfChildren(), sum.numberOfChildren());
product.shallowReduce(reductionContext);
}
return std::move(sum);
}
Matrix Matrix::cross(Matrix * b, ExpressionNode::ReductionContext reductionContext) const {
// Cross product is defined between two vectors of size 3
assert(numberOfRows() == 3 && numberOfColumns() == 1 && b->numberOfRows() == 3 && b->numberOfColumns() == 1);
Matrix matrix = Matrix::Builder();
for (int j = 0; j < 3; j++) {
int j1 = (j+1)%3;
int j2 = (j+2)%3;
Expression a1b2 = Multiplication::Builder(const_cast<Matrix *>(this)->matrixChild(0, j1).clone(), const_cast<Matrix *>(b)->matrixChild(0, j2).clone());
Expression a2b1 = Multiplication::Builder(const_cast<Matrix *>(this)->matrixChild(0, j2).clone(), const_cast<Matrix *>(b)->matrixChild(0, j1).clone());
Expression difference = Subtraction::Builder(a1b2, a2b1);
a1b2.shallowReduce(reductionContext);
a2b1.shallowReduce(reductionContext);
matrix.addChildAtIndexInPlace(difference, matrix.numberOfChildren(), matrix.numberOfChildren());
difference.shallowReduce(reductionContext);
}
matrix.setDimensions(3, 1);
return matrix;
}
Expression Matrix::shallowReduce(Context * context) {
{
Expression e = Expression::defaultShallowReduce();

View File

@@ -135,6 +135,50 @@ MatrixComplex<T> MatrixComplexNode<T>::ref(bool reduced) const {
return MatrixComplex<T>::Builder(operandsCopy, m_numberOfRows, m_numberOfColumns);
}
template<typename T>
std::complex<T> MatrixComplexNode<T>::norm() const {
if (numberOfChildren() == 0 || numberOfColumns() > 1) {
return std::complex<T>(NAN, NAN);
}
std::complex<T> sum = 0;
for (int i = 0; i < numberOfChildren(); i++) {
sum += std::norm(complexAtIndex(i));
}
return std::sqrt(sum);
}
template<typename T>
std::complex<T> MatrixComplexNode<T>::dot(Evaluation<T> * e) const {
if (e->type() != EvaluationNode<T>::Type::MatrixComplex) {
return std::complex<T>(NAN, NAN);
}
MatrixComplex<T> * b = static_cast<MatrixComplex<T>*>(e);
if (numberOfChildren() == 0 || numberOfColumns() > 1 || b->numberOfChildren() == 0 || b->numberOfColumns() > 1 || numberOfRows() != b->numberOfRows()) {
return std::complex<T>(NAN, NAN);
}
std::complex<T> sum = 0;
for (int i = 0; i < numberOfChildren(); i++) {
sum += complexAtIndex(i) * b->complexAtIndex(i);
}
return sum;
}
template<typename T>
Evaluation<T> MatrixComplexNode<T>::cross(Evaluation<T> * e) const {
if (e->type() != EvaluationNode<T>::Type::MatrixComplex) {
return MatrixComplex<T>::Undefined();
}
MatrixComplex<T> * b = static_cast<MatrixComplex<T>*>(e);
if (numberOfChildren() == 0 || numberOfColumns() != 1 || numberOfRows() != 3 || b->numberOfChildren() == 0 || b->numberOfColumns() != 1 || b->numberOfRows() != 3) {
return MatrixComplex<T>::Undefined();
}
std::complex<T> operandsCopy[3];
operandsCopy[0] = complexAtIndex(1) * b->complexAtIndex(2) - complexAtIndex(2) * b->complexAtIndex(1);
operandsCopy[1] = complexAtIndex(2) * b->complexAtIndex(0) - complexAtIndex(0) * b->complexAtIndex(2);
operandsCopy[2] = complexAtIndex(0) * b->complexAtIndex(1) - complexAtIndex(1) * b->complexAtIndex(0);
return MatrixComplex<T>::Builder(operandsCopy, 3, 1);
}
// MATRIX COMPLEX REFERENCE
template<typename T>

View File

@@ -110,9 +110,11 @@ private:
&Conjugate::s_functionHelper,
&Cosine::s_functionHelper,
&HyperbolicCosine::s_functionHelper,
&VectorCross::s_functionHelper,
&Determinant::s_functionHelper,
&Derivative::s_functionHelper,
&MatrixDimension::s_functionHelper,
&VectorDot::s_functionHelper,
&Factor::s_functionHelper,
&Floor::s_functionHelper,
&FracPart::s_functionHelper,
@@ -127,6 +129,7 @@ private:
&NaperianLogarithm::s_functionHelper,
&CommonLogarithm::s_functionHelper,
&Logarithm::s_functionHelper,
&VectorNorm::s_functionHelper,
&NormCDF::s_functionHelper,
&NormCDF2::s_functionHelper,
&NormPDF::s_functionHelper,

View File

@@ -369,6 +369,9 @@ template Tangent TreeHandle::FixedArityBuilder<Tangent, TangentNode>(const Tuple
template Undefined TreeHandle::FixedArityBuilder<Undefined, UndefinedNode>(const Tuple &);
template UnitConvert TreeHandle::FixedArityBuilder<UnitConvert, UnitConvertNode>(const Tuple &);
template Unreal TreeHandle::FixedArityBuilder<Unreal, UnrealNode>(const Tuple &);
template VectorCross TreeHandle::FixedArityBuilder<VectorCross, VectorCrossNode>(const Tuple &);
template VectorDot TreeHandle::FixedArityBuilder<VectorDot, VectorDotNode>(const Tuple &);
template VectorNorm TreeHandle::FixedArityBuilder<VectorNorm, VectorNormNode>(const Tuple &);
template MatrixLayout TreeHandle::NAryBuilder<MatrixLayout, MatrixLayoutNode>(const Tuple &);
}

View File

@@ -0,0 +1,61 @@
#include <poincare/vector_cross.h>
#include <poincare/division.h>
#include <poincare/layout_helper.h>
#include <poincare/matrix.h>
#include <poincare/serialization_helper.h>
#include <poincare/undefined.h>
namespace Poincare {
constexpr Expression::FunctionHelper VectorCross::s_functionHelper;
int VectorCrossNode::numberOfChildren() const { return VectorCross::s_functionHelper.numberOfChildren(); }
Expression VectorCrossNode::shallowReduce(ReductionContext reductionContext) {
return VectorCross(this).shallowReduce(reductionContext);
}
Layout VectorCrossNode::createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const {
return LayoutHelper::Prefix(VectorCross(this), floatDisplayMode, numberOfSignificantDigits, VectorCross::s_functionHelper.name());
}
int VectorCrossNode::serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const {
return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, VectorCross::s_functionHelper.name());
}
template<typename T>
Evaluation<T> VectorCrossNode::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const {
Evaluation<T> input0 = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit);
Evaluation<T> input1 = childAtIndex(1)->approximate(T(), context, complexFormat, angleUnit);
return input0.cross(&input1);
}
Expression VectorCross::shallowReduce(ExpressionNode::ReductionContext reductionContext) {
{
Expression e = Expression::defaultShallowReduce();
e = e.defaultHandleUnitsInChildren();
if (e.isUndefined()) {
return e;
}
}
Expression c0 = childAtIndex(0);
Expression c1 = childAtIndex(1);
if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) {
Matrix matrixChild0 = static_cast<Matrix&>(c0);
Matrix matrixChild1 = static_cast<Matrix&>(c1);
// Cross product is defined between two column matrices of size 3
if (matrixChild0.numberOfColumns() != 1 || matrixChild1.numberOfColumns() != 1 || matrixChild0.numberOfRows() != 3 || matrixChild1.numberOfRows() != 3) {
return replaceWithUndefinedInPlace();
}
Expression a = matrixChild0.cross(&matrixChild1, reductionContext);
replaceWithInPlace(a);
return a.shallowReduce(reductionContext);
}
if (c0.deepIsMatrix(reductionContext.context()) && c1.deepIsMatrix(reductionContext.context())) {
return *this;
}
return replaceWithUndefinedInPlace();
}
}

View File

@@ -0,0 +1,61 @@
#include <poincare/vector_dot.h>
#include <poincare/addition.h>
#include <poincare/layout_helper.h>
#include <poincare/matrix.h>
#include <poincare/serialization_helper.h>
#include <poincare/undefined.h>
namespace Poincare {
constexpr Expression::FunctionHelper VectorDot::s_functionHelper;
int VectorDotNode::numberOfChildren() const { return VectorDot::s_functionHelper.numberOfChildren(); }
Expression VectorDotNode::shallowReduce(ReductionContext reductionContext) {
return VectorDot(this).shallowReduce(reductionContext);
}
Layout VectorDotNode::createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const {
return LayoutHelper::Prefix(VectorDot(this), floatDisplayMode, numberOfSignificantDigits, VectorDot::s_functionHelper.name());
}
int VectorDotNode::serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const {
return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, VectorDot::s_functionHelper.name());
}
template<typename T>
Evaluation<T> VectorDotNode::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const {
Evaluation<T> input0 = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit);
Evaluation<T> input1 = childAtIndex(1)->approximate(T(), context, complexFormat, angleUnit);
return Complex<T>::Builder(input0.dot(&input1));
}
Expression VectorDot::shallowReduce(ExpressionNode::ReductionContext reductionContext) {
{
Expression e = Expression::defaultShallowReduce();
e = e.defaultHandleUnitsInChildren();
if (e.isUndefined()) {
return e;
}
}
Expression c0 = childAtIndex(0);
Expression c1 = childAtIndex(1);
if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) {
Matrix matrixChild0 = static_cast<Matrix&>(c0);
Matrix matrixChild1 = static_cast<Matrix&>(c1);
// Dot product is defined between two column matrices of the same dimensions
if (matrixChild0.numberOfColumns() != 1 || matrixChild1.numberOfColumns() != 1 || matrixChild0.numberOfRows() != matrixChild1.numberOfRows()) {
return replaceWithUndefinedInPlace();
}
Expression a = matrixChild0.dot(&matrixChild1, reductionContext);
replaceWithInPlace(a);
return a.shallowReduce(reductionContext);
}
if (c0.deepIsMatrix(reductionContext.context()) && c1.deepIsMatrix(reductionContext.context())) {
return *this;
}
return replaceWithUndefinedInPlace();
}
}

View File

@@ -0,0 +1,58 @@
#include <poincare/vector_norm.h>
#include <poincare/addition.h>
#include <poincare/layout_helper.h>
#include <poincare/matrix.h>
#include <poincare/serialization_helper.h>
#include <poincare/undefined.h>
namespace Poincare {
constexpr Expression::FunctionHelper VectorNorm::s_functionHelper;
int VectorNormNode::numberOfChildren() const { return VectorNorm::s_functionHelper.numberOfChildren(); }
Expression VectorNormNode::shallowReduce(ReductionContext reductionContext) {
return VectorNorm(this).shallowReduce(reductionContext);
}
Layout VectorNormNode::createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const {
return LayoutHelper::Prefix(VectorNorm(this), floatDisplayMode, numberOfSignificantDigits, VectorNorm::s_functionHelper.name());
}
int VectorNormNode::serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const {
return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, VectorNorm::s_functionHelper.name());
}
template<typename T>
Evaluation<T> VectorNormNode::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const {
Evaluation<T> input = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit);
return Complex<T>::Builder(input.norm());
}
Expression VectorNorm::shallowReduce(ExpressionNode::ReductionContext reductionContext) {
{
Expression e = Expression::defaultShallowReduce();
e = e.defaultHandleUnitsInChildren();
if (e.isUndefined()) {
return e;
}
}
Expression c = childAtIndex(0);
if (c.type() == ExpressionNode::Type::Matrix) {
Matrix matrixChild = static_cast<Matrix&>(c);
if (matrixChild.numberOfColumns() != 1) {
// Norm is only defined on column matrices
return replaceWithUndefinedInPlace();
}
Expression a = matrixChild.norm(reductionContext);
replaceWithInPlace(a);
return a.shallowReduce(reductionContext);
}
if (c.deepIsMatrix(reductionContext.context())) {
return *this;
}
return replaceWithUndefinedInPlace();
}
}