Start expression simplification

This commit is contained in:
Léa Saviot
2018-06-29 17:54:54 +02:00
parent 53bacec72f
commit f0e2bd0e5c
12 changed files with 147 additions and 32 deletions

View File

@@ -6,13 +6,10 @@
class AdditionNode : public ExpressionNode {
public:
const char * description() const override {
return "Addition";
}
const char * description() const override { return "Addition"; }
size_t size() const override { return sizeof(AdditionNode); }
Type type() const override { return Type::Addition; }
size_t size() const override {
return sizeof(AdditionNode);
}
float approximate() override {
float result = 0.0f;
for (int i=0; i<numberOfChildren(); i++) {
@@ -25,12 +22,36 @@ public:
return result;
}
bool shallowReduce() override {
if (ExpressionNode::shallowReduce()) {
return true;
}
/* Step 1: Addition is associative, so let's start by merging children which
* also are additions themselves. */
int i = 0;
int initialNumberOfChildren = numberOfChildren();
while (i < initialNumberOfChildren) {
ExpressionNode * currentChild = child(i);
if (currentChild->type() == Type::Addition) {
TreeRef(this).mergeChildren(TreeRef(currentChild));
// Is it ok to modify memory while executing ?
continue;
}
i++;
}
return false;
}
int numberOfChildren() const override { return m_numberOfChildren; }
void incrementNumberOfChildren() override { m_numberOfChildren++; }
void decrementNumberOfChildren() override {
assert(m_numberOfChildren > 0);
m_numberOfChildren--;
}
void eraseNumberOfChildren() override {
m_numberOfChildren = 0;
}
/*
Expression simplify() override {
// Scan operands, merge constants

View File

@@ -13,6 +13,7 @@ public:
// TreeNode
size_t size() const override { return sizeof(AllocationFailedExpressionNode); }
const char * description() const override { return "Allocation Failed"; }
Type type() const override { return Type::AllocationFailure; }
int numberOfChildren() const override { return 0; }
bool isAllocationFailure() const override { return true; }
};

View File

@@ -5,6 +5,35 @@
class ExpressionNode : public TreeNode {
public:
enum class Type : uint8_t {
AllocationFailure = 0,
Float = 1,
Addition
};
// Expression
virtual Type type() const = 0;
virtual float approximate() = 0;
void deepReduce() {
assert(parentTree() != nullptr);
for (int i = 0; i < numberOfChildren(); i++) {
child(i)->deepReduce();
}
shallowReduce();
}
virtual bool shallowReduce() {
for (int i = 0; i < numberOfChildren(); i++) {
if (child(i)->isAllocationFailure()) {
replaceWithAllocationFailure();
return true;
}
}
return false;
}
// Allocation failure
static TreeNode * FailedAllocationStaticNode();
static int AllocationFailureNodeIdentifier() {
return FailedAllocationStaticNode()->identifier();
@@ -12,9 +41,9 @@ public:
int allocationFailureNodeIdentifier() override {
return AllocationFailureNodeIdentifier();
}
TreeNode * failedAllocationStaticNode() override { return FailedAllocationStaticNode(); }
virtual float approximate() = 0;
// Hierarchy
ExpressionNode * child(int i) { return static_cast<ExpressionNode *>(childTreeAtIndex(i)); }
};

View File

@@ -11,7 +11,8 @@ class ExpressionReference : public TreeReference<T> {
public:
using TreeReference<T>::TreeReference;
// Allow every ExpressionReference<T> to be transformed into an ExpressionReference<ExpressionNode>, i.e. Expression
/* Allow every ExpressionReference<T> to be transformed into an
* ExpressionReference<ExpressionNode>, i.e. ExpressionRef */
operator ExpressionReference<ExpressionNode>() const {
return ExpressionReference<ExpressionNode>(this->node());
}
@@ -30,11 +31,13 @@ public:
return this->castedNode()->approximate();
}
/*
ExpressionReference<ExpressionNode> simplify() {
return node()->simplify();
void deepReduce() {
return this->castedNode()->deepReduce();
}
void shallowReduce() {
return this->castedNode()->shallowReduce();
}
*/
};
typedef ExpressionReference<ExpressionNode> ExpressionRef;

