From 7e2cc375d26e67ea133560cf937963615351ffd0 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Sat, 9 Sep 2017 10:06:02 -0400 Subject: [PATCH] Fix potential overflow in complex radius calculation. --- liba/include/math.h | 2 ++ poincare/src/complex.cpp | 32 ++++++++++++++++++++++++++++++-- poincare/test/complex.cpp | 10 ++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/liba/include/math.h b/liba/include/math.h index a8d1efbd0..1ecbd2249 100644 --- a/liba/include/math.h +++ b/liba/include/math.h @@ -10,6 +10,8 @@ LIBA_BEGIN_DECLS #define INFINITY __builtin_inff() #define M_E 2.71828182845904523536028747135266250 #define M_PI 3.14159265358979323846264338327950288 +#define M_PI_4 0.78539816339744830961566084581987572 +#define M_SQRT2 1.41421356237309504880168872420969808 #define FP_INFINITE 0x01 #define FP_NAN 0x02 diff --git a/poincare/src/complex.cpp b/poincare/src/complex.cpp index e5e5d7008..ed1256c42 100644 --- a/poincare/src/complex.cpp +++ b/poincare/src/complex.cpp @@ -161,10 +161,38 @@ T Complex::b() const { template T Complex::r() const { + // We want to avoid a^2 and b^2 which could both easily overflow. + // min, max = minmax(abs(a), abs(b)) (*minmax returns both arguments sorted*) + // abs(a + bi) == sqrt(a^2 + b^2) + // == sqrt(abs(a)^2 + abs(b)^2) + // == sqrt(min^2 + max^2) + // == sqrt((min^2 + max^2) * max^2/max^2) + // == sqrt((min^2 + max^2) / max^2)*sqrt(max^2) + // == sqrt(min^2/max^2 + 1) * max + // == sqrt((min/max)^2 + 1) * max + // min >= 0 && + // max >= 0 && + // min <= max => min/max <= 1 + // => (min/max)^2 <= 1 + // => (min/max)^2 + 1 <= 2 + // => sqrt((min/max)^2 + 1) <= sqrt(2) + // So the calculation is guaranteed to not overflow until the final multiply. + // If (min/max)^2 underflows then min doesn't contribute anything significant + // compared to max, and the formula reduces to simply max as it should. + // We do need to be careful about the case where a == 0 && b == 0 which would + // cause a division by zero. + T min = std::fabs(m_a); if (m_b == 0) { - return std::fabs(m_a); + return min; } - return std::sqrt(m_a*m_a + m_b*m_b); + T max = std::fabs(m_b); + if (max < min) { + T temp = min; + min = max; + max = temp; + } + T temp = min/max; + return std::sqrt(temp*temp + 1) * max; } template diff --git a/poincare/test/complex.cpp b/poincare/test/complex.cpp index 738e3ad6d..d08b3ab16 100644 --- a/poincare/test/complex.cpp +++ b/poincare/test/complex.cpp @@ -128,4 +128,14 @@ QUIZ_CASE(poincare_complex_constructor) { b = new Complex(Complex::Polar(12.04159457879229548012824103, 1.4876550949)); assert(std::fabs(b->a() - 1.0) < 0.0000000001 && std::fabs(b->b()-12.0) < 0.0000000001); delete b; + + Complex * c = new Complex(Complex::Cartesian(-2.0e20f, 2.0e20f)); + assert(c->a() == -2.0e20f && c->b() == 2.0e20f); + assert(c->r() == 2.0e20f*(float)M_SQRT2 && c->th() == 3*(float)M_PI_4); + delete c; + + Complex * d = new Complex(Complex::Cartesian(1.0e155, -1.0e155)); + assert(d->a() == 1.0e155 && d->b() == -1.0e155); + assert(d->r() == 1.0e155*M_SQRT2 && d->th() == -M_PI_4); + delete d; }