diff --git a/poincare/src/simplify/expression_selector.cpp b/poincare/src/simplify/expression_selector.cpp index 6664911bf..980bcd04f 100644 --- a/poincare/src/simplify/expression_selector.cpp +++ b/poincare/src/simplify/expression_selector.cpp @@ -19,6 +19,10 @@ int ExpressionSelector::numberOfNonWildcardChildren() { } int ExpressionSelector::match(const Expression * e, ExpressionMatch * matches) { + return this->match(e, matches, 0); +} + +int ExpressionSelector::match(const Expression * e, ExpressionMatch * matches, int offset) { int numberOfMatches = 0; // Does the current node match? @@ -63,18 +67,25 @@ int ExpressionSelector::match(const Expression * e, ExpressionMatch * matches) { /* This should not happen as a wildcard should be matched _before_ */ assert(false); break; + case ExpressionSelector::Match::SameAs: + /* Here we assume that the match can only have a single child, as we + * don't want to match on wildcards. */ + assert(matches[m_integerValue].numberOfExpressions() == 1); + if (!e->isEquivalentTo((Expression *)matches[m_sameAsPosition].expression(0))) { + return 0; + } + break; } // The current node does match. Let's add it to our matches - matches[numberOfMatches++] = ExpressionMatch(&e, 1); - + matches[offset + numberOfMatches++] = ExpressionMatch(&e, 1); if (m_numberOfChildren != 0) { int numberOfChildMatches = 0; if (!e->isCommutative()) { - numberOfChildMatches = sequentialMatch(e, matches+numberOfMatches); + numberOfChildMatches = sequentialMatch(e, matches, offset+numberOfMatches); } else { - numberOfChildMatches = commutativeMatch(e, matches+numberOfMatches); + numberOfChildMatches = commutativeMatch(e, matches, offset+numberOfMatches); } // We check whether the children matched or not. if (numberOfChildMatches == 0) { @@ -88,7 +99,8 @@ int ExpressionSelector::match(const Expression * e, ExpressionMatch * matches) { } /* This tries to match the children selector sequentialy */ -int ExpressionSelector::sequentialMatch(const Expression * e, ExpressionMatch * matches) { +int ExpressionSelector::sequentialMatch(const Expression * e, + ExpressionMatch * matches, int offset) { int numberOfMatches = 0; for (int i=0; im_match == ExpressionSelector::Match::Wildcard) { assert(false); // There should not be a wildcard for non commutative op. } else { - int numberOfChildMatches = childSelector->match(childExpression, (matches+numberOfMatches)); + int numberOfChildMatches = childSelector->match( + childExpression, + matches, + offset+numberOfMatches); if (numberOfChildMatches == 0) { return 0; } else { @@ -112,7 +127,7 @@ int ExpressionSelector::sequentialMatch(const Expression * e, ExpressionMatch * * a selector and then writes the output ExpressionMatch to matches just like * match would do. */ -int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch * matches) { +int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch * matches, int offset) { // If we have more children to match than the expression has, we cannot match. if (e->numberOfOperands() < m_numberOfChildren) { return 0; @@ -136,7 +151,7 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch * * yet. */ int numberOfChildren = this->numberOfNonWildcardChildren(); - if (!canCommutativelyMatch(e, matches, selectorMatched, numberOfChildren)) { + if (!canCommutativelyMatch(e, matches, selectorMatched, numberOfChildren, offset)) { free(selectorMatched); return 0; } @@ -186,7 +201,7 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch * * table. * * Using the example in the previous comment we would write - * | + | + | (Integer(2),Ineteger(3)) | Integer(4) | + * | + | + | (Integer(2),Integer(3)) | Integer(4) | * * The pointer arithmetic with numberOfMatches, allows us to know how many * matches a selector has written. @@ -202,7 +217,10 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch * * Integer(4) */ int numberOfMatches = 0; for (int i(0); imatch(e->operand(expressionMatched[i]), matches+numberOfMatches); + int numberOfChildMatches = child(i)->match( + e->operand(expressionMatched[i]), + matches, + offset + numberOfMatches); assert(numberOfChildMatches > 0); numberOfMatches += numberOfChildMatches; } @@ -218,7 +236,7 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch * local_expr[j++] = e->operand(i); } } - matches[numberOfMatches++] = ExpressionMatch(local_expr, j); + matches[offset + numberOfMatches++] = ExpressionMatch(local_expr, j); free(local_expr); } @@ -234,7 +252,11 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch * * leftToMatch tells it how many selectors still have to be matched. * Implementation detail: selectors are matched in ascending order. */ -bool ExpressionSelector::canCommutativelyMatch(const Expression * e, ExpressionMatch * matches, uint8_t * selectorMatched, int leftToMatch) { +bool ExpressionSelector::canCommutativelyMatch(const Expression * e, + ExpressionMatch * matches, + uint8_t * selectorMatched, + int leftToMatch, + int offset) { bool hasWildcard = child(m_numberOfChildren-1)->m_match == ExpressionSelector::Match::Wildcard; // This part is used to make sure that we stop once we matched everything. @@ -276,12 +298,13 @@ bool ExpressionSelector::canCommutativelyMatch(const Expression * e, ExpressionM if (selectorMatched[j] != kUnmatched) { continue; } - if (child(i)->match(e->operand(j), matches)) { + int numberOfMatches = child(i)->match(e->operand(j), matches, offset); + if (numberOfMatches) { // We managed to match this selector. selectorMatched[j] = i; /* We check that we can match the rest in this configuration, if so we * are good. */ - if (this->canCommutativelyMatch(e, matches, selectorMatched, leftToMatch - 1)) { + if (this->canCommutativelyMatch(e, matches, selectorMatched, leftToMatch - 1, offset + numberOfMatches)) { return true; } // Otherwise we backtrack. diff --git a/poincare/src/simplify/expression_selector.h b/poincare/src/simplify/expression_selector.h index a4a2cfbf4..d43519124 100644 --- a/poincare/src/simplify/expression_selector.h +++ b/poincare/src/simplify/expression_selector.h @@ -11,6 +11,7 @@ class ExpressionSelector { public: static constexpr ExpressionSelector Any(uint8_t numberOfChildren); static constexpr ExpressionSelector Wildcard(uint8_t numberOfChildren); + static constexpr ExpressionSelector SameAs(int index, uint8_t numberOfChildren); static constexpr ExpressionSelector Type(Expression::Type type, uint8_t numberOfChildren); static constexpr ExpressionSelector TypeAndValue(Expression::Type type, int32_t value, uint8_t numberOfChildren); @@ -27,15 +28,18 @@ private: Type, Wildcard, TypeAndValue, + SameAs, }; + int match(const Expression * e, ExpressionMatch * matches, int offset); + constexpr ExpressionSelector(Match match, Expression::Type type, int32_t integerValue, uint8_t numberOfChildren); int numberOfNonWildcardChildren(); bool canCommutativelyMatch(const Expression * e, ExpressionMatch * matches, - uint8_t * selectorMatched, int leftToMatch); - int commutativeMatch(const Expression * e, ExpressionMatch * matches); - int sequentialMatch(const Expression * e, ExpressionMatch * matches); + uint8_t * selectorMatched, int leftToMatch, int offset); + int commutativeMatch(const Expression * e, ExpressionMatch * matches, int offset); + int sequentialMatch(const Expression * e, ExpressionMatch * matches, int offset); ExpressionSelector * child(int index); Match m_match; @@ -50,6 +54,8 @@ private: int32_t m_integerValue; // m_expressionType == Symbol char * m_symbolName; + // Position of the other match we must be equal to. + int32_t m_sameAsPosition; }; }; }; @@ -68,6 +74,10 @@ constexpr ExpressionSelector ExpressionSelector::Wildcard(uint8_t numberOfChildr return ExpressionSelector(Match::Wildcard, (Expression::Type)0, 0, numberOfChildren); } +constexpr ExpressionSelector ExpressionSelector::SameAs(int index, uint8_t numberOfChildren) { + return ExpressionSelector(Match::SameAs, (Expression::Type)0, index, numberOfChildren); +} + constexpr ExpressionSelector ExpressionSelector::Type(Expression::Type type, uint8_t numberOfChildren) { return ExpressionSelector(Match::Type, type, 0, numberOfChildren); } diff --git a/poincare/src/simplify/rules.pr b/poincare/src/simplify/rules.pr index 27297c52d..8a27eb103 100644 --- a/poincare/src/simplify/rules.pr +++ b/poincare/src/simplify/rules.pr @@ -2,9 +2,20 @@ Addition(Addition(a*),b*)->Addition(a*,b*); Addition(Integer.a,Integer.b)->$AddIntegers(a,b); Addition(Integer.a,Integer.b,c*)->Addition($AddIntegers(a,b),c*); + +Subtraction(a,b)->Addition(a,Product(b,Integer[-1])); +Addition(a, Product(a,Integer[-1]))->Integer[0]; +Addition(a, Product(a,Integer[-1]), b)->b; +Addition(a, Product(a,Integer[-1]), b, c*)->Addition(b,c*); + +Addition(a,a,b*)->Addition(Product(a,Integer[2]),b*); +Addition(a,Product(a,b),c*)->Addition(Product(a,Addition(b,Integer[1])),c*); +Addition(Product(a,b),Product(a,c),d*)->Addition(Product(a,Addition(b,c)),d*); +Addition(a,a)->Product(a,Integer[2]); +Addition(a,Product(a,b))->Product(a,Addition(b,Integer[1])); +Addition(Product(a,b),Product(a,c))->Product(a,Addition(b,c)); + Product(Product(a*),b*)->Product(a*,b*); Product(Integer[0],a*)->Integer[0]; Product(Integer.a,Integer.b)->$MultiplyIntegers(a,b); Product(Integer.a,Integer.b,c*)->Product($MultiplyIntegers(a,b),c*); -Product(Addition(a, b), c*)->Addition(Product(a, c*), Product(b, c*)); /* Distributivity. */ -Product(Addition(a, b, c*), d*)->Addition(Product(a, d*), Product(d*, Addition(b, c*))); diff --git a/poincare/src/simplify/rules_generation/node.cpp b/poincare/src/simplify/rules_generation/node.cpp index db2345fbc..0755b8ddb 100644 --- a/poincare/src/simplify/rules_generation/node.cpp +++ b/poincare/src/simplify/rules_generation/node.cpp @@ -48,7 +48,17 @@ std::string Node::generateSelectorConstructor(Rule * context) { switch (m_referenceMode) { case Node::ReferenceMode::None: case Node::ReferenceMode::SingleNode: - result << "ExpressionSelector::Any("; + { + // We try to see if we already saw this node before. + Node * selector = context->selector(); + int index = selector->flatIndexOfChildNamed(*m_referenceName); + int my_index = selector->flatIndexOfChildRef(this); + if (index >= 0 && index < my_index) { + result << "ExpressionSelector::SameAs(" << index << ", "; + } else { + result << "ExpressionSelector::Any("; + } + } break; case Node::ReferenceMode::Wildcard: result << "ExpressionSelector::Wildcard("; @@ -106,6 +116,22 @@ std::string Node::generateBuilderConstructor(Rule * context) { return result.str(); } +int Node::flatIndexOfChildRef(Node * node) { + if (m_referenceName != nullptr && node == this) { + return 0; + } + int sum=1; + for (Node * child : *m_children) { + int index = child->flatIndexOfChildRef(node); + if (index >= 0) { + return sum+index; + } else { + sum += child->totalDescendantCountIncludingSelf(); + } + } + return -1; +} + int Node::flatIndexOfChildNamed(std::string name) { if (m_referenceName != nullptr && *m_referenceName == name) { return 0; diff --git a/poincare/src/simplify/rules_generation/node.h b/poincare/src/simplify/rules_generation/node.h index 2f070f510..08d8d1ee5 100644 --- a/poincare/src/simplify/rules_generation/node.h +++ b/poincare/src/simplify/rules_generation/node.h @@ -28,6 +28,7 @@ public: int totalDescendantCountIncludingSelf(); int flatIndexOfChildNamed(std::string name); + int flatIndexOfChildRef(Node * node); void generateSelectorTree(Rule * context); void generateBuilderTree(Rule * context); diff --git a/poincare/test/simplify_product.cpp b/poincare/test/simplify_product.cpp index dce68a966..5992395ba 100644 --- a/poincare/test/simplify_product.cpp +++ b/poincare/test/simplify_product.cpp @@ -15,8 +15,18 @@ QUIZ_CASE(poincare_simplify_product_by_zero) { assert(simplifies_to("3*(5+4)", "27")); } -QUIZ_CASE(poincare_simplify_distributive) { - assert(equivalent_to("3*(x+y)", "3*x+3*y")); - assert(equivalent_to("3*(x+y+z)", "3*x+3*y+3*z")); - assert(equivalent_to("3*(x+y+z+w)", "3*x+3*y+3*z+3*w")); +QUIZ_CASE(poincare_simplify_distributive_reverse) { + assert(equivalent_to("x+x", "2*x")); + assert(equivalent_to("2*x+x", "3*x")); + assert(equivalent_to("x*2+x", "3*x")); + assert(equivalent_to("2*x+2*x", "4*x")); + assert(equivalent_to("x*2+2*y", "2*(x+y)")); + assert(equivalent_to("x+x+y+y", "2*x+2*y")); + assert(equivalent_to("2*x+2*y", "2*(x+y)")); + //assert(equivalent_to("x+x+y+y", "2*(x+y)")); + assert(equivalent_to("x-x-y+y", "0)")); + assert(equivalent_to("x+y-x-y", "0")); + assert(equivalent_to("x+x*y", "x*(y+1)")); + assert(equivalent_to("x-x", "0")); + assert(equivalent_to("x-x", "0")); }