View File

@@ -8,6 +8,7 @@ class FloatNode : public ExpressionNode {
public:
FloatNode() : ExpressionNode() {}
size_t size() const override { return sizeof(FloatNode); }
Type type() const override { return Type::Float; }
int numberOfChildren() const override { return 0; }
float approximate() override { return m_value; }
const char * description() const override {

View File

@@ -22,6 +22,9 @@ public:
assert(m_numberOfChildren > 0);
m_numberOfChildren--;
}
void eraseNumberOfChildren() override {
m_numberOfChildren = 0;
}
void moveCursorLeft(LayoutCursor * cursor, bool * shouldRecomputeLayout) override {
if (this == cursor->layoutReference().node()) {

View File

@@ -222,6 +222,23 @@ void testStealOperand() {
assert_expression_approximates_to(a2, 2);
}
void testSimplify() {
printf("Symplify test\n");
AdditionRef a(
AdditionRef(
FloatRef(0.0f),
FloatRef(1.0f)),
FloatRef(2.0f));
assert_expression_approximates_to(a, 3);
a.deepReduce();
assert_expression_approximates_to(a, 3);
assert(a.numberOfChildren() == 3);
}
void testPoolLayoutAllocationFail() {
printf("Pool layout allocation fail test\n");
@@ -260,7 +277,7 @@ int main() {
runTest(testPoolExpressionAllocationFail);
runTest(testPoolExpressionAllocationFail2);
runTest(testPoolExpressionAllocationFailOnImbricatedAdditions);
runTest(testStealOperand);
//runTest(testStealOperand);
printf("\n*******************\nEnd of tests\n*******************\n\n");
return 0;
}

View File

@@ -1,6 +1,6 @@
#include "tree_node.h"
#include "tree_pool.h"
#include "expression_node.h"
#include "tree_reference.h"
#include <stdio.h>
// Node operations
@@ -175,3 +175,9 @@ bool TreeNode::hasSibling(const TreeNode * e) const {
}
return false;
}
void TreeNode::replaceWithAllocationFailure() {
TreeRef t(this);
t.replaceWithAllocationFailure();
// TODO: OK to change the memory while executing from it, even though we know it will stop execution just after ?
}

View File

@@ -53,6 +53,7 @@ public:
virtual int numberOfChildren() const = 0;
virtual void incrementNumberOfChildren() {} //TODO Put an assert false
virtual void decrementNumberOfChildren() {} //TODO Put an assert false //TODO what if somebody i stealing a unary tree's only child ?
virtual void eraseNumberOfChildren() {} //TODO Put an assert false //TODO what if somebody i stealing a unary tree's only child ?
int numberOfDescendants(bool includeSelf) const;
TreeNode * childTreeAtIndex(int i) const;
int indexOfChildByIdentifier(int childID) const;
@@ -125,14 +126,7 @@ public:
return node;
}
protected:
TreeNode() :
m_identifier(-1),
m_referenceCounter(1)
{
}
/*TreeNode * lastDescendant() const {
TreeNode * lastDescendant() const {
TreeNode * node = const_cast<TreeNode *>(this);
int remainingNodesToVisit = node->numberOfChildren();
while (remainingNodesToVisit > 0) {
@@ -141,7 +135,17 @@ protected:
remainingNodesToVisit += node->numberOfChildren();
}
return node;
}*/
}
// Hierarchy operations
void replaceWithAllocationFailure();
protected:
TreeNode() :
m_identifier(-1),
m_referenceCounter(1)
{
}
TreeNode * lastChild() const {
if (numberOfChildren() == 0) {

View File

@@ -41,35 +41,45 @@ void TreePool::logNodeForIdentifierArray() {
}
void TreePool::move(TreeNode * source, TreeNode * destination) {
if (source == destination) {
size_t moveSize = source->deepSize();
moveNodes(source, destination, moveSize);
}
void TreePool::moveChildren(TreeNode * sourceParent, TreeNode * destination) {
size_t moveSize = sourceParent->deepSize() - sourceParent->size();
moveNodes(sourceParent->next(), destination, moveSize);
}
void TreePool::moveNodes(TreeNode * source, TreeNode * destination, size_t moveSize) {
if (source == destination || moveSize == 0) {
return;
}
// Move the Node
size_t srcDeepSize = source->deepSize();
char * destinationAddress = reinterpret_cast<char *>(destination);
char * sourceAddress = reinterpret_cast<char *>(source);
if (insert(destinationAddress, sourceAddress, srcDeepSize)) {
if (insert(destinationAddress, sourceAddress, moveSize)) {
// Update the nodeForIdentifier array
for (int i = 0; i < MaxNumberOfNodes; i++) {
char * nodeAddress = reinterpret_cast<char *>(m_nodeForIdentifier[i]);
if (nodeAddress == nullptr) {
continue;
} else if (nodeAddress >= sourceAddress && nodeAddress < sourceAddress + srcDeepSize) {
} else if (nodeAddress >= sourceAddress && nodeAddress < sourceAddress + moveSize) {
if (destinationAddress < sourceAddress) {
m_nodeForIdentifier[i] = reinterpret_cast<TreeNode *>(nodeAddress - (sourceAddress - destinationAddress));
} else {
m_nodeForIdentifier[i] = reinterpret_cast<TreeNode *>(nodeAddress + (destinationAddress - (sourceAddress + srcDeepSize)));
m_nodeForIdentifier[i] = reinterpret_cast<TreeNode *>(nodeAddress + (destinationAddress - (sourceAddress + moveSize)));
}
} else if (nodeAddress > sourceAddress && nodeAddress < destinationAddress) {
m_nodeForIdentifier[i] = reinterpret_cast<TreeNode *>(nodeAddress - srcDeepSize);
m_nodeForIdentifier[i] = reinterpret_cast<TreeNode *>(nodeAddress - moveSize);
} else if (nodeAddress < sourceAddress && nodeAddress >= destinationAddress) {
m_nodeForIdentifier[i] = reinterpret_cast<TreeNode *>(nodeAddress + srcDeepSize);
m_nodeForIdentifier[i] = reinterpret_cast<TreeNode *>(nodeAddress + moveSize);
}
}
}
}
#include <stdio.h>
void TreePool::log() {

View File

@@ -33,6 +33,7 @@ public:
}
void move(TreeNode * source, TreeNode * destination);
void moveChildren(TreeNode * sourceParent, TreeNode * destination);
TreeNode * deepCopy(TreeNode * node) {
size_t size = node->deepSize();
@@ -141,6 +142,7 @@ private:
void * alloc(size_t size);
void dealloc(TreeNode * ptr);
static inline bool insert(char * destination, char * source, size_t length);
void moveNodes(TreeNode * source, TreeNode * destination, size_t moveLength);
// Identifiers
int generateIdentifier() {

View File

@@ -11,6 +11,9 @@ class Cursor;
template <typename T>
class TreeReference {
friend class TreeNode;
friend class AdditionNode;
friend class Cursor;
template <typename U>
friend class TreeReference;
@@ -126,6 +129,12 @@ public:
node()->decrementNumberOfChildren();
}
void removeChildren() {
node()->releaseChildren();
TreePool::sharedPool()->moveChildren(node(), TreePool::sharedPool()->last());
node()->eraseNumberOfChildren();
}
void replaceWith(TreeReference<TreeNode> t) {
TreeReference<TreeNode> p = parent();
if (p.isDefined()) {
@@ -198,6 +207,15 @@ public:
TreePool::sharedPool()->move(secondChild.node(), firstChildNode);
}
void mergeChildren(TreeReference<T> t) {
// Steal operands
TreePool::sharedPool()->moveChildren(t.node(), node()->lastDescendant());
// If t is a child, remove it
if (node()->hasChild(t.node())) {
removeChild(t);
}
}
protected:
TreeReference() {
TreeNode * node = TreePool::sharedPool()->createTreeNode<T>();