From a8d8e72fb48ff6d20d2aaaab75445d0eeb33f463 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9a=20Saviot?= Date: Thu, 4 Jul 2019 09:41:24 +0200 Subject: [PATCH] [poincare] Fix MatrixTrace::shallowReduce --- poincare/include/poincare/matrix_trace.h | 2 +- poincare/src/matrix_trace.cpp | 36 ++++++++++-------------- poincare/test/matrix.cpp | 6 ++-- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/poincare/include/poincare/matrix_trace.h b/poincare/include/poincare/matrix_trace.h index cf130c12a..29e7609d3 100644 --- a/poincare/include/poincare/matrix_trace.h +++ b/poincare/include/poincare/matrix_trace.h @@ -38,7 +38,7 @@ public: static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("trace", 1, &UntypedBuilderOneChild); - Expression shallowReduce(); + Expression shallowReduce(ExpressionNode::ReductionContext reductionContext); }; } diff --git a/poincare/src/matrix_trace.cpp b/poincare/src/matrix_trace.cpp index f4dcf03a7..734875692 100644 --- a/poincare/src/matrix_trace.cpp +++ b/poincare/src/matrix_trace.cpp @@ -14,7 +14,7 @@ constexpr Expression::FunctionHelper MatrixTrace::s_functionHelper; int MatrixTraceNode::numberOfChildren() const { return MatrixTrace::s_functionHelper.numberOfChildren(); } Expression MatrixTraceNode::shallowReduce(ReductionContext reductionContext) { - return MatrixTrace(this).shallowReduce(); + return MatrixTrace(this).shallowReduce(reductionContext); } Layout MatrixTraceNode::createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const { @@ -33,7 +33,7 @@ Evaluation MatrixTraceNode::templatedApproximate(Context * context, Preferenc } -Expression MatrixTrace::shallowReduce() { +Expression MatrixTrace::shallowReduce(ExpressionNode::ReductionContext reductionContext) { { Expression e = Expression::defaultShallowReduce(); if (e.isUndefined()) { @@ -41,32 +41,26 @@ Expression MatrixTrace::shallowReduce() { } } Expression c = childAtIndex(0); -#if MATRIX_EXACT_REDUCING -#if 0 if (c.type() == ExpressionNode::Type::Matrix) { - Matrix m = static_cast(c); - if (m.numberOfRows() != m.numberOfColumns()) { - return Undefined::Builder(); + Matrix matrixChild = static_cast(c); + if (matrixChild.numberOfRows() != matrixChild.numberOfColumns()) { + return replaceWithUndefinedInPlace(); } - int n = m.numberOfRows(); + int n = matrixChild.numberOfRows(); Addition a = Addition::Builder(); for (int i = 0; i < n; i++) { - a.addChildAtIndexInPlace(m.childAtIndex(i+n*i), i, a.numberOfChildren()); + a.addChildAtIndexInPlace(matrixChild.matrixChild(i,i), i, i); } - return a.shallowReduce(context, complexFormat, angleUnit); + replaceWithInPlace(a); + return a.shallowReduce(reductionContext); } - if (!c.recursivelyMatches(Expression::IsMatrix)) { - return c; + /* TODO LEA + if (c.recursivelyMatches(Expression::IsMatrix)) { + return *this; } - return *this; -#endif -#else - if (c.type() != ExpressionNode::Type::Matrix) { - replaceWithInPlace(c); - return c; - } - return *this; -#endif + */ + replaceWithInPlace(c); + return c; } } diff --git a/poincare/test/matrix.cpp b/poincare/test/matrix.cpp index eee007696..cd954401a 100644 --- a/poincare/test/matrix.cpp +++ b/poincare/test/matrix.cpp @@ -84,10 +84,10 @@ QUIZ_CASE(poincare_matrix_simplify) { assert_parsed_expression_simplify_to("dim([[1/√(2),1/2,3][2,1,-3]])", "[[2,3]]"); assert_parsed_expression_simplify_to("inverse([[1/√(2),1/2,3][2,1,-3]])", Undefined::Name()); assert_parsed_expression_simplify_to("inverse([[1,2][3,4]])", "inverse([[1,2][3,4]])"); // TODO: implement matrix inverse if dim < 3 - assert_parsed_expression_simplify_to("trace([[1/√(2),1/2,3][2,1,-3]])", Undefined::Name()); - assert_parsed_expression_simplify_to("trace([[√(2),2][4,3+log(3)]])", "√(2)+3+log(3)"); - assert_parsed_expression_simplify_to("trace(√(2)+log(3))", "√(2)+log(3)"); #endif + assert_parsed_expression_simplify_to("trace([[1/√(2),1/2,3][2,1,-3]])", Undefined::Name()); + assert_parsed_expression_simplify_to("trace([[√(2),2][4,3+log(3)]])", "log(3)+√(2)+3"); + assert_parsed_expression_simplify_to("trace(√(2)+log(3))", "log(3)+√(2)"); assert_parsed_expression_simplify_to("transpose([[1/√(2),1/2,3][2,1,-3]])", "[[√(2)/2,2][1/2,1][3,-3]]"); assert_parsed_expression_simplify_to("transpose(√(4))", "2"); #if MATRIX_EXACT_REDUCING