diff --git a/poincare/include/poincare/array.h b/poincare/include/poincare/array.h index 09f2b2dda..d3f001a41 100644 --- a/poincare/include/poincare/array.h +++ b/poincare/include/poincare/array.h @@ -8,10 +8,15 @@ namespace Poincare { class Array { public: + enum class VectorType { + None, + Vertical, + Horizontal + }; Array() : m_numberOfRows(0), m_numberOfColumns(0) {} - bool isVector() const { return m_numberOfRows == 1 || m_numberOfColumns == 1; } + VectorType vectorType() const { return m_numberOfColumns == 1 ? VectorType::Vertical : (m_numberOfRows == 1 ? VectorType::Horizontal : VectorType::None); } int numberOfRows() const { return m_numberOfRows; } int numberOfColumns() const { return m_numberOfColumns; } void setNumberOfRows(int rows) { assert(rows >= 0); m_numberOfRows = rows; } diff --git a/poincare/include/poincare/matrix.h b/poincare/include/poincare/matrix.h index 45ad18b72..ee443e1bd 100644 --- a/poincare/include/poincare/matrix.h +++ b/poincare/include/poincare/matrix.h @@ -57,7 +57,7 @@ public: static Matrix Builder() { return TreeHandle::NAryBuilder(); } void setDimensions(int rows, int columns); - bool isVector() const { return node()->isVector(); } + Array::VectorType vectorType() const { return node()->vectorType(); } int numberOfRows() const { return node()->numberOfRows(); } int numberOfColumns() const { return node()->numberOfColumns(); } using TreeHandle::addChildAtIndexInPlace; diff --git a/poincare/include/poincare/matrix_complex.h b/poincare/include/poincare/matrix_complex.h index 57cf48f5f..d24dbfa3e 100644 --- a/poincare/include/poincare/matrix_complex.h +++ b/poincare/include/poincare/matrix_complex.h @@ -63,7 +63,7 @@ public: std::complex complexAtIndex(int index) const { return node()->complexAtIndex(index); } - bool isVector() const { return node()->isVector(); } + Array::VectorType vectorType() const { return node()->vectorType(); } int numberOfRows() const { return node()->numberOfRows(); } int numberOfColumns() const { return node()->numberOfColumns(); } void setDimensions(int rows, int columns); diff --git a/poincare/src/matrix.cpp b/poincare/src/matrix.cpp index b2ddc8988..48422d6b6 100644 --- a/poincare/src/matrix.cpp +++ b/poincare/src/matrix.cpp @@ -494,7 +494,8 @@ Expression Matrix::determinant(ExpressionNode::ReductionContext reductionContext } Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const { - assert(isVector()); + // Norm is defined on vectors only + assert(vectorType() != Array::VectorType::None); Addition sum = Addition::Builder(); for (int j = 0; j < numberOfChildren(); j++) { Expression absValue = AbsoluteValue::Builder(const_cast(this)->childAtIndex(j).clone()); @@ -509,8 +510,8 @@ Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const } Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionContext) const { - // Dot product is defined between two vectors of same size and orientation - assert(isVector() && b->isVector() && numberOfChildren() == b->numberOfChildren() && numberOfRows() == b->numberOfRows()); + // Dot product is defined between two vectors of same size and type + assert(vectorType() != Array::VectorType::None && vectorType() == b->vectorType() && numberOfChildren() == b->numberOfChildren()); Addition sum = Addition::Builder(); for (int j = 0; j < numberOfChildren(); j++) { Expression product = Multiplication::Builder(const_cast(this)->childAtIndex(j).clone(), const_cast(b)->childAtIndex(j).clone()); @@ -521,9 +522,8 @@ Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionCon } Matrix Matrix::cross(Matrix * b, ExpressionNode::ReductionContext reductionContext) const { - /* Cross product is defined between two vectors of size 3 and of same - * orientation */ - assert(isVector() && b->isVector() && numberOfChildren() == 3 && b->numberOfChildren() == 3 && numberOfRows() == b->numberOfRows()); + // Cross product is defined between two vectors of size 3 and of same type. + assert(vectorType() != Array::VectorType::None && vectorType() == b->vectorType() && numberOfChildren() == 3 && b->numberOfChildren() == 3); Matrix matrix = Matrix::Builder(); for (int j = 0; j < 3; j++) { int j1 = (j+1)%3; diff --git a/poincare/src/matrix_complex.cpp b/poincare/src/matrix_complex.cpp index de32b2620..5099f7dfe 100644 --- a/poincare/src/matrix_complex.cpp +++ b/poincare/src/matrix_complex.cpp @@ -137,7 +137,7 @@ MatrixComplex MatrixComplexNode::ref(bool reduced) const { template std::complex MatrixComplexNode::norm() const { - if (!isVector()) { + if (vectorType() == Array::VectorType::None) { return std::complex(NAN, NAN); } std::complex sum = 0; @@ -153,7 +153,7 @@ std::complex MatrixComplexNode::dot(Evaluation * e) const { return std::complex(NAN, NAN); } MatrixComplex * b = static_cast*>(e); - if (!isVector() || !b->isVector() || numberOfChildren() != b->numberOfChildren() || numberOfRows() != b->numberOfRows()) { + if (vectorType() == Array::VectorType::None || vectorType() != b->vectorType() || numberOfChildren() != b->numberOfChildren()) { return std::complex(NAN, NAN); } std::complex sum = 0; @@ -169,7 +169,7 @@ Evaluation MatrixComplexNode::cross(Evaluation * e) const { return MatrixComplex::Undefined(); } MatrixComplex * b = static_cast*>(e); - if (!isVector() || !b->isVector() || numberOfChildren() != 3 || b->numberOfChildren() != 3 || numberOfRows() != b->numberOfRows()) { + if (vectorType() == Array::VectorType::None || vectorType() != b->vectorType() || numberOfChildren() != 3 || b->numberOfChildren() != 3) { return MatrixComplex::Undefined(); } std::complex operandsCopy[3]; diff --git a/poincare/src/vector_cross.cpp b/poincare/src/vector_cross.cpp index 750f82e96..6aeb6e7a1 100644 --- a/poincare/src/vector_cross.cpp +++ b/poincare/src/vector_cross.cpp @@ -44,8 +44,8 @@ Expression VectorCross::shallowReduce(ExpressionNode::ReductionContext reduction if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) { Matrix matrixChild0 = static_cast(c0); Matrix matrixChild1 = static_cast(c1); - // Cross product is defined between two vectors of size 3 - if (!matrixChild0.isVector() || !matrixChild1.isVector() || matrixChild0.numberOfChildren() != 3 || matrixChild1.numberOfChildren() != 3 || matrixChild0.numberOfRows() != matrixChild1.numberOfRows()) { + // Cross product is defined between two vectors of same type and of size 3 + if (matrixChild0.vectorType() == Array::VectorType::None || matrixChild0.vectorType() != matrixChild1.vectorType() || matrixChild0.numberOfChildren() != 3 || matrixChild1.numberOfChildren() != 3) { return replaceWithUndefinedInPlace(); } Expression a = matrixChild0.cross(&matrixChild1, reductionContext); diff --git a/poincare/src/vector_dot.cpp b/poincare/src/vector_dot.cpp index ce1791575..ed8305faf 100644 --- a/poincare/src/vector_dot.cpp +++ b/poincare/src/vector_dot.cpp @@ -44,9 +44,8 @@ Expression VectorDot::shallowReduce(ExpressionNode::ReductionContext reductionCo if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) { Matrix matrixChild0 = static_cast(c0); Matrix matrixChild1 = static_cast(c1); - /* Dot product is defined between two vectors of the same dimension and - * orientation */ - if (!matrixChild0.isVector() || !matrixChild1.isVector() || matrixChild0.numberOfChildren() != matrixChild1.numberOfChildren() || matrixChild0.numberOfRows() != matrixChild1.numberOfRows()) { + // Dot product is defined between two vectors of the same dimension and type + if (matrixChild0.vectorType() == Array::VectorType::None || matrixChild0.vectorType() != matrixChild1.vectorType() || matrixChild0.numberOfChildren() != matrixChild1.numberOfChildren()) { return replaceWithUndefinedInPlace(); } Expression a = matrixChild0.dot(&matrixChild1, reductionContext); diff --git a/poincare/src/vector_norm.cpp b/poincare/src/vector_norm.cpp index 1f280ff3d..cccd7f0bc 100644 --- a/poincare/src/vector_norm.cpp +++ b/poincare/src/vector_norm.cpp @@ -42,8 +42,8 @@ Expression VectorNorm::shallowReduce(ExpressionNode::ReductionContext reductionC Expression c = childAtIndex(0); if (c.type() == ExpressionNode::Type::Matrix) { Matrix matrixChild = static_cast(c); - if (!matrixChild.isVector()) { - // Norm is only defined on vectors + // Norm is only defined on vectors only + if (matrixChild.vectorType() == Array::VectorType::None) { return replaceWithUndefinedInPlace(); } Expression a = matrixChild.norm(reductionContext);