[poincare] Handle horizontal vectors

Change-Id: I98088b2b9f2dbc0549795a5c3eed4787fea70068
This commit is contained in:
Hugo Saint-Vignes
2020-11-16 11:09:49 +01:00
committed by EmilieNumworks
parent 0ecfa0012c
commit bee8d8531b
7 changed files with 22 additions and 18 deletions

View File

@@ -12,6 +12,7 @@ public:
m_numberOfColumns(0) {} m_numberOfColumns(0) {}
bool hasMatrixChild(Context * context) const; bool hasMatrixChild(Context * context) const;
bool isVector() const { return m_numberOfRows == 1 || m_numberOfColumns == 1; }
int numberOfRows() const { return m_numberOfRows; } int numberOfRows() const { return m_numberOfRows; }
int numberOfColumns() const { return m_numberOfColumns; } int numberOfColumns() const { return m_numberOfColumns; }
virtual void setNumberOfRows(int rows) { assert(rows >= 0); m_numberOfRows = rows; } virtual void setNumberOfRows(int rows) { assert(rows >= 0); m_numberOfRows = rows; }
@@ -67,6 +68,7 @@ public:
static Matrix Builder() { return TreeHandle::NAryBuilder<Matrix, MatrixNode>(); } static Matrix Builder() { return TreeHandle::NAryBuilder<Matrix, MatrixNode>(); }
void setDimensions(int rows, int columns); void setDimensions(int rows, int columns);
bool isVector() const { return node()->isVector(); }
int numberOfRows() const { return node()->numberOfRows(); } int numberOfRows() const { return node()->numberOfRows(); }
int numberOfColumns() const { return node()->numberOfColumns(); } int numberOfColumns() const { return node()->numberOfColumns(); }
using TreeHandle::addChildAtIndexInPlace; using TreeHandle::addChildAtIndexInPlace;

View File

@@ -36,6 +36,7 @@ public:
// EvaluationNode // EvaluationNode
typename EvaluationNode<T>::Type type() const override { return EvaluationNode<T>::Type::MatrixComplex; } typename EvaluationNode<T>::Type type() const override { return EvaluationNode<T>::Type::MatrixComplex; }
bool isVector() const { return m_numberOfRows == 1 || m_numberOfColumns == 1; }
int numberOfRows() const { return m_numberOfRows; } int numberOfRows() const { return m_numberOfRows; }
int numberOfColumns() const { return m_numberOfColumns; } int numberOfColumns() const { return m_numberOfColumns; }
virtual void setNumberOfRows(int rows) { assert(rows >= 0); m_numberOfRows = rows; } virtual void setNumberOfRows(int rows) { assert(rows >= 0); m_numberOfRows = rows; }
@@ -71,6 +72,7 @@ public:
std::complex<T> complexAtIndex(int index) const { std::complex<T> complexAtIndex(int index) const {
return node()->complexAtIndex(index); return node()->complexAtIndex(index);
} }
bool isVector() const { return node()->isVector(); }
int numberOfRows() const { return node()->numberOfRows(); } int numberOfRows() const { return node()->numberOfRows(); }
int numberOfColumns() const { return node()->numberOfColumns(); } int numberOfColumns() const { return node()->numberOfColumns(); }
void setDimensions(int rows, int columns); void setDimensions(int rows, int columns);

View File

