diff --git a/src/core.h b/src/core.h index 895f9f7c..a3dd9baa 100644 --- a/src/core.h +++ b/src/core.h @@ -17,9 +17,9 @@ #define __LOL_CORE_H__ // Base types +#include "half.h" #include "matrix.h" #include "numeric.h" -#include "half.h" #include "timer.h" // Static classes diff --git a/src/half.h b/src/half.h index cc13c5e8..30ba4d45 100644 --- a/src/half.h +++ b/src/half.h @@ -63,6 +63,26 @@ public: operator float() const; inline operator int() const { return (int)(float)*this; } + /* Operations */ + inline half operator -() { return makebits(m_bits ^ 0x8000u); } + inline half &operator +=(float f) { return (*this = (half)(*this + f)); } + inline half &operator -=(float f) { return (*this = (half)(*this - f)); } + inline half &operator *=(float f) { return (*this = (half)(*this * f)); } + inline half &operator /=(float f) { return (*this = (half)(*this / f)); } + inline half &operator +=(half h) { return (*this = (half)(*this + h)); } + inline half &operator -=(half h) { return (*this = (half)(*this - h)); } + inline half &operator *=(half h) { return (*this = (half)(*this * h)); } + inline half &operator /=(half h) { return (*this = (half)(*this / h)); } + + inline float operator +(float f) const { return (float)*this + f; } + inline float operator -(float f) const { return (float)*this - f; } + inline float operator *(float f) const { return (float)*this * f; } + inline float operator /(float f) const { return (float)*this / f; } + inline float operator +(half h) const { return (float)*this + (float)h; } + inline float operator -(half h) const { return (float)*this - (float)h; } + inline float operator *(half h) const { return (float)*this * (float)h; } + inline float operator /(half h) const { return (float)*this / (float)h; } + /* Factories */ static half makeslow(float f); static half makefast(float f); @@ -74,6 +94,16 @@ public: } }; +inline float &operator +=(float &f, half h) { return f += (float)h; } +inline float &operator -=(float &f, half h) { return f -= (float)h; } +inline float &operator *=(float &f, half h) { return f *= (float)h; } +inline float &operator /=(float &f, half h) { return f /= (float)h; } + +inline float operator +(float f, half h) { return f + (float)h; } +inline float operator -(float f, half h) { return f - (float)h; } +inline float operator *(float f, half h) { return f * (float)h; } +inline float operator /(float f, half h) { return f / (float)h; } + } /* namespace lol */ #endif // __LOL_HALF_H__ diff --git a/test/half.cpp b/test/half.cpp index 739f67cb..fdcef86c 100644 --- a/test/half.cpp +++ b/test/half.cpp @@ -35,6 +35,8 @@ class HalfTest : public CppUnit::TestCase CPPUNIT_TEST(test_half_classify); CPPUNIT_TEST(test_half_to_float); CPPUNIT_TEST(test_half_to_int); + CPPUNIT_TEST(test_float_op_half); + CPPUNIT_TEST(test_half_op_float); CPPUNIT_TEST_SUITE_END(); public: @@ -207,6 +209,89 @@ public: CPPUNIT_ASSERT_EQUAL((int)(half)(-65504.0f), -65504); } + void test_float_op_half() + { + half zero = 0; + half one = 1; + half two = 2; + + float a = zero + one; + CPPUNIT_ASSERT_EQUAL(1.0f, a); + a += zero; + CPPUNIT_ASSERT_EQUAL(1.0f, a); + a -= zero; + CPPUNIT_ASSERT_EQUAL(1.0f, a); + a *= one; + CPPUNIT_ASSERT_EQUAL(1.0f, a); + a /= one; + CPPUNIT_ASSERT_EQUAL(1.0f, a); + + float b = one + zero; + CPPUNIT_ASSERT_EQUAL(1.0f, b); + b += one; + CPPUNIT_ASSERT_EQUAL(2.0f, b); + b *= two; + CPPUNIT_ASSERT_EQUAL(4.0f, b); + b -= two; + CPPUNIT_ASSERT_EQUAL(2.0f, b); + b /= two; + CPPUNIT_ASSERT_EQUAL(1.0f, b); + + float c = one - zero; + CPPUNIT_ASSERT_EQUAL(1.0f, c); + + float d = two - one; + CPPUNIT_ASSERT_EQUAL(1.0f, d); + + float e = two + (-one); + CPPUNIT_ASSERT_EQUAL(1.0f, e); + + float f = (two * two) / (one + one); + CPPUNIT_ASSERT_EQUAL(2.0f, f); + } + + void test_half_op_float() + { + half zero = 0; + half one = 1; + half two = 2; + half four = 4; + + half a = one + 0.0f; + CPPUNIT_ASSERT_EQUAL(one.bits(), a.bits()); + a += 0.0f; + CPPUNIT_ASSERT_EQUAL(one.bits(), a.bits()); + a -= 0.0f; + CPPUNIT_ASSERT_EQUAL(one.bits(), a.bits()); + a *= 1.0f; + CPPUNIT_ASSERT_EQUAL(one.bits(), a.bits()); + a /= 1.0f; + CPPUNIT_ASSERT_EQUAL(one.bits(), a.bits()); + + half b = one + 0.0f; + CPPUNIT_ASSERT_EQUAL(one.bits(), b.bits()); + b += 1.0f; + CPPUNIT_ASSERT_EQUAL(two.bits(), b.bits()); + b *= 2.0f; + CPPUNIT_ASSERT_EQUAL(four.bits(), b.bits()); + b -= 2.0f; + CPPUNIT_ASSERT_EQUAL(two.bits(), b.bits()); + b /= 2.0f; + CPPUNIT_ASSERT_EQUAL(one.bits(), b.bits()); + + half c = 1.0f - zero; + CPPUNIT_ASSERT_EQUAL(one.bits(), c.bits()); + + half d = 2.0f - one; + CPPUNIT_ASSERT_EQUAL(one.bits(), d.bits()); + + half e = 2.0f + (-one); + CPPUNIT_ASSERT_EQUAL(one.bits(), e.bits()); + + half f = (2.0f * two) / (1.0f + one); + CPPUNIT_ASSERT_EQUAL(two.bits(), f.bits()); + } + private: struct TestPair { float f; uint16_t x; };