From bee8d8531bfcdce6d95f1ff058c9d9fcda471f11 Mon Sep 17 00:00:00 2001 From: Hugo Saint-Vignes Date: Mon, 16 Nov 2020 11:09:49 +0100 Subject: [PATCH] [poincare] Handle horizontal vectors Change-Id: I98088b2b9f2dbc0549795a5c3eed4787fea70068 --- poincare/include/poincare/matrix.h | 2 ++ poincare/include/poincare/matrix_complex.h | 2 ++ poincare/src/matrix.cpp | 18 +++++++++--------- poincare/src/matrix_complex.cpp | 6 +++--- poincare/src/vector_cross.cpp | 4 ++-- poincare/src/vector_dot.cpp | 4 ++-- poincare/src/vector_norm.cpp | 4 ++-- 7 files changed, 22 insertions(+), 18 deletions(-) diff --git a/poincare/include/poincare/matrix.h b/poincare/include/poincare/matrix.h index 35ef39629..24aabf490 100644 --- a/poincare/include/poincare/matrix.h +++ b/poincare/include/poincare/matrix.h @@ -12,6 +12,7 @@ public: m_numberOfColumns(0) {} bool hasMatrixChild(Context * context) const; + bool isVector() const { return m_numberOfRows == 1 || m_numberOfColumns == 1; } int numberOfRows() const { return m_numberOfRows; } int numberOfColumns() const { return m_numberOfColumns; } virtual void setNumberOfRows(int rows) { assert(rows >= 0); m_numberOfRows = rows; } @@ -67,6 +68,7 @@ public: static Matrix Builder() { return TreeHandle::NAryBuilder(); } void setDimensions(int rows, int columns); + bool isVector() const { return node()->isVector(); } int numberOfRows() const { return node()->numberOfRows(); } int numberOfColumns() const { return node()->numberOfColumns(); } using TreeHandle::addChildAtIndexInPlace; diff --git a/poincare/include/poincare/matrix_complex.h b/poincare/include/poincare/matrix_complex.h index 275785cd4..50d7c5922 100644 --- a/poincare/include/poincare/matrix_complex.h +++ b/poincare/include/poincare/matrix_complex.h @@ -36,6 +36,7 @@ public: // EvaluationNode typename EvaluationNode::Type type() const override { return EvaluationNode::Type::MatrixComplex; } + bool isVector() const { return m_numberOfRows == 1 || m_numberOfColumns == 1; } int numberOfRows() const { return m_numberOfRows; } int numberOfColumns() const { return m_numberOfColumns; } virtual void setNumberOfRows(int rows) { assert(rows >= 0); m_numberOfRows = rows; } @@ -71,6 +72,7 @@ public: std::complex complexAtIndex(int index) const { return node()->complexAtIndex(index); } + bool isVector() const { return node()->isVector(); } int numberOfRows() const { return node()->numberOfRows(); } int numberOfColumns() const { return node()->numberOfColumns(); } void setDimensions(int rows, int columns); diff --git a/poincare/src/matrix.cpp b/poincare/src/matrix.cpp index 5dd9f56a0..2f2d0bfb7 100644 --- a/poincare/src/matrix.cpp +++ b/poincare/src/matrix.cpp @@ -494,10 +494,10 @@ Expression Matrix::determinant(ExpressionNode::ReductionContext reductionContext } Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const { - assert(numberOfColumns() == 1); + assert(isVector()); Addition sum = Addition::Builder(); - for (int j = 0; j < numberOfRows(); j++) { - Expression absValue = AbsoluteValue::Builder(const_cast(this)->matrixChild(0, j).clone()); + for (int j = 0; j < numberOfChildren(); j++) { + Expression absValue = AbsoluteValue::Builder(const_cast(this)->childAtIndex(j).clone()); Expression squaredAbsValue = Power::Builder(absValue, Rational::Builder(2)); absValue.shallowReduce(reductionContext); sum.addChildAtIndexInPlace(squaredAbsValue, sum.numberOfChildren(), sum.numberOfChildren()); @@ -510,10 +510,10 @@ Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionContext) const { // Dot product is defined between two vectors of same size - assert(numberOfRows() == b->numberOfRows() && numberOfColumns() == 1 && b->numberOfColumns() == 1); + assert(isVector() && b->isVector() && numberOfChildren() == b->numberOfChildren()); Addition sum = Addition::Builder(); - for (int j = 0; j < numberOfRows(); j++) { - Expression product = Multiplication::Builder(const_cast(this)->matrixChild(0, j).clone(), const_cast(b)->matrixChild(0, j).clone()); + for (int j = 0; j < numberOfChildren(); j++) { + Expression product = Multiplication::Builder(const_cast(this)->childAtIndex(j).clone(), const_cast(b)->childAtIndex(j).clone()); sum.addChildAtIndexInPlace(product, sum.numberOfChildren(), sum.numberOfChildren()); product.shallowReduce(reductionContext); } @@ -522,13 +522,13 @@ Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionCon Matrix Matrix::cross(Matrix * b, ExpressionNode::ReductionContext reductionContext) const { // Cross product is defined between two vectors of size 3 - assert(numberOfRows() == 3 && numberOfColumns() == 1 && b->numberOfRows() == 3 && b->numberOfColumns() == 1); + assert(isVector() && b->isVector() && numberOfChildren() == 3 && b->numberOfChildren() == 3); Matrix matrix = Matrix::Builder(); for (int j = 0; j < 3; j++) { int j1 = (j+1)%3; int j2 = (j+2)%3; - Expression a1b2 = Multiplication::Builder(const_cast(this)->matrixChild(0, j1).clone(), const_cast(b)->matrixChild(0, j2).clone()); - Expression a2b1 = Multiplication::Builder(const_cast(this)->matrixChild(0, j2).clone(), const_cast(b)->matrixChild(0, j1).clone()); + Expression a1b2 = Multiplication::Builder(const_cast(this)->childAtIndex(j1).clone(), const_cast(b)->childAtIndex(j2).clone()); + Expression a2b1 = Multiplication::Builder(const_cast(this)->childAtIndex(j2).clone(), const_cast(b)->childAtIndex(j1).clone()); Expression difference = Subtraction::Builder(a1b2, a2b1); a1b2.shallowReduce(reductionContext); a2b1.shallowReduce(reductionContext); diff --git a/poincare/src/matrix_complex.cpp b/poincare/src/matrix_complex.cpp index eb94bab40..f0f087068 100644 --- a/poincare/src/matrix_complex.cpp +++ b/poincare/src/matrix_complex.cpp @@ -137,7 +137,7 @@ MatrixComplex MatrixComplexNode::ref(bool reduced) const { template std::complex MatrixComplexNode::norm() const { - if (numberOfChildren() == 0 || numberOfColumns() > 1) { + if (!isVector()) { return std::complex(NAN, NAN); } std::complex sum = 0; @@ -153,7 +153,7 @@ std::complex MatrixComplexNode::dot(Evaluation * e) const { return std::complex(NAN, NAN); } MatrixComplex * b = static_cast*>(e); - if (numberOfChildren() == 0 || numberOfColumns() > 1 || b->numberOfChildren() == 0 || b->numberOfColumns() > 1 || numberOfRows() != b->numberOfRows()) { + if (!isVector() || !b->isVector() || numberOfChildren() != b->numberOfChildren()) { return std::complex(NAN, NAN); } std::complex sum = 0; @@ -169,7 +169,7 @@ Evaluation MatrixComplexNode::cross(Evaluation * e) const { return MatrixComplex::Undefined(); } MatrixComplex * b = static_cast*>(e); - if (numberOfChildren() == 0 || numberOfColumns() != 1 || numberOfRows() != 3 || b->numberOfChildren() == 0 || b->numberOfColumns() != 1 || b->numberOfRows() != 3) { + if (!isVector() || !b->isVector() || numberOfChildren() != 3 || b->numberOfChildren() != 3) { return MatrixComplex::Undefined(); } std::complex operandsCopy[3]; diff --git a/poincare/src/vector_cross.cpp b/poincare/src/vector_cross.cpp index 6ee288083..4208be412 100644 --- a/poincare/src/vector_cross.cpp +++ b/poincare/src/vector_cross.cpp @@ -44,8 +44,8 @@ Expression VectorCross::shallowReduce(ExpressionNode::ReductionContext reduction if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) { Matrix matrixChild0 = static_cast(c0); Matrix matrixChild1 = static_cast(c1); - // Cross product is defined between two column matrices of size 3 - if (matrixChild0.numberOfColumns() != 1 || matrixChild1.numberOfColumns() != 1 || matrixChild0.numberOfRows() != 3 || matrixChild1.numberOfRows() != 3) { + // Cross product is defined between two vectors of size 3 + if (!matrixChild0.isVector() || !matrixChild1.isVector() || matrixChild0.numberOfChildren() != 3 || matrixChild1.numberOfChildren() != 3) { return replaceWithUndefinedInPlace(); } Expression a = matrixChild0.cross(&matrixChild1, reductionContext); diff --git a/poincare/src/vector_dot.cpp b/poincare/src/vector_dot.cpp index aaabfe3c1..820abf7c1 100644 --- a/poincare/src/vector_dot.cpp +++ b/poincare/src/vector_dot.cpp @@ -44,8 +44,8 @@ Expression VectorDot::shallowReduce(ExpressionNode::ReductionContext reductionCo if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) { Matrix matrixChild0 = static_cast(c0); Matrix matrixChild1 = static_cast(c1); - // Dot product is defined between two column matrices of the same dimensions - if (matrixChild0.numberOfColumns() != 1 || matrixChild1.numberOfColumns() != 1 || matrixChild0.numberOfRows() != matrixChild1.numberOfRows()) { + // Dot product is defined between two vectors of the same dimensions + if (!matrixChild0.isVector() || !matrixChild1.isVector() || matrixChild0.numberOfChildren() != matrixChild1.numberOfChildren()) { return replaceWithUndefinedInPlace(); } Expression a = matrixChild0.dot(&matrixChild1, reductionContext); diff --git a/poincare/src/vector_norm.cpp b/poincare/src/vector_norm.cpp index c428d8b5e..1f280ff3d 100644 --- a/poincare/src/vector_norm.cpp +++ b/poincare/src/vector_norm.cpp @@ -42,8 +42,8 @@ Expression VectorNorm::shallowReduce(ExpressionNode::ReductionContext reductionC Expression c = childAtIndex(0); if (c.type() == ExpressionNode::Type::Matrix) { Matrix matrixChild = static_cast(c); - if (matrixChild.numberOfColumns() != 1) { - // Norm is only defined on column matrices + if (!matrixChild.isVector()) { + // Norm is only defined on vectors return replaceWithUndefinedInPlace(); } Expression a = matrixChild.norm(reductionContext);