[poincare] Add a vector type for matrix product and norm operations

Change-Id: I28b0956273f8c0a3a9bdc4389956caa106d6b8af
This commit is contained in:
Hugo Saint-Vignes
2020-11-18 10:35:10 +01:00
committed by EmilieNumworks
parent 522456677a
commit d8527b02ad
8 changed files with 23 additions and 19 deletions

View File

@@ -8,10 +8,15 @@ namespace Poincare {
class Array {
public:
enum class VectorType {
None,
Vertical,
Horizontal
};
Array() :
m_numberOfRows(0),
m_numberOfColumns(0) {}
bool isVector() const { return m_numberOfRows == 1 || m_numberOfColumns == 1; }
VectorType vectorType() const { return m_numberOfColumns == 1 ? VectorType::Vertical : (m_numberOfRows == 1 ? VectorType::Horizontal : VectorType::None); }
int numberOfRows() const { return m_numberOfRows; }
int numberOfColumns() const { return m_numberOfColumns; }
void setNumberOfRows(int rows) { assert(rows >= 0); m_numberOfRows = rows; }

View File

@@ -57,7 +57,7 @@ public:
static Matrix Builder() { return TreeHandle::NAryBuilder<Matrix, MatrixNode>(); }
void setDimensions(int rows, int columns);
bool isVector() const { return node()->isVector(); }
Array::VectorType vectorType() const { return node()->vectorType(); }
int numberOfRows() const { return node()->numberOfRows(); }
int numberOfColumns() const { return node()->numberOfColumns(); }
using TreeHandle::addChildAtIndexInPlace;

View File

@@ -63,7 +63,7 @@ public:
std::complex<T> complexAtIndex(int index) const {
return node()->complexAtIndex(index);
}
bool isVector() const { return node()->isVector(); }
Array::VectorType vectorType() const { return node()->vectorType(); }
int numberOfRows() const { return node()->numberOfRows(); }
int numberOfColumns() const { return node()->numberOfColumns(); }
void setDimensions(int rows, int columns);

View File

@@ -494,7 +494,8 @@ Expression Matrix::determinant(ExpressionNode::ReductionContext reductionContext
}
Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const {
assert(isVector());
// Norm is defined on vectors only
assert(vectorType() != Array::VectorType::None);
Addition sum = Addition::Builder();
for (int j = 0; j < numberOfChildren(); j++) {
Expression absValue = AbsoluteValue::Builder(const_cast<Matrix *>(this)->childAtIndex(j).clone());
@@ -509,8 +510,8 @@ Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const
}
Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionContext) const {
// Dot product is defined between two vectors of same size and orientation
assert(isVector() && b->isVector() && numberOfChildren() == b->numberOfChildren() && numberOfRows() == b->numberOfRows());
// Dot product is defined between two vectors of same size and type
assert(vectorType() != Array::VectorType::None && vectorType() == b->vectorType() && numberOfChildren() == b->numberOfChildren());
Addition sum = Addition::Builder();
for (int j = 0; j < numberOfChildren(); j++) {
Expression product = Multiplication::Builder(const_cast<Matrix *>(this)->childAtIndex(j).clone(), const_cast<Matrix *>(b)->childAtIndex(j).clone());
@@ -521,9 +522,8 @@ Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionCon
}
Matrix Matrix::cross(Matrix * b, ExpressionNode::ReductionContext reductionContext) const {
/* Cross product is defined between two vectors of size 3 and of same
* orientation */
assert(isVector() && b->isVector() && numberOfChildren() == 3 && b->numberOfChildren() == 3 && numberOfRows() == b->numberOfRows());
// Cross product is defined between two vectors of size 3 and of same type.
assert(vectorType() != Array::VectorType::None && vectorType() == b->vectorType() && numberOfChildren() == 3 && b->numberOfChildren() == 3);
Matrix matrix = Matrix::Builder();
for (int j = 0; j < 3; j++) {
int j1 = (j+1)%3;

View File

@@ -137,7 +137,7 @@ MatrixComplex<T> MatrixComplexNode<T>::ref(bool reduced) const {
template<typename T>
std::complex<T> MatrixComplexNode<T>::norm() const {
if (!isVector()) {
if (vectorType() == Array::VectorType::None) {
return std::complex<T>(NAN, NAN);
}
std::complex<T> sum = 0;
@@ -153,7 +153,7 @@ std::complex<T> MatrixComplexNode<T>::dot(Evaluation<T> * e) const {
return std::complex<T>(NAN, NAN);
}
MatrixComplex<T> * b = static_cast<MatrixComplex<T>*>(e);
if (!isVector() || !b->isVector() || numberOfChildren() != b->numberOfChildren() || numberOfRows() != b->numberOfRows()) {
if (vectorType() == Array::VectorType::None || vectorType() != b->vectorType() || numberOfChildren() != b->numberOfChildren()) {
return std::complex<T>(NAN, NAN);
}
std::complex<T> sum = 0;
@@ -169,7 +169,7 @@ Evaluation<T> MatrixComplexNode<T>::cross(Evaluation<T> * e) const {
return MatrixComplex<T>::Undefined();
}
MatrixComplex<T> * b = static_cast<MatrixComplex<T>*>(e);
if (!isVector() || !b->isVector() || numberOfChildren() != 3 || b->numberOfChildren() != 3 || numberOfRows() != b->numberOfRows()) {
if (vectorType() == Array::VectorType::None || vectorType() != b->vectorType() || numberOfChildren() != 3 || b->numberOfChildren() != 3) {
return MatrixComplex<T>::Undefined();
}
std::complex<T> operandsCopy[3];

View File

@@ -44,8 +44,8 @@ Expression VectorCross::shallowReduce(ExpressionNode::ReductionContext reduction
if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) {
Matrix matrixChild0 = static_cast<Matrix&>(c0);
Matrix matrixChild1 = static_cast<Matrix&>(c1);
// Cross product is defined between two vectors of size 3
if (!matrixChild0.isVector() || !matrixChild1.isVector() || matrixChild0.numberOfChildren() != 3 || matrixChild1.numberOfChildren() != 3 || matrixChild0.numberOfRows() != matrixChild1.numberOfRows()) {
// Cross product is defined between two vectors of same type and of size 3
if (matrixChild0.vectorType() == Array::VectorType::None || matrixChild0.vectorType() != matrixChild1.vectorType() || matrixChild0.numberOfChildren() != 3 || matrixChild1.numberOfChildren() != 3) {
return replaceWithUndefinedInPlace();
}
Expression a = matrixChild0.cross(&matrixChild1, reductionContext);

View File

@@ -44,9 +44,8 @@ Expression VectorDot::shallowReduce(ExpressionNode::ReductionContext reductionCo
if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) {
Matrix matrixChild0 = static_cast<Matrix&>(c0);
Matrix matrixChild1 = static_cast<Matrix&>(c1);
/* Dot product is defined between two vectors of the same dimension and
* orientation */
if (!matrixChild0.isVector() || !matrixChild1.isVector() || matrixChild0.numberOfChildren() != matrixChild1.numberOfChildren() || matrixChild0.numberOfRows() != matrixChild1.numberOfRows()) {
// Dot product is defined between two vectors of the same dimension and type
if (matrixChild0.vectorType() == Array::VectorType::None || matrixChild0.vectorType() != matrixChild1.vectorType() || matrixChild0.numberOfChildren() != matrixChild1.numberOfChildren()) {
return replaceWithUndefinedInPlace();
}
Expression a = matrixChild0.dot(&matrixChild1, reductionContext);

View File

@@ -42,8 +42,8 @@ Expression VectorNorm::shallowReduce(ExpressionNode::ReductionContext reductionC
Expression c = childAtIndex(0);
if (c.type() == ExpressionNode::Type::Matrix) {
Matrix matrixChild = static_cast<Matrix&>(c);
if (!matrixChild.isVector()) {
// Norm is only defined on vectors
// Norm is only defined on vectors only
if (matrixChild.vectorType() == Array::VectorType::None) {
return replaceWithUndefinedInPlace();
}
Expression a = matrixChild.norm(reductionContext);