diff --git a/poincare/include/poincare/determinant.h b/poincare/include/poincare/determinant.h index 903760908..0abe92881 100644 --- a/poincare/include/poincare/determinant.h +++ b/poincare/include/poincare/determinant.h @@ -39,7 +39,7 @@ public: static constexpr Expression::FunctionHelper s_functionHelper = Expression::FunctionHelper("det", 1, &UntypedBuilderOneChild); - Expression shallowReduce(Context * context); + Expression shallowReduce(ExpressionNode::ReductionContext reductionContext); }; } diff --git a/poincare/include/poincare/multiplication.h b/poincare/include/poincare/multiplication.h index 260e530c7..8d1905833 100644 --- a/poincare/include/poincare/multiplication.h +++ b/poincare/include/poincare/multiplication.h @@ -70,6 +70,7 @@ public: static Multiplication Builder(Expression e1) { return Multiplication::Builder(&e1, 1); } static Multiplication Builder(Expression e1, Expression e2) { return Multiplication::Builder(ArrayBuilder(e1, e2).array(), 2); } static Multiplication Builder(Expression e1, Expression e2, Expression e3) { return Multiplication::Builder(ArrayBuilder(e1, e2, e3).array(), 3); } + static Multiplication Builder(Expression e1, Expression e2, Expression e3, Expression e4) { return Multiplication::Builder(ArrayBuilder(e1, e2, e3, e4).array(), 4); } static Multiplication Builder(Expression * children, size_t numberOfChildren) { return TreeHandle::NAryBuilder(children, numberOfChildren); } template static void computeOnArrays(T * m, T * n, T * result, int mNumberOfColumns, int mNumberOfRows, int nNumberOfColumns); diff --git a/poincare/src/determinant.cpp b/poincare/src/determinant.cpp index 98f19ffbe..dbe616a16 100644 --- a/poincare/src/determinant.cpp +++ b/poincare/src/determinant.cpp @@ -1,7 +1,10 @@ #include +#include #include #include +#include #include +#include extern "C" { #include } @@ -21,7 +24,6 @@ int DeterminantNode::serialize(char * buffer, int bufferSize, Preferences::Print return SerializationHelper::Prefix(this, buffer, bufferSize, floatDisplayMode, numberOfSignificantDigits, Determinant::s_functionHelper.name()); } -// TODO: handle this exactly in shallowReduce for small dimensions. template Evaluation DeterminantNode::templatedApproximate(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const { Evaluation input = childAtIndex(0)->approximate(T(), context, complexFormat, angleUnit); @@ -29,10 +31,10 @@ Evaluation DeterminantNode::templatedApproximate(Context * context, Preferenc } Expression DeterminantNode::shallowReduce(ReductionContext reductionContext) { - return Determinant(this).shallowReduce(reductionContext.context()); + return Determinant(this).shallowReduce(reductionContext); } -Expression Determinant::shallowReduce(Context * context) { +Expression Determinant::shallowReduce(ExpressionNode::ReductionContext reductionContext) { { Expression e = Expression::defaultShallowReduce(); if (e.isUndefined()) { @@ -41,11 +43,60 @@ Expression Determinant::shallowReduce(Context * context) { } Expression c0 = childAtIndex(0); // det(A) = A if A is not a matrix - if (!SortedIsMatrix(c0, context)) { + if (!SortedIsMatrix(c0, reductionContext.context())) { replaceWithInPlace(c0); return c0; } - //TODO LEA for matrix + if (c0.type() == ExpressionNode::Type::Matrix) { + Matrix m0 = static_cast(c0); + int dim = m0.numberOfRows(); + if (dim != m0.numberOfColumns()) { + // Determinant is for square matrices + return replaceWithUndefinedInPlace(); + } + Expression result; + if (dim == 1) { + // Determinant of [[a]] is a + result = m0.childAtIndex(0); + } else if (dim == 2) { + /* |a b| + * Determinant of |c d| is ad-bc */ + Multiplication ad = Multiplication::Builder(m0.matrixChild(0,0), m0.matrixChild(1,1)); + Multiplication bc = Multiplication::Builder(m0.matrixChild(0,1), m0.matrixChild(1,0)); + result = Subtraction::Builder(ad, bc); + ad.shallowReduce(reductionContext); + bc.shallowReduce(reductionContext); + } else if (dim == 3) { + /* |a b c| + * Determinant of |d e f| is aei+bfg+cdh-ceg-bdi-afh + * |g h i| */ + Expression a = m0.matrixChild(0,0); + Expression b = m0.matrixChild(0,1); + Expression c = m0.matrixChild(0,2); + Expression d = m0.matrixChild(1,0); + Expression e = m0.matrixChild(1,1); + Expression f = m0.matrixChild(1,2); + Expression g = m0.matrixChild(2,0); + Expression h = m0.matrixChild(2,1); + Expression i = m0.matrixChild(2,2); + constexpr int additionChildrenCount = 6; + Expression additionChildren[additionChildrenCount] = { + Multiplication::Builder(a.clone(), e.clone(), i.clone()), + Multiplication::Builder(b.clone(), f.clone(), g.clone()), + Multiplication::Builder(c.clone(), d.clone(), h.clone()), + Multiplication::Builder(Rational::Builder(-1), c, e, g), + Multiplication::Builder(Rational::Builder(-1), b, d, i), + Multiplication::Builder(Rational::Builder(-1), a, f, h)}; + result = Addition::Builder(additionChildren, additionChildrenCount); + for (int i = 0; i < additionChildrenCount; i++) { + additionChildren[i].shallowReduce(reductionContext); + } + } + if (!result.isUninitialized()) { + replaceWithInPlace(result); + return result.shallowReduce(reductionContext); + } + } return *this; } diff --git a/poincare/test/matrix.cpp b/poincare/test/matrix.cpp index 964eeee3b..5499f7284 100644 --- a/poincare/test/matrix.cpp +++ b/poincare/test/matrix.cpp @@ -33,9 +33,15 @@ QUIZ_CASE(poincare_matrix_simplify) { assert_parsed_expression_simplify_to("[[1,2][3,4]]^(-1)", "[[-2,1][3/2,-1/2]]"); // Determinant - assert_parsed_expression_simplify_to("det([[1,2][3,4]])", "det([[1,2][3,4]])"); // TODO: implement determinant if dim < 3 - assert_parsed_expression_simplify_to("det([[2,2][3,4]])", "det([[2,2][3,4]])"); - assert_parsed_expression_simplify_to("det([[2,2][3,3]])", "det([[2,2][3,3]])"); + assert_parsed_expression_simplify_to("det(π+π)", "2×π"); + assert_parsed_expression_simplify_to("det([[π+π]])", "2×π"); + assert_parsed_expression_simplify_to("det([[1,2][3,4]])", "-2"); + assert_parsed_expression_simplify_to("det([[2,2][3,4]])", "2"); + assert_parsed_expression_simplify_to("det([[2,2][3,4][3,4]])", Undefined::Name()); + assert_parsed_expression_simplify_to("det([[2,2][3,3]])", "0"); + assert_parsed_expression_simplify_to("det([[1,2,3][4,5,6][7,8,9]])", "0"); + assert_parsed_expression_simplify_to("det([[1,2,3][4,5,6][7,8,9]])", "0"); + assert_parsed_expression_simplify_to("det([[1,2,3][4π,5,6][7,8,9]])", "24×π-24"); // Dimension assert_parsed_expression_simplify_to("dim(3)", "[[1,1]]");