@@ -494,10 +494,10 @@ Expression Matrix::determinant(ExpressionNode::ReductionContext reductionContext
} }
Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const { Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const {
assert(numberOfColumns() == 1); assert(isVector());
Addition sum = Addition::Builder(); Addition sum = Addition::Builder();
for (int j = 0; j < numberOfRows(); j++) { for (int j = 0; j < numberOfChildren(); j++) {
Expression absValue = AbsoluteValue::Builder(const_cast<Matrix *>(this)->matrixChild(0, j).clone()); Expression absValue = AbsoluteValue::Builder(const_cast<Matrix *>(this)->childAtIndex(j).clone());
Expression squaredAbsValue = Power::Builder(absValue, Rational::Builder(2)); Expression squaredAbsValue = Power::Builder(absValue, Rational::Builder(2));
absValue.shallowReduce(reductionContext); absValue.shallowReduce(reductionContext);
sum.addChildAtIndexInPlace(squaredAbsValue, sum.numberOfChildren(), sum.numberOfChildren()); sum.addChildAtIndexInPlace(squaredAbsValue, sum.numberOfChildren(), sum.numberOfChildren());
@@ -510,10 +510,10 @@ Expression Matrix::norm(ExpressionNode::ReductionContext reductionContext) const
Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionContext) const { Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionContext) const {
// Dot product is defined between two vectors of same size // Dot product is defined between two vectors of same size
assert(numberOfRows() == b->numberOfRows() && numberOfColumns() == 1 && b->numberOfColumns() == 1); assert(isVector() && b->isVector() && numberOfChildren() == b->numberOfChildren());
Addition sum = Addition::Builder(); Addition sum = Addition::Builder();
for (int j = 0; j < numberOfRows(); j++) { for (int j = 0; j < numberOfChildren(); j++) {
Expression product = Multiplication::Builder(const_cast<Matrix *>(this)->matrixChild(0, j).clone(), const_cast<Matrix *>(b)->matrixChild(0, j).clone()); Expression product = Multiplication::Builder(const_cast<Matrix *>(this)->childAtIndex(j).clone(), const_cast<Matrix *>(b)->childAtIndex(j).clone());
sum.addChildAtIndexInPlace(product, sum.numberOfChildren(), sum.numberOfChildren()); sum.addChildAtIndexInPlace(product, sum.numberOfChildren(), sum.numberOfChildren());
product.shallowReduce(reductionContext); product.shallowReduce(reductionContext);
} }
@@ -522,13 +522,13 @@ Expression Matrix::dot(Matrix * b, ExpressionNode::ReductionContext reductionCon
Matrix Matrix::cross(Matrix * b, ExpressionNode::ReductionContext reductionContext) const { Matrix Matrix::cross(Matrix * b, ExpressionNode::ReductionContext reductionContext) const {
// Cross product is defined between two vectors of size 3 // Cross product is defined between two vectors of size 3
assert(numberOfRows() == 3 && numberOfColumns() == 1 && b->numberOfRows() == 3 && b->numberOfColumns() == 1); assert(isVector() && b->isVector() && numberOfChildren() == 3 && b->numberOfChildren() == 3);
Matrix matrix = Matrix::Builder(); Matrix matrix = Matrix::Builder();
for (int j = 0; j < 3; j++) { for (int j = 0; j < 3; j++) {
int j1 = (j+1)%3; int j1 = (j+1)%3;
int j2 = (j+2)%3; int j2 = (j+2)%3;
Expression a1b2 = Multiplication::Builder(const_cast<Matrix *>(this)->matrixChild(0, j1).clone(), const_cast<Matrix *>(b)->matrixChild(0, j2).clone()); Expression a1b2 = Multiplication::Builder(const_cast<Matrix *>(this)->childAtIndex(j1).clone(), const_cast<Matrix *>(b)->childAtIndex(j2).clone());
Expression a2b1 = Multiplication::Builder(const_cast<Matrix *>(this)->matrixChild(0, j2).clone(), const_cast<Matrix *>(b)->matrixChild(0, j1).clone()); Expression a2b1 = Multiplication::Builder(const_cast<Matrix *>(this)->childAtIndex(j2).clone(), const_cast<Matrix *>(b)->childAtIndex(j1).clone());
Expression difference = Subtraction::Builder(a1b2, a2b1); Expression difference = Subtraction::Builder(a1b2, a2b1);
a1b2.shallowReduce(reductionContext); a1b2.shallowReduce(reductionContext);
a2b1.shallowReduce(reductionContext); a2b1.shallowReduce(reductionContext);

View File

@@ -137,7 +137,7 @@ MatrixComplex<T> MatrixComplexNode<T>::ref(bool reduced) const {
template<typename T> template<typename T>
std::complex<T> MatrixComplexNode<T>::norm() const { std::complex<T> MatrixComplexNode<T>::norm() const {
if (numberOfChildren() == 0 || numberOfColumns() > 1) { if (!isVector()) {
return std::complex<T>(NAN, NAN); return std::complex<T>(NAN, NAN);
} }
std::complex<T> sum = 0; std::complex<T> sum = 0;
@@ -153,7 +153,7 @@ std::complex<T> MatrixComplexNode<T>::dot(Evaluation<T> * e) const {
return std::complex<T>(NAN, NAN); return std::complex<T>(NAN, NAN);
} }
MatrixComplex<T> * b = static_cast<MatrixComplex<T>*>(e); MatrixComplex<T> * b = static_cast<MatrixComplex<T>*>(e);
if (numberOfChildren() == 0 || numberOfColumns() > 1 || b->numberOfChildren() == 0 || b->numberOfColumns() > 1 || numberOfRows() != b->numberOfRows()) { if (!isVector() || !b->isVector() || numberOfChildren() != b->numberOfChildren()) {
return std::complex<T>(NAN, NAN); return std::complex<T>(NAN, NAN);
} }
std::complex<T> sum = 0; std::complex<T> sum = 0;
@@ -169,7 +169,7 @@ Evaluation<T> MatrixComplexNode<T>::cross(Evaluation<T> * e) const {
return MatrixComplex<T>::Undefined(); return MatrixComplex<T>::Undefined();
} }
MatrixComplex<T> * b = static_cast<MatrixComplex<T>*>(e); MatrixComplex<T> * b = static_cast<MatrixComplex<T>*>(e);
if (numberOfChildren() == 0 || numberOfColumns() != 1 || numberOfRows() != 3 || b->numberOfChildren() == 0 || b->numberOfColumns() != 1 || b->numberOfRows() != 3) { if (!isVector() || !b->isVector() || numberOfChildren() != 3 || b->numberOfChildren() != 3) {
return MatrixComplex<T>::Undefined(); return MatrixComplex<T>::Undefined();
} }
std::complex<T> operandsCopy[3]; std::complex<T> operandsCopy[3];

View File

@@ -44,8 +44,8 @@ Expression VectorCross::shallowReduce(ExpressionNode::ReductionContext reduction
if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) { if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) {
Matrix matrixChild0 = static_cast<Matrix&>(c0); Matrix matrixChild0 = static_cast<Matrix&>(c0);
Matrix matrixChild1 = static_cast<Matrix&>(c1); Matrix matrixChild1 = static_cast<Matrix&>(c1);
// Cross product is defined between two column matrices of size 3 // Cross product is defined between two vectors of size 3
if (matrixChild0.numberOfColumns() != 1 || matrixChild1.numberOfColumns() != 1 || matrixChild0.numberOfRows() != 3 || matrixChild1.numberOfRows() != 3) { if (!matrixChild0.isVector() || !matrixChild1.isVector() || matrixChild0.numberOfChildren() != 3 || matrixChild1.numberOfChildren() != 3) {
return replaceWithUndefinedInPlace(); return replaceWithUndefinedInPlace();
} }
Expression a = matrixChild0.cross(&matrixChild1, reductionContext); Expression a = matrixChild0.cross(&matrixChild1, reductionContext);

View File

@@ -44,8 +44,8 @@ Expression VectorDot::shallowReduce(ExpressionNode::ReductionContext reductionCo
if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) { if (c0.type() == ExpressionNode::Type::Matrix && c1.type() == ExpressionNode::Type::Matrix) {
Matrix matrixChild0 = static_cast<Matrix&>(c0); Matrix matrixChild0 = static_cast<Matrix&>(c0);
Matrix matrixChild1 = static_cast<Matrix&>(c1); Matrix matrixChild1 = static_cast<Matrix&>(c1);
// Dot product is defined between two column matrices of the same dimensions // Dot product is defined between two vectors of the same dimensions
if (matrixChild0.numberOfColumns() != 1 || matrixChild1.numberOfColumns() != 1 || matrixChild0.numberOfRows() != matrixChild1.numberOfRows()) { if (!matrixChild0.isVector() || !matrixChild1.isVector() || matrixChild0.numberOfChildren() != matrixChild1.numberOfChildren()) {
return replaceWithUndefinedInPlace(); return replaceWithUndefinedInPlace();
} }
Expression a = matrixChild0.dot(&matrixChild1, reductionContext); Expression a = matrixChild0.dot(&matrixChild1, reductionContext);

View File

@@ -42,8 +42,8 @@ Expression VectorNorm::shallowReduce(ExpressionNode::ReductionContext reductionC
Expression c = childAtIndex(0); Expression c = childAtIndex(0);
if (c.type() == ExpressionNode::Type::Matrix) { if (c.type() == ExpressionNode::Type::Matrix) {
Matrix matrixChild = static_cast<Matrix&>(c); Matrix matrixChild = static_cast<Matrix&>(c);
if (matrixChild.numberOfColumns() != 1) { if (!matrixChild.isVector()) {
// Norm is only defined on column matrices // Norm is only defined on vectors
return replaceWithUndefinedInPlace(); return replaceWithUndefinedInPlace();
} }
Expression a = matrixChild.norm(reductionContext); Expression a = matrixChild.norm(reductionContext);