[poincare] Fix extractUnit into removeUnit paradigm

This commit is contained in:
Émilie Feral
2020-04-16 15:25:49 +02:00
parent 4e2da5de05
commit d805c59202
19 changed files with 108 additions and 68 deletions

View File

@@ -25,7 +25,7 @@ public:
// Properties
Type type() const override { return Type::Division; }
int polynomialDegree(Context * context, const char * symbolName) const override;
Expression extractUnits() override { assert(false); return ExpressionNode::extractUnits(); }
Expression removeUnit(Expression * unit) override { assert(false); return ExpressionNode::removeUnit(unit); }
// Approximation
virtual Evaluation<float> approximate(SinglePrecision p, Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const override {

View File

@@ -22,7 +22,7 @@ public:
// Properties
Type type() const override { return Type::EmptyExpression; }
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
Expression extractUnits() override { assert(false); return ExpressionNode::extractUnits(); }
Expression removeUnit(Expression * unit) override { assert(false); return ExpressionNode::removeUnit(unit); }
// Simplification
LayoutShape leftLayoutShape() const override {

View File

@@ -199,7 +199,7 @@ public:
Expression replaceSymbolWithExpression(const SymbolAbstract & symbol, const Expression & expression) { return node()->replaceSymbolWithExpression(symbol, expression); }
/* Units */
Expression extractUnits() { return node()->extractUnits(); }
Expression removeUnit(Expression * unit) { return node()->removeUnit(unit); }
bool hasUnit() const;
/* Complex */

View File

@@ -178,7 +178,7 @@ public:
virtual float characteristicXRange(Context * context, Preferences::AngleUnit angleUnit) const;
bool isOfType(Type * types, int length) const;
virtual Expression extractUnits(); // Only reduced nodes should answer
virtual Expression removeUnit(Expression * unit); // Only reduced nodes should answer
/* Simplification */
/* SimplificationOrder returns:

View File

@@ -25,7 +25,7 @@ public:
int polynomialDegree(Context * context, const char * symbolName) const override;
int getPolynomialCoefficients(Context * context, const char * symbolName, Expression coefficients[], ExpressionNode::SymbolicComputation symbolicComputation) const override;
bool childAtIndexNeedsUserParentheses(const Expression & child, int childIndex) const override;
Expression extractUnits() override;
Expression removeUnit(Expression * unit) override;
// Approximation
template<typename T> static Complex<T> compute(const std::complex<T> c, const std::complex<T> d, Preferences::ComplexFormat complexFormat) { return Complex<T>::Builder(c*d); }
@@ -65,6 +65,7 @@ private:
class Multiplication : public NAryExpression {
friend class Addition;
friend class Power;
friend class MultiplicationNode;
public:
Multiplication(const MultiplicationNode * n) : NAryExpression(n) {}
static Multiplication Builder(const Tuple & children = {}) { return TreeHandle::NAryBuilder<Multiplication, MultiplicationNode>(convert(children)); }
@@ -76,7 +77,7 @@ public:
// Properties
int getPolynomialCoefficients(Context * context, const char * symbolName, Expression coefficients[], ExpressionNode::SymbolicComputation symbolicComputation) const;
Expression extractUnits();
// Approximation
template<typename T> static void computeOnArrays(T * m, T * n, T * result, int mNumberOfColumns, int mNumberOfRows, int nNumberOfColumns);
// Simplification
@@ -88,6 +89,9 @@ public:
NAryExpression::sortChildrenInPlace(order, context, false, canBeInterrupted);
}
private:
// Unit
Expression removeUnit(Expression * unit);
// Simplification
Expression privateShallowReduce(ExpressionNode::ReductionContext reductionContext, bool expand, bool canBeInterrupted);
void factorizeBase(int i, int j, ExpressionNode::ReductionContext reductionContext);

View File

@@ -20,7 +20,7 @@ public:
#endif
private:
Expression extractUnits() override { assert(false); return ExpressionNode::extractUnits(); }
Expression removeUnit(Expression * unit) override { assert(false); return ExpressionNode::removeUnit(unit); }
// Layout
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
int serialize(char * buffer, int bufferSize, Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;

View File

@@ -21,7 +21,7 @@ public:
// Properties
Type type() const override { return Type::Parenthesis; }
int polynomialDegree(Context * context, const char * symbolName) const override;
Expression extractUnits() override { assert(false); return ExpressionNode::extractUnits(); }
Expression removeUnit(Expression * unit) override { assert(false); return ExpressionNode::removeUnit(unit); }
// Layout
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;

View File

@@ -29,7 +29,7 @@ public:
Sign sign(Context * context) const override;
Expression setSign(Sign s, ReductionContext reductionContext) override;
bool childAtIndexNeedsUserParentheses(const Expression & child, int childIndex) const override;
Expression extractUnits() override;
Expression removeUnit(Expression * unit) override;
int polynomialDegree(Context * context, const char * symbolName) const override;
int getPolynomialCoefficients(Context * context, const char * symbolName, Expression coefficients[], ExpressionNode::SymbolicComputation symbolicComputation) const override;
@@ -84,6 +84,9 @@ private:
constexpr static int k_maxExactPowerMatrix = 100;
constexpr static int k_maxNumberOfTermsInExpandedMultinome = 25;
// Unit
Expression removeUnit(Expression * unit);
// Simplification
Expression denominator(ExpressionNode::ReductionContext reductionContext) const;

View File

@@ -24,7 +24,7 @@ public:
Type type() const override { return Type::Subtraction; }
int polynomialDegree(Context * context, const char * symbolName) const override;
bool childAtIndexNeedsUserParentheses(const Expression & child, int childIndex) const override;
Expression extractUnits() override { assert(false); return ExpressionNode::extractUnits(); }
Expression removeUnit(Expression * unit) override { assert(false); return ExpressionNode::removeUnit(unit); }
// Approximation
template<typename T> static Complex<T> compute(const std::complex<T> c, const std::complex<T> d, Preferences::ComplexFormat complexFormat) { return Complex<T>::Builder(c - d); }

View File

@@ -151,7 +151,7 @@ public:
// Expression Properties
Type type() const override { return Type::Unit; }
Sign sign(Context * context) const override;
Expression extractUnits() override;
Expression removeUnit(Expression * unit) override;
/* Layout */
Layout createLayout(Preferences::PrintFloatMode floatDisplayMode, int numberOfSignificantDigits) const override;
@@ -183,6 +183,7 @@ private:
};
class Unit final : public Expression {
friend class UnitNode;
public:
typedef UnitNode::Prefix Prefix;
typedef UnitNode::Representative Representative;
@@ -728,6 +729,9 @@ public:
// Simplification
Expression shallowReduce(ExpressionNode::ReductionContext reductionContext);
void chooseBestMultipleForValue(double & value, const int exponent, ExpressionNode::ReductionContext reductionContext);
private:
Expression removeUnit(Expression * unit);
};
}

View File

@@ -20,7 +20,7 @@ public:
Type type() const override { return Type::UnitConvert; }
private:
Expression extractUnits() override { assert(false); return ExpressionNode::extractUnits(); }
Expression removeUnit(Expression * unit) override { assert(false); return ExpressionNode::removeUnit(unit); }
// Simplification
Expression shallowReduce(ReductionContext reductionContext) override;
// Evalutation

View File

@@ -155,10 +155,12 @@ Expression Addition::shallowReduce(ExpressionNode::ReductionContext reductionCon
/* Step 2: Handle the units. All children should have the same unit, otherwise
* the result is not homogeneous. */
{
Expression unit = childAtIndex(0).extractUnits();
Expression unit;
childAtIndex(0).removeUnit(&unit);
const bool hasUnit = !unit.isUninitialized();
for (int i = 1; i < childrenCount; i++) {
Expression otherUnit = childAtIndex(i).extractUnits();
Expression otherUnit;
childAtIndex(i).removeUnit(&otherUnit);
if (hasUnit == otherUnit.isUninitialized() ||
(hasUnit && !unit.isIdenticalTo(otherUnit)))
{
@@ -166,15 +168,6 @@ Expression Addition::shallowReduce(ExpressionNode::ReductionContext reductionCon
}
}
if (hasUnit) {
for (int i = 0; i < childrenCount; i++) {
/* Any unary hierarchy must be squashed here since it has not been
* done in extractUnits.
*/
Expression child = childAtIndex(i);
if (child.type() == ExpressionNode::Type::Multiplication) {
static_cast<Multiplication &>(child).squashUnaryHierarchyInPlace();
}
}
Expression addition = shallowReduce(reductionContext);
Multiplication result = Multiplication::Builder(unit);
result.mergeSameTypeChildrenInPlace();

View File

@@ -350,7 +350,9 @@ Expression Expression::defaultHandleUnitsInChildren() {
// Generically, an Expression does not accept any Unit in its children.
const int childrenCount = numberOfChildren();
for (int i = 0; i < childrenCount; i++) {
if (!childAtIndex(i).extractUnits().isUninitialized()) {
Expression unit;
childAtIndex(i).removeUnit(&unit);
if (!unit.isUninitialized()) {
return replaceWithUndefinedInPlace();
}
}

View File

@@ -136,8 +136,8 @@ bool ExpressionNode::isOfType(Type * types, int length) const {
return false;
}
Expression ExpressionNode::extractUnits() {
return Expression();
Expression ExpressionNode::removeUnit(Expression * unit) {
return Expression(this);
}
void ExpressionNode::setChildrenInPlace(Expression other) {

View File

@@ -63,8 +63,8 @@ bool MultiplicationNode::childAtIndexNeedsUserParentheses(const Expression & chi
return child.isOfType(types, 2);
}
Expression MultiplicationNode::extractUnits() {
return Multiplication(this).extractUnits();
Expression MultiplicationNode::removeUnit(Expression * unit) {
return Multiplication(this).removeUnit(unit);
}
template<typename T>
@@ -255,31 +255,30 @@ int Multiplication::getPolynomialCoefficients(Context * context, const char * sy
return deg;
}
Expression Multiplication::extractUnits() {
Multiplication result = Multiplication::Builder();
Expression Multiplication::removeUnit(Expression * unit) {
Multiplication unitMult = Multiplication::Builder();
int resultChildrenCount = 0;
for (int i = 0; i < numberOfChildren(); i++) {
Expression currentUnit = childAtIndex(i).extractUnits();
Expression currentUnit;
childAtIndex(i).removeUnit(&currentUnit);
if (!currentUnit.isUninitialized()) {
assert(childAtIndex(i) == currentUnit);
result.addChildAtIndexInPlace(currentUnit, resultChildrenCount, resultChildrenCount);
unitMult.addChildAtIndexInPlace(currentUnit, resultChildrenCount, resultChildrenCount);
resultChildrenCount++;
assert(childAtIndex(i).isRationalOne());
removeChildAtIndexInPlace(i--);
}
}
if (resultChildrenCount == 0) {
return Expression();
*unit = Expression();
} else {
*unit = unitMult.squashUnaryHierarchyInPlace();
}
/* squashUnaryHierarchyInPlace();
* That would make 'this' invalid, so we would rather keep any unary
* hierarchy as it is and handle it later.
* TODO ?
* A possible solution would be that the extractUnits method becomes
* Expression extractUnits(Expression & units)
* returning the Expression that is left after extracting the units
* and setting the units reference instead of returning the units.
*/
return result.squashUnaryHierarchyInPlace();
if (numberOfChildren() == 0) {
Expression one = Rational::Builder(1);
replaceWithInPlace(one);
return one;
}
return squashUnaryHierarchyInPlace();
}
template<typename T>
@@ -356,9 +355,11 @@ Expression Multiplication::shallowBeautify(ExpressionNode::ReductionContext redu
return std::move(o);
}
Expression result = *this;
Expression units = extractUnits();
Expression self = *this;
Expression units;
self = removeUnit(&units);
Expression result;
if (!units.isUninitialized()) {
/* Step 2: Handle the units
*
@@ -408,8 +409,12 @@ Expression Multiplication::shallowBeautify(ExpressionNode::ReductionContext redu
}
if (unitsAccu.numberOfChildren() > 0) {
units = Division::Builder(units, unitsAccu.clone()).deepReduce(reductionContext);
Expression newUnits = units.extractUnits();
result = Multiplication::Builder(result, units);
Expression newUnits;
units = units.removeUnit(&newUnits);
Multiplication m = Multiplication::Builder(units);
self.replaceWithInPlace(m);
m.addChildAtIndexInPlace(self, 0, 1);
self = m;
if (newUnits.isUninitialized()) {
units = unitsAccu;
} else {
@@ -426,7 +431,7 @@ Expression Multiplication::shallowBeautify(ExpressionNode::ReductionContext redu
* most relevant.
*/
double value = result.approximateToScalar<double>(reductionContext.context(), reductionContext.complexFormat(), reductionContext.angleUnit());
double value = self.approximateToScalar<double>(reductionContext.context(), reductionContext.complexFormat(), reductionContext.angleUnit());
if (std::isnan(value)) {
// If the value is undefined, return "undef" without any unit
result = Undefined::Builder();
@@ -480,7 +485,7 @@ Expression Multiplication::shallowBeautify(ExpressionNode::ReductionContext redu
}
}
replaceWithInPlace(result);
self.replaceWithInPlace(result);
return result;
}
@@ -1006,7 +1011,7 @@ bool Multiplication::TermHasNumeralExponent(const Expression & e) {
}
void Multiplication::splitIntoNormalForm(Expression & numerator, Expression & denominator, ExpressionNode::ReductionContext reductionContext) const {
assert(const_cast<Multiplication*>(this)->extractUnits().isUninitialized());
assert(!hasUnit());
Multiplication mNumerator = Multiplication::Builder();
Multiplication mDenominator = Multiplication::Builder();
int numberOfFactorsInNumerator = 0;

View File

@@ -81,12 +81,8 @@ int PowerNode::polynomialDegree(Context * context, const char * symbolName) cons
return -1;
}
Expression PowerNode::extractUnits() {
if (!childAtIndex(0)->extractUnits().isUninitialized()) {
assert(childAtIndex(0)->type() == ExpressionNode::Type::Unit);
return Power(this);
}
return ExpressionNode::extractUnits();
Expression PowerNode::removeUnit(Expression * unit) {
return Power(this).removeUnit(unit);
}
int PowerNode::getPolynomialCoefficients(Context * context, const char * symbolName, Expression coefficients[], ExpressionNode::SymbolicComputation symbolicComputation) const {
@@ -378,6 +374,23 @@ int Power::getPolynomialCoefficients(Context * context, const char * symbolName,
return -1;
}
Expression Power::removeUnit(Expression * unit) {
Expression childUnit;
Expression child = childAtIndex(0).node()->removeUnit(&childUnit);
if (!childUnit.isUninitialized()) {
// Reduced power containing unit are of form "unit^i" with i integer
assert(child.isRationalOne());
assert(childUnit.type() == ExpressionNode::Type::Unit);
Power p = *this;
Expression result = child;
replaceWithInPlace(child);
p.replaceChildAtIndexInPlace(0, childUnit);
*unit = p;
return child;
}
return *this;
}
Expression Power::shallowReduce(ExpressionNode::ReductionContext reductionContext) {
{
@@ -392,10 +405,13 @@ Expression Power::shallowReduce(ExpressionNode::ReductionContext reductionContex
// Step 1: Handle the units
{
if (!index.extractUnits().isUninitialized()) {
Expression indexUnit;
index.removeUnit(&indexUnit);
if (!indexUnit.isUninitialized()) {
// There must be no unit in the exponent
return replaceWithUndefinedInPlace();
}
assert(index == childAtIndex(1));
if (base.hasUnit()) {
if (index.type() != ExpressionNode::Type::Rational || !static_cast<Rational &>(index).isInteger()) {
// The exponent must be an Integer

View File

@@ -164,8 +164,8 @@ ExpressionNode::Sign UnitNode::sign(Context * context) const {
return Sign::Positive;
}
Expression UnitNode::extractUnits() {
return Unit(this);
Expression UnitNode::removeUnit(Expression * unit) {
return Unit(this).removeUnit(unit);
}
int UnitNode::simplificationOrderSameType(const ExpressionNode * e, bool ascending, bool canBeInterrupted, bool ignoreParentheses) const {
@@ -341,6 +341,13 @@ void Unit::chooseBestMultipleForValue(double & value, const int exponent, Expres
value = bestVal;
}
Expression Unit::removeUnit(Expression * unit) {
*unit = *this;
Expression one = Rational::Builder(1);
replaceWithInPlace(one);
return one;
}
template Evaluation<float> UnitNode::templatedApproximate<float>(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;
template Evaluation<double> UnitNode::templatedApproximate<double>(Context * context, Preferences::ComplexFormat complexFormat, Preferences::AngleUnit angleUnit) const;

View File

@@ -41,7 +41,8 @@ Expression UnitConvert::shallowReduce(ExpressionNode::ReductionContext reduction
reductionContext.angleUnit(),
reductionContext.target(),
ExpressionNode::SymbolicComputation::ReplaceAllSymbolsWithUndefinedAndReplaceUnits);
Expression unit = childAtIndex(1).clone().reduce(reductionContextWithUnits).extractUnits();
Expression unit;
childAtIndex(1).clone().reduce(reductionContextWithUnits).removeUnit(&unit);
if (unit.isUninitialized()) {
// There is no unit on the right
return replaceWithUndefinedInPlace();
@@ -54,18 +55,21 @@ Expression UnitConvert::shallowReduce(ExpressionNode::ReductionContext reduction
reductionContext.angleUnit(),
reductionContext.target(),
ExpressionNode::SymbolicComputation::ReplaceAllSymbolsWithUndefinedAndDoNotReplaceUnits);
Expression finalUnit = childAtIndex(1).reduce(reductionContextWithoutUnits).extractUnits();
Expression finalUnit;
childAtIndex(1).reduce(reductionContextWithoutUnits).removeUnit(&finalUnit);
// Divide the left member by the new unit
Expression division = Division::Builder(childAtIndex(0), finalUnit.clone());
division = division.reduce(reductionContext);
if (!division.extractUnits().isUninitialized()) {
Expression divisionUnit;
division = division.removeUnit(&divisionUnit);
if (!divisionUnit.isUninitialized()) {
// The left and right members are not homogeneous
return replaceWithUndefinedInPlace();
}
double floatValue = division.approximateToScalar<double>(reductionContext.context(), reductionContext.complexFormat(), reductionContext.angleUnit());
if (std::isinf(floatValue)) {
return Infinity::Builder(false); //FIXME sign?
return Infinity::Builder(floatValue < 0.0); //FIXME sign?
}
if (std::isnan(floatValue)) {
return Undefined::Builder();

View File

@@ -373,10 +373,12 @@ void assert_reduced_expression_unit(const char * expression, const char * unit,
ExpressionNode::ReductionContext redContext(&globalContext, Real, Degree, SystemForApproximation, symbolicComutation);
Expression e = parse_expression(expression, &globalContext, false);
e = e.reduce(redContext);
Expression u1 = e.extractUnits();
Expression u2 = parse_expression(unit, &globalContext, false);
u2 = u2.reduce(redContext);
u2 = u2.extractUnits();
Expression u1;
e = e.removeUnit(&u1);
Expression e2 = parse_expression(unit, &globalContext, false);
Expression u2;
e2 = e2.reduce(redContext);
e2.removeUnit(&u2);
quiz_assert_print_if_failure(u1.isUninitialized() == u2.isUninitialized() && (u1.isUninitialized() || u1.isIdenticalTo(u2)), expression);
}