From 6db02ea1221e45b4b15e4cf24c6b0e2b9ba3b840 Mon Sep 17 00:00:00 2001 From: Hugo Saint-Vignes Date: Tue, 21 Jul 2020 17:44:15 +0200 Subject: [PATCH] [poincare/vectors] Implement Vectors operations Change-Id: Ib5aa2f5951f4eabe3b04078eab8b7d3a4d3769e7 --- poincare/Makefile | 3 + poincare/include/poincare/complex.h | 3 + poincare/include/poincare/evaluation.h | 6 ++ poincare/include/poincare/expression.h | 3 + poincare/include/poincare/expression_node.h | 18 ++++-- poincare/include/poincare/matrix.h | 3 + poincare/include/poincare/matrix_complex.h | 3 + poincare/include/poincare/vector_cross.h | 48 ++++++++++++++++ poincare/include/poincare/vector_dot.h | 48 ++++++++++++++++ poincare/include/poincare/vector_norm.h | 48 ++++++++++++++++ poincare/include/poincare_nodes.h | 3 + poincare/src/expression.cpp | 3 +- poincare/src/matrix.cpp | 48 ++++++++++++++++ poincare/src/matrix_complex.cpp | 44 +++++++++++++++ poincare/src/parsing/parser.h | 3 + poincare/src/tree_handle.cpp | 3 + poincare/src/vector_cross.cpp | 61 +++++++++++++++++++++ poincare/src/vector_dot.cpp | 61 +++++++++++++++++++++ poincare/src/vector_norm.cpp | 58 ++++++++++++++++++++ 19 files changed, 462 insertions(+), 5 deletions(-) create mode 100644 poincare/include/poincare/vector_cross.h create mode 100644 poincare/include/poincare/vector_dot.h create mode 100644 poincare/include/poincare/vector_norm.h create mode 100644 poincare/src/vector_cross.cpp create mode 100644 poincare/src/vector_dot.cpp create mode 100644 poincare/src/vector_norm.cpp diff --git a/poincare/Makefile b/poincare/Makefile index 71dcd46db..a6c0524b9 100644 --- a/poincare/Makefile +++ b/poincare/Makefile @@ -151,6 +151,9 @@ poincare_src += $(addprefix poincare/src/,\ unit_convert.cpp \ unreal.cpp \ variable_context.cpp \ + vector_cross.cpp \ + vector_dot.cpp \ + vector_norm.cpp \ ) poincare_src += $(addprefix poincare/src/parsing/,\ diff --git a/poincare/include/poincare/complex.h b/poincare/include/poincare/complex.h index 5f4aec757..ac5be2ccb 100644 --- a/poincare/include/poincare/complex.h +++ b/poincare/include/poincare/complex.h @@ -34,6 +34,9 @@ public: Expression complexToExpression(Preferences::Preferences::ComplexFormat complexFormat) const override; std::complex trace() const override { return *this; } std::complex determinant() const override { return *this; } + Evaluation cross(Evaluation * e) const override { return Complex::Undefined(); } + std::complex dot(Evaluation * e) const override { return std::complex(NAN, NAN); } + std::complex norm() const override { return std::complex(NAN, NAN); } }; template diff --git a/poincare/include/poincare/evaluation.h b/poincare/include/poincare/evaluation.h index 2cca0158d..d0f508123 100644 --- a/poincare/include/poincare/evaluation.h +++ b/poincare/include/poincare/evaluation.h @@ -32,6 +32,9 @@ public: virtual Expression complexToExpression(Preferences::ComplexFormat complexFormat) const = 0; virtual std::complex trace() const = 0; virtual std::complex determinant() const = 0; + virtual Evaluation cross(Evaluation * e) const = 0; + virtual std::complex dot(Evaluation * e) const = 0; + virtual std::complex norm() const = 0; }; template @@ -64,6 +67,9 @@ public: Expression complexToExpression(Preferences::ComplexFormat complexFormat) const; std::complex trace() const { return node()->trace(); } std::complex determinant() const { return node()->determinant(); } + Evaluation cross(Evaluation * e) const { return node()->cross(e); } + std::complex dot(Evaluation * e) const { return node()->dot(e); } + std::complex norm() const { return node()->norm(); } protected: Evaluation(EvaluationNode * n) : TreeHandle(n) {} }; diff --git a/poincare/include/poincare/expression.h b/poincare/include/poincare/expression.h index d74be3a9f..6e71cc117 100644 --- a/poincare/include/poincare/expression.h +++ b/poincare/include/poincare/expression.h @@ -103,6 +103,9 @@ class Expression : public TreeHandle { friend class TrigonometryCheatTable; friend class Unit; friend class UnitConvert; + friend class VectorCross; + friend class VectorDot; + friend class VectorNorm; friend class AdditionNode; friend class DerivativeNode; diff --git a/poincare/include/poincare/expression_node.h b/poincare/include/poincare/expression_node.h index 853355a89..8e56f3126 100644 --- a/poincare/include/poincare/expression_node.h +++ b/poincare/include/poincare/expression_node.h @@ -23,7 +23,8 @@ class ExpressionNode : public TreeNode { friend class PowerNode; friend class SymbolNode; public: - enum class Type : uint8_t { + // The types order is important here. + enum class Type : uint8_t { Uninitialized = 0, Undefined = 1, Unreal, @@ -33,6 +34,8 @@ public: Double, Float, Infinity, + /* When merging number nodes together, we do a linear scan which stops at + * the first non-number child. */ Multiplication, Power, Addition, @@ -95,10 +98,15 @@ public: SquareRoot, Subtraction, Sum, - + VectorDot, + VectorNorm, + /* When sorting the children of an expression, we assert that the following + * nodes are at the end of the list : */ + // - Units Unit, + // - Complexes ComplexCartesian, - + // - Any kind of matrices : ConfidenceInterval, MatrixDimension, MatrixIdentity, @@ -107,9 +115,11 @@ public: MatrixRowEchelonForm, MatrixReducedRowEchelonForm, PredictionInterval, + VectorCross, Matrix, + EmptyExpression - }; + }; /* Poor man's RTTI */ virtual Type type() const = 0; diff --git a/poincare/include/poincare/matrix.h b/poincare/include/poincare/matrix.h index def7da820..30ea4f5d4 100644 --- a/poincare/include/poincare/matrix.h +++ b/poincare/include/poincare/matrix.h @@ -85,6 +85,9 @@ public: * not. */ Expression createInverse(ExpressionNode::ReductionContext reductionContext, bool * couldComputeInverse) const; Expression determinant(ExpressionNode::ReductionContext reductionContext, bool * couldComputeDeterminant, bool inPlace); + Expression norm(ExpressionNode::ReductionContext reductionContext) const; + Expression dot(Matrix * b, ExpressionNode::ReductionContext reductionContext) const; + Matrix cross(Matrix * b, ExpressionNode::ReductionContext reductionContext) const; // TODO: find another solution for inverse and determinant (avoid capping the matrix) static constexpr int k_maxNumberOfCoefficients = 100; diff --git a/poincare/include/poincare/matrix_complex.h b/poincare/include/poincare/matrix_complex.h index ce2343b1e..275785cd4 100644 --- a/poincare/include/poincare/matrix_complex.h +++ b/poincare/include/poincare/matrix_complex.h @@ -47,6 +47,9 @@ public: MatrixComplex inverse() const; MatrixComplex transpose() const; MatrixComplex ref(bool reduced) const; + std::complex norm() const override; + std::complex dot(Evaluation * e) const override; + Evaluation cross(Evaluation * e) const override; private: // See comment on Matrix uint16_t m_numberOfRows; diff --git a/poincare/include/poincare/vector_cross.h b/poincare/include/poincare/vector_cross.h new file mode 100644 index 000000000..e1f75e03e --- /dev/null +++ b/poincare/include/poincare/vector_cross.h @@ -0,0 +1,48 @@ +#ifndef POINCARE_VECTOR_CROSS_H +#define POINCARE_VECTOR_CROSS_H + +#include + +namespace Poincare { + +class VectorCrossNode final : public ExpressionNode { +public: + + // TreeNode + size_t size() const override { return sizeof(VectorCrossNode); } + int numberOfChildren() const override; +#if POINCARE_TREE_LOG + void logNodeName(std::ostream & stream) const override { + stream << "VectorCross"; + } +#endif + + // Properties + Type type() const override { return Type::VectorCross; } +private: + // Layout + Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override; + int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override; + // Simplification + Expression shallowReduce(ReductionContext reductionContext) override; + LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; }; + LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; } + // Evaluation + Evaluation approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate(context, complexFormat, angleUnit); } + Evaluation approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate(context, complexFormat, angleUnit); } + template Evaluation templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const; +}; + +class VectorCross final : public Expression { +public: + VectorCross(const VectorCrossNode * n) : Expression(n) {} + static VectorCross Builder(Expression child0, Expression child1) { return TreeHandle::FixedArityBuilder({child0, child1}); } + + static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("cross", 2, &UntypedBuilderTwoChildren); + + Expression shallowReduce(ExpressionNode::ReductionContext reductionContext); +}; + +} + +#endif diff --git a/poincare/include/poincare/vector_dot.h b/poincare/include/poincare/vector_dot.h new file mode 100644 index 000000000..3b9c80342 --- /dev/null +++ b/poincare/include/poincare/vector_dot.h @@ -0,0 +1,48 @@ +#ifndef POINCARE_VECTOR_DOT_H +#define POINCARE_VECTOR_DOT_H + +#include + +namespace Poincare { + +class VectorDotNode final : public ExpressionNode { +public: + + // TreeNode + size_t size() const override { return sizeof(VectorDotNode); } + int numberOfChildren() const override; +#if POINCARE_TREE_LOG + void logNodeName(std::ostream & stream) const override { + stream << "VectorDot"; + } +#endif + + // Properties + Type type() const override { return Type::VectorDot; } +private: + // Layout + Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override; + int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override; + // Simplification + Expression shallowReduce(ReductionContext reductionContext) override; + LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; }; + LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; } + // Evaluation + Evaluation approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate(context, complexFormat, angleUnit); } + Evaluation approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate(context, complexFormat, angleUnit); } + template Evaluation templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const; +}; + +class VectorDot final : public Expression { +public: + VectorDot(const VectorDotNode * n) : Expression(n) {} + static VectorDot Builder(Expression child0, Expression child1) { return TreeHandle::FixedArityBuilder({child0, child1}); } + + static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("dot", 2, &UntypedBuilderTwoChildren); + + Expression shallowReduce(ExpressionNode::ReductionContext reductionContext); +}; + +} + +#endif diff --git a/poincare/include/poincare/vector_norm.h b/poincare/include/poincare/vector_norm.h new file mode 100644 index 000000000..f0572ada4 --- /dev/null +++ b/poincare/include/poincare/vector_norm.h @@ -0,0 +1,48 @@ +#ifndef POINCARE_VECTOR_NORM_H +#define POINCARE_VECTOR_NORM_H + +#include + +namespace Poincare { + +class VectorNormNode final : public ExpressionNode { +public: + + // TreeNode + size_t size() const override { return sizeof(VectorNormNode); } + int numberOfChildren() const override; +#if POINCARE_TREE_LOG + void logNodeName(std::ostream & stream) const override { + stream << "VectorNorm"; + } +#endif + + // Properties + Type type() const override { return Type::VectorNorm; } +private: + // Layout + Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override; + int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override; + // Simplification + Expression shallowReduce(ReductionContext reductionContext) override; + LayoutShape leftLayoutShape() const override { return LayoutShape::MoreLetters; }; + LayoutShape rightLayoutShape() const override { return LayoutShape::BoundaryPunctuation; } + // Evaluation + Evaluation approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate(context, complexFormat, angleUnit); } + Evaluation approximate(DoublePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override { return templatedApproximate(context, complexFormat, angleUnit); } + template Evaluation templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const; +}; + +class VectorNorm final : public Expression { +public: + VectorNorm(const VectorNormNode * n) : Expression(n) {} + static VectorNorm Builder(Expression child) { return TreeHandle::FixedArityBuilder({child}); } + + static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("norm", 1, &UntypedBuilderOneChild); + + Expression shallowReduce(ExpressionNode::ReductionContext reductionContext); +}; + +} + +#endif diff --git a/poincare/include/poincare_nodes.h b/poincare/include/poincare_nodes.h index 41f1d0d92..ed7316b8e 100644 --- a/poincare/include/poincare_nodes.h +++ b/poincare/include/poincare_nodes.h @@ -89,5 +89,8 @@ #include #include #include +#include +#include +#include #endif diff --git a/poincare/src/expression.cpp b/poincare/src/expression.cpp index a9032a621..21daee9cb 100644 --- a/poincare/src/expression.cpp +++ b/poincare/src/expression.cpp @@ -187,7 +187,8 @@ bool Expression::IsMatrix(const Expression e, Context * context) { || e.type() == ExpressionNode::Type::MatrixIdentity || e.type() == ExpressionNode::Type::MatrixTranspose || e.type() == ExpressionNode::Type::MatrixRowEchelonForm - || e.type() == ExpressionNode::Type::MatrixReducedRowEchelonForm; + || e.type() == ExpressionNode::Type::MatrixReducedRowEchelonForm + || e.type() == ExpressionNode::Type::VectorCross; } bool Expression::IsInfinity(const Expression e, Context * context) { diff --git a/poincare/src/matrix.cpp b/poincare/src/matrix.cpp index cabbb2992..6c7df3815 100644 --- a/poincare/src/matrix.cpp +++ b/poincare/src/matrix.cpp @@ -6,8 +6,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -483,6 +485,52 @@ Expression Matrix::determinant(ExpressionNode::ReductionContext reductionContext return result; } +Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const { + assert(numberOfColumns() == 1); + Addition sum = Addition::Builder(); + for (int j = 0; j < numberOfRows(); j++) { + Expression absValue = AbsoluteValue::Builder(const_cast(this)->matrixChild(0, j).clone()); + Expression squaredAbsValue = Power::Builder(absValue, Rational::Builder(2)); + absValue.shallowReduce(reductionContext); + sum.addChildAtIndexInPlace(squaredAbsValue, sum.numberOfChildren(), sum.numberOfChildren()); + squaredAbsValue.shallowReduce(reductionContext); + } + Expression result = SquareRoot::Builder(sum); + sum.shallowReduce(reductionContext); + return result; +} + +Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionContext) const { + // Dot product is defined between two vectors of same size + assert(numberOfRows() == b->numberOfRows() && numberOfColumns() == 1 && b->numberOfColumns() == 1); + Addition sum = Addition::Builder(); + for (int j = 0; j < numberOfRows(); j++) { + Expression product = Multiplication::Builder(const_cast(this)->matrixChild(0, j).clone(), const_cast(b)->matrixChild(0, j).clone()); + sum.addChildAtIndexInPlace(product, sum.numberOfChildren(), sum.numberOfChildren()); + product.shallowReduce(reductionContext); + } + return std::move(sum); +} + +Matrix Matrix::cross(Matrix * b, ExpressionNode::ReductionContext reductionContext) const { + // Cross product is defined between two vectors of size 3 + assert(numberOfRows() == 3 && numberOfColumns() == 1 && b->numberOfRows() == 3 && b->numberOfColumns() == 1); + Matrix matrix = Matrix::Builder(); + for (int j = 0; j < 3; j++) { + int j1 = (j+1)%3; + int j2 = (j+2)%3; + Expression a1b2 = Multiplication::Builder(const_cast(this)->matrixChild(0, j1).clone(), const_cast(b)->matrixChild(0, j2).clone()); + Expression a2b1 = Multiplication::Builder(const_cast(this)->matrixChild(0, j2).clone(), const_cast(b)->matrixChild(0, j1).clone()); + Expression difference = Subtraction::Builder(a1b2, a2b1); + a1b2.shallowReduce(reductionContext); + a2b1.shallowReduce(reductionContext); + matrix.addChildAtIndexInPlace(difference, matrix.numberOfChildren(), matrix.numberOfChildren()); + difference.shallowReduce(reductionContext); + } + matrix.setDimensions(3, 1); + return matrix; +} + Expression Matrix::shallowReduce(Context * context) { { Expression e = Expression::defaultShallowReduce(); diff --git a/poincare/src/matrix_complex.cpp b/poincare/src/matrix_complex.cpp index d0fa0b68c..eb94bab40 100644 --- a/poincare/src/matrix_complex.cpp +++ b/poincare/src/matrix_complex.cpp @@ -135,6 +135,50 @@ MatrixComplex MatrixComplexNode::ref(bool reduced) const { return MatrixComplex::Builder(operandsCopy, m_numberOfRows, m_numberOfColumns); } +template +std::complex MatrixComplexNode::norm() const { + if (numberOfChildren() == 0 || numberOfColumns() > 1) { + return std::complex(NAN, NAN); + } + std::complex sum = 0; + for (int i = 0; i < numberOfChildren(); i++) { + sum += std::norm(complexAtIndex(i)); + } + return std::sqrt(sum); +} + +template +std::complex MatrixComplexNode::dot(Evaluation * e) const { + if (e->type() != EvaluationNode::Type::MatrixComplex) { + return std::complex(NAN, NAN); + } + MatrixComplex * b = static_cast*>(e); + if (numberOfChildren() == 0 || numberOfColumns() > 1 || b->numberOfChildren() == 0 || b->numberOfColumns() > 1 || numberOfRows() != b->numberOfRows()) { + return std::complex(NAN, NAN); + } + std::complex sum = 0; + for (int i = 0; i < numberOfChildren(); i++) { + sum += complexAtIndex(i) * b->complexAtIndex(i); + } + return sum; +} + +template +Evaluation MatrixComplexNode::cross(Evaluation * e) const { + if (e->type() != EvaluationNode::Type::MatrixComplex) { + return MatrixComplex::Undefined(); + } + MatrixComplex * b = static_cast*>(e); + if (numberOfChildren() == 0 || numberOfColumns() != 1 || numberOfRows() != 3 || b->numberOfChildren() == 0 || b->numberOfColumns() != 1 || b->numberOfRows() != 3) { + return MatrixComplex::Undefined(); + } + std::complex operandsCopy[3]; + operandsCopy[0] = complexAtIndex(1) * b->complexAtIndex(2) - complexAtIndex(2) * b->complexAtIndex(1); + operandsCopy[1] = complexAtIndex(2) * b->complexAtIndex(0) - complexAtIndex(0) * b->complexAtIndex(2); + operandsCopy[2] = complexAtIndex(0) * b->complexAtIndex(1) - complexAtIndex(1) * b->complexAtIndex(0); + return MatrixComplex::Builder(operandsCopy, 3, 1); +} + // MATRIX COMPLEX REFERENCE template diff --git a/poincare/src/parsing/parser.h b/poincare/src/parsing/parser.h index 5064d12f7..8770d6220 100644 --- a/poincare/src/parsing/parser.h +++ b/poincare/src/parsing/parser.h @@ -110,9 +110,11 @@ private: &Conjugate::s_functionHelper, &Cosine::s_functionHelper, &HyperbolicCosine::s_functionHelper, + &VectorCross::s_functionHelper, &Determinant::s_functionHelper, &Derivative::s_functionHelper, &MatrixDimension::s_functionHelper, + &VectorDot::s_functionHelper, &Factor::s_functionHelper, &Floor::s_functionHelper, &FracPart::s_functionHelper, @@ -127,6 +129,7 @@ private: &NaperianLogarithm::s_functionHelper, &CommonLogarithm::s_functionHelper, &Logarithm::s_functionHelper, + &VectorNorm::s_functionHelper, &NormCDF::s_functionHelper, &NormCDF2::s_functionHelper, &NormPDF::s_functionHelper, diff --git a/poincare/src/tree_handle.cpp b/poincare/src/tree_handle.cpp index 1b70f0b74..4e1d2c511 100644 --- a/poincare/src/tree_handle.cpp +++ b/poincare/src/tree_handle.cpp @@ -369,6 +369,9 @@ template Tangent TreeHandle::FixedArityBuilder(const Tuple template Undefined TreeHandle::FixedArityBuilder(const Tuple &); template UnitConvert TreeHandle::FixedArityBuilder(const Tuple &); template Unreal TreeHandle::FixedArityBuilder(const Tuple &); +template VectorCross TreeHandle::FixedArityBuilder(const Tuple &); +template VectorDot TreeHandle::FixedArityBuilder(const Tuple &); +template VectorNorm TreeHandle::FixedArityBuilder(const Tuple &); template MatrixLayout TreeHandle::NAryBuilder(const Tuple &); } diff --git a/poincare/src/vector_cross.cpp b/poincare/src/vector_cross.cpp new file mode 100644 index 000000000..dc5f6dc19 --- /dev/null +++ b/poincare/src/vector_cross.cpp @@ -0,0 +1,61 @@ +#include +#include +#include +#include +#include +#include + +namespace Poincare { + +constexpr Expression::FunctionHelper VectorCross::s_functionHelper; + +int VectorCrossNode::numberOfChildren() const { return VectorCross::s_functionHelper.numberOfChildren(); } + +Expression VectorCrossNode::shallowReduce(ReductionContext reductionContext) { + return VectorCross(this).shallowReduce(reductionContext); +} + +Layout VectorCrossNode::createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const { + return LayoutHelper::Prefix(VectorCross(this), floatDisplayMode, numberOfSignificantDigits, VectorCross::s_functionHelper.name()); +} + +int VectorCrossNode::serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const { + return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, VectorCross::s_functionHelper.name()); +} + +template +Evaluation VectorCrossNode::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const { + Evaluation input0 = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit); + Evaluation input1 = childAtIndex(1)->approximate(T(), context, complexFormat, angleUnit); + return input0.cross(&input1); +} + + +Expression VectorCross::shallowReduce(ExpressionNode::ReductionContext reductionContext) { + { + Expression e = Expression::defaultShallowReduce(); + e = e.defaultHandleUnitsInChildren(); + if (e.isUndefined()) { + return e; + } + } + Expression c0 = childAtIndex(0); + Expression c1 = childAtIndex(1); + if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) { + Matrix matrixChild0 = static_cast(c0); + Matrix matrixChild1 = static_cast(c1); + // Cross product is defined between two column matrices of size 3 + if (matrixChild0.numberOfColumns() != 1 || matrixChild1.numberOfColumns() != 1 || matrixChild0.numberOfRows() != 3 || matrixChild1.numberOfRows() != 3) { + return replaceWithUndefinedInPlace(); + } + Expression a = matrixChild0.cross(&matrixChild1, reductionContext); + replaceWithInPlace(a); + return a.shallowReduce(reductionContext); + } + if (c0.deepIsMatrix(reductionContext.context()) && c1.deepIsMatrix(reductionContext.context())) { + return *this; + } + return replaceWithUndefinedInPlace(); +} + +} diff --git a/poincare/src/vector_dot.cpp b/poincare/src/vector_dot.cpp new file mode 100644 index 000000000..5dc5d23f1 --- /dev/null +++ b/poincare/src/vector_dot.cpp @@ -0,0 +1,61 @@ +#include +#include +#include +#include +#include +#include + +namespace Poincare { + +constexpr Expression::FunctionHelper VectorDot::s_functionHelper; + +int VectorDotNode::numberOfChildren() const { return VectorDot::s_functionHelper.numberOfChildren(); } + +Expression VectorDotNode::shallowReduce(ReductionContext reductionContext) { + return VectorDot(this).shallowReduce(reductionContext); +} + +Layout VectorDotNode::createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const { + return LayoutHelper::Prefix(VectorDot(this), floatDisplayMode, numberOfSignificantDigits, VectorDot::s_functionHelper.name()); +} + +int VectorDotNode::serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const { + return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, VectorDot::s_functionHelper.name()); +} + +template +Evaluation VectorDotNode::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const { + Evaluation input0 = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit); + Evaluation input1 = childAtIndex(1)->approximate(T(), context, complexFormat, angleUnit); + return Complex::Builder(input0.dot(&input1)); +} + + +Expression VectorDot::shallowReduce(ExpressionNode::ReductionContext reductionContext) { + { + Expression e = Expression::defaultShallowReduce(); + e = e.defaultHandleUnitsInChildren(); + if (e.isUndefined()) { + return e; + } + } + Expression c0 = childAtIndex(0); + Expression c1 = childAtIndex(1); + if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) { + Matrix matrixChild0 = static_cast(c0); + Matrix matrixChild1 = static_cast(c1); + // Dot product is defined between two column matrices of the same dimensions + if (matrixChild0.numberOfColumns() != 1 || matrixChild1.numberOfColumns() != 1 || matrixChild0.numberOfRows() != matrixChild1.numberOfRows()) { + return replaceWithUndefinedInPlace(); + } + Expression a = matrixChild0.dot(&matrixChild1, reductionContext); + replaceWithInPlace(a); + return a.shallowReduce(reductionContext); + } + if (c0.deepIsMatrix(reductionContext.context()) && c1.deepIsMatrix(reductionContext.context())) { + return *this; + } + return replaceWithUndefinedInPlace(); +} + +} diff --git a/poincare/src/vector_norm.cpp b/poincare/src/vector_norm.cpp new file mode 100644 index 000000000..aeef1f8a3 --- /dev/null +++ b/poincare/src/vector_norm.cpp @@ -0,0 +1,58 @@ +#include +#include +#include +#include +#include +#include + +namespace Poincare { + +constexpr Expression::FunctionHelper VectorNorm::s_functionHelper; + +int VectorNormNode::numberOfChildren() const { return VectorNorm::s_functionHelper.numberOfChildren(); } + +Expression VectorNormNode::shallowReduce(ReductionContext reductionContext) { + return VectorNorm(this).shallowReduce(reductionContext); +} + +Layout VectorNormNode::createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const { + return LayoutHelper::Prefix(VectorNorm(this), floatDisplayMode, numberOfSignificantDigits, VectorNorm::s_functionHelper.name()); +} + +int VectorNormNode::serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const { + return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, VectorNorm::s_functionHelper.name()); +} + +template +Evaluation VectorNormNode::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const { + Evaluation input = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit); + return Complex::Builder(input.norm()); +} + + +Expression VectorNorm::shallowReduce(ExpressionNode::ReductionContext reductionContext) { + { + Expression e = Expression::defaultShallowReduce(); + e = e.defaultHandleUnitsInChildren(); + if (e.isUndefined()) { + return e; + } + } + Expression c = childAtIndex(0); + if (c.type() == ExpressionNode::Type::Matrix) { + Matrix matrixChild = static_cast(c); + if (matrixChild.numberOfColumns() != 1) { + // Norm is only defined on column matrices + return replaceWithUndefinedInPlace(); + } + Expression a = matrixChild.norm(reductionContext); + replaceWithInPlace(a); + return a.shallowReduce(reductionContext); + } + if (c.deepIsMatrix(reductionContext.context())) { + return *this; + } + return replaceWithUndefinedInPlace(); +} + +}