Add the possibility to detect Addition(a,a) in the rules.

Change-Id: I57ea5af186304ce8872001793a3c46e7828948e2
This commit is contained in:
Felix Raimundo
2016-04-22 14:33:48 +02:00
parent ad33e7ffa4
commit b7b43edea3
6 changed files with 105 additions and 24 deletions

View File

@@ -19,6 +19,10 @@ int ExpressionSelector::numberOfNonWildcardChildren() {
}
int ExpressionSelector::match(const Expression * e, ExpressionMatch * matches) {
return this->match(e, matches, 0);
}
int ExpressionSelector::match(const Expression * e, ExpressionMatch * matches, int offset) {
int numberOfMatches = 0;
// Does the current node match?
@@ -63,18 +67,25 @@ int ExpressionSelector::match(const Expression * e, ExpressionMatch * matches) {
/* This should not happen as a wildcard should be matched _before_ */
assert(false);
break;
case ExpressionSelector::Match::SameAs:
/* Here we assume that the match can only have a single child, as we
* don't want to match on wildcards. */
assert(matches[m_integerValue].numberOfExpressions() == 1);
if (!e->isEquivalentTo((Expression *)matches[m_sameAsPosition].expression(0))) {
return 0;
}
break;
}
// The current node does match. Let's add it to our matches
matches[numberOfMatches++] = ExpressionMatch(&e, 1);
matches[offset + numberOfMatches++] = ExpressionMatch(&e, 1);
if (m_numberOfChildren != 0) {
int numberOfChildMatches = 0;
if (!e->isCommutative()) {
numberOfChildMatches = sequentialMatch(e, matches+numberOfMatches);
numberOfChildMatches = sequentialMatch(e, matches, offset+numberOfMatches);
} else {
numberOfChildMatches = commutativeMatch(e, matches+numberOfMatches);
numberOfChildMatches = commutativeMatch(e, matches, offset+numberOfMatches);
}
// We check whether the children matched or not.
if (numberOfChildMatches == 0) {
@@ -88,7 +99,8 @@ int ExpressionSelector::match(const Expression * e, ExpressionMatch * matches) {
}
/* This tries to match the children selector sequentialy */
int ExpressionSelector::sequentialMatch(const Expression * e, ExpressionMatch * matches) {
int ExpressionSelector::sequentialMatch(const Expression * e,
ExpressionMatch * matches, int offset) {
int numberOfMatches = 0;
for (int i=0; i<m_numberOfChildren; i++) {
ExpressionSelector * childSelector = child(i);
@@ -97,7 +109,10 @@ int ExpressionSelector::sequentialMatch(const Expression * e, ExpressionMatch *
if (childSelector->m_match == ExpressionSelector::Match::Wildcard) {
assert(false); // There should not be a wildcard for non commutative op.
} else {
int numberOfChildMatches = childSelector->match(childExpression, (matches+numberOfMatches));
int numberOfChildMatches = childSelector->match(
childExpression,
matches,
offset+numberOfMatches);
if (numberOfChildMatches == 0) {
return 0;
} else {
@@ -112,7 +127,7 @@ int ExpressionSelector::sequentialMatch(const Expression * e, ExpressionMatch *
* a selector and then writes the output ExpressionMatch to matches just like
* match would do.
*/
int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch * matches) {
int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch * matches, int offset) {
// If we have more children to match than the expression has, we cannot match.
if (e->numberOfOperands() < m_numberOfChildren) {
return 0;
@@ -136,7 +151,7 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch *
* yet. */
int numberOfChildren = this->numberOfNonWildcardChildren();
if (!canCommutativelyMatch(e, matches, selectorMatched, numberOfChildren)) {
if (!canCommutativelyMatch(e, matches, selectorMatched, numberOfChildren, offset)) {
free(selectorMatched);
return 0;
}
@@ -186,7 +201,7 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch *
* table.
*
* Using the example in the previous comment we would write
* | + | + | (Integer(2),Ineteger(3)) | Integer(4) |
* | + | + | (Integer(2),Integer(3)) | Integer(4) |
*
* The pointer arithmetic with numberOfMatches, allows us to know how many
* matches a selector has written.
@@ -202,7 +217,10 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch *
* Integer(4) */
int numberOfMatches = 0;
for (int i(0); i<numberOfChildren; i++) {
int numberOfChildMatches = child(i)->match(e->operand(expressionMatched[i]), matches+numberOfMatches);
int numberOfChildMatches = child(i)->match(
e->operand(expressionMatched[i]),
matches,
offset + numberOfMatches);
assert(numberOfChildMatches > 0);
numberOfMatches += numberOfChildMatches;
}
@@ -218,7 +236,7 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch *
local_expr[j++] = e->operand(i);
}
}
matches[numberOfMatches++] = ExpressionMatch(local_expr, j);
matches[offset + numberOfMatches++] = ExpressionMatch(local_expr, j);
free(local_expr);
}
@@ -234,7 +252,11 @@ int ExpressionSelector::commutativeMatch(const Expression * e, ExpressionMatch *
* leftToMatch tells it how many selectors still have to be matched.
* Implementation detail: selectors are matched in ascending order.
*/
bool ExpressionSelector::canCommutativelyMatch(const Expression * e, ExpressionMatch * matches, uint8_t * selectorMatched, int leftToMatch) {
bool ExpressionSelector::canCommutativelyMatch(const Expression * e,
ExpressionMatch * matches,
uint8_t * selectorMatched,
int leftToMatch,
int offset) {
bool hasWildcard = child(m_numberOfChildren-1)->m_match == ExpressionSelector::Match::Wildcard;
// This part is used to make sure that we stop once we matched everything.
@@ -276,12 +298,13 @@ bool ExpressionSelector::canCommutativelyMatch(const Expression * e, ExpressionM
if (selectorMatched[j] != kUnmatched) {
continue;
}
if (child(i)->match(e->operand(j), matches)) {
int numberOfMatches = child(i)->match(e->operand(j), matches, offset);
if (numberOfMatches) {
// We managed to match this selector.
selectorMatched[j] = i;
/* We check that we can match the rest in this configuration, if so we
* are good. */
if (this->canCommutativelyMatch(e, matches, selectorMatched, leftToMatch - 1)) {
if (this->canCommutativelyMatch(e, matches, selectorMatched, leftToMatch - 1, offset + numberOfMatches)) {
return true;
}
// Otherwise we backtrack.

View File

@@ -11,6 +11,7 @@ class ExpressionSelector {
public:
static constexpr ExpressionSelector Any(uint8_t numberOfChildren);
static constexpr ExpressionSelector Wildcard(uint8_t numberOfChildren);
static constexpr ExpressionSelector SameAs(int index, uint8_t numberOfChildren);
static constexpr ExpressionSelector Type(Expression::Type type, uint8_t numberOfChildren);
static constexpr ExpressionSelector TypeAndValue(Expression::Type type, int32_t value, uint8_t numberOfChildren);
@@ -27,15 +28,18 @@ private:
Type,
Wildcard,
TypeAndValue,
SameAs,
};
int match(const Expression * e, ExpressionMatch * matches, int offset);
constexpr ExpressionSelector(Match match, Expression::Type type, int32_t integerValue, uint8_t numberOfChildren);
int numberOfNonWildcardChildren();
bool canCommutativelyMatch(const Expression * e, ExpressionMatch * matches,
uint8_t * selectorMatched, int leftToMatch);
int commutativeMatch(const Expression * e, ExpressionMatch * matches);
int sequentialMatch(const Expression * e, ExpressionMatch * matches);
uint8_t * selectorMatched, int leftToMatch, int offset);
int commutativeMatch(const Expression * e, ExpressionMatch * matches, int offset);
int sequentialMatch(const Expression * e, ExpressionMatch * matches, int offset);
ExpressionSelector * child(int index);
Match m_match;
@@ -50,6 +54,8 @@ private:
int32_t m_integerValue;
// m_expressionType == Symbol
char * m_symbolName;
// Position of the other match we must be equal to.
int32_t m_sameAsPosition;
};
};
};
@@ -68,6 +74,10 @@ constexpr ExpressionSelector ExpressionSelector::Wildcard(uint8_t numberOfChildr
return ExpressionSelector(Match::Wildcard, (Expression::Type)0, 0, numberOfChildren);
}
constexpr ExpressionSelector ExpressionSelector::SameAs(int index, uint8_t numberOfChildren) {
return ExpressionSelector(Match::SameAs, (Expression::Type)0, index, numberOfChildren);
}
constexpr ExpressionSelector ExpressionSelector::Type(Expression::Type type, uint8_t numberOfChildren) {
return ExpressionSelector(Match::Type, type, 0, numberOfChildren);
}

View File

@@ -2,9 +2,20 @@
Addition(Addition(a*),b*)->Addition(a*,b*);
Addition(Integer.a,Integer.b)->$AddIntegers(a,b);
Addition(Integer.a,Integer.b,c*)->Addition($AddIntegers(a,b),c*);
Subtraction(a,b)->Addition(a,Product(b,Integer[-1]));
Addition(a, Product(a,Integer[-1]))->Integer[0];
Addition(a, Product(a,Integer[-1]), b)->b;
Addition(a, Product(a,Integer[-1]), b, c*)->Addition(b,c*);
Addition(a,a,b*)->Addition(Product(a,Integer[2]),b*);
Addition(a,Product(a,b),c*)->Addition(Product(a,Addition(b,Integer[1])),c*);
Addition(Product(a,b),Product(a,c),d*)->Addition(Product(a,Addition(b,c)),d*);
Addition(a,a)->Product(a,Integer[2]);
Addition(a,Product(a,b))->Product(a,Addition(b,Integer[1]));
Addition(Product(a,b),Product(a,c))->Product(a,Addition(b,c));
Product(Product(a*),b*)->Product(a*,b*);
Product(Integer[0],a*)->Integer[0];
Product(Integer.a,Integer.b)->$MultiplyIntegers(a,b);
Product(Integer.a,Integer.b,c*)->Product($MultiplyIntegers(a,b),c*);
Product(Addition(a, b), c*)->Addition(Product(a, c*), Product(b, c*)); /* Distributivity. */
Product(Addition(a, b, c*), d*)->Addition(Product(a, d*), Product(d*, Addition(b, c*)));

View File

@@ -48,7 +48,17 @@ std::string Node::generateSelectorConstructor(Rule * context) {
switch (m_referenceMode) {
case Node::ReferenceMode::None:
case Node::ReferenceMode::SingleNode:
result << "ExpressionSelector::Any(";
{
// We try to see if we already saw this node before.
Node * selector = context->selector();
int index = selector->flatIndexOfChildNamed(*m_referenceName);
int my_index = selector->flatIndexOfChildRef(this);
if (index >= 0 && index < my_index) {
result << "ExpressionSelector::SameAs(" << index << ", ";
} else {
result << "ExpressionSelector::Any(";
}
}
break;
case Node::ReferenceMode::Wildcard:
result << "ExpressionSelector::Wildcard(";
@@ -106,6 +116,22 @@ std::string Node::generateBuilderConstructor(Rule * context) {
return result.str();
}
int Node::flatIndexOfChildRef(Node * node) {
if (m_referenceName != nullptr && node == this) {
return 0;
}
int sum=1;
for (Node * child : *m_children) {
int index = child->flatIndexOfChildRef(node);
if (index >= 0) {
return sum+index;
} else {
sum += child->totalDescendantCountIncludingSelf();
}
}
return -1;
}
int Node::flatIndexOfChildNamed(std::string name) {
if (m_referenceName != nullptr && *m_referenceName == name) {
return 0;

View File

@@ -28,6 +28,7 @@ public:
int totalDescendantCountIncludingSelf();
int flatIndexOfChildNamed(std::string name);
int flatIndexOfChildRef(Node * node);
void generateSelectorTree(Rule * context);
void generateBuilderTree(Rule * context);

View File

@@ -15,8 +15,18 @@ QUIZ_CASE(poincare_simplify_product_by_zero) {
assert(simplifies_to("3*(5+4)", "27"));
}
QUIZ_CASE(poincare_simplify_distributive) {
assert(equivalent_to("3*(x+y)", "3*x+3*y"));
assert(equivalent_to("3*(x+y+z)", "3*x+3*y+3*z"));
assert(equivalent_to("3*(x+y+z+w)", "3*x+3*y+3*z+3*w"));
QUIZ_CASE(poincare_simplify_distributive_reverse) {
assert(equivalent_to("x+x", "2*x"));
assert(equivalent_to("2*x+x", "3*x"));
assert(equivalent_to("x*2+x", "3*x"));
assert(equivalent_to("2*x+2*x", "4*x"));
assert(equivalent_to("x*2+2*y", "2*(x+y)"));
assert(equivalent_to("x+x+y+y", "2*x+2*y"));
assert(equivalent_to("2*x+2*y", "2*(x+y)"));
//assert(equivalent_to("x+x+y+y", "2*(x+y)"));
assert(equivalent_to("x-x-y+y", "0)"));
assert(equivalent_to("x+y-x-y", "0"));
assert(equivalent_to("x+x*y", "x*(y+1)"));
assert(equivalent_to("x-x", "0"));
assert(equivalent_to("x-x", "0"));
}