diff --git a/src/matrix.h b/src/matrix.h index 72c1e650..e104a0f0 100644 --- a/src/matrix.h +++ b/src/matrix.h @@ -37,13 +37,13 @@ namespace lol return *this = (*this) op val; \ } -#define BOOL_OP(elems, op, ret) \ +#define BOOL_OP(elems, op, op2, ret) \ inline bool operator op(Vec##elems const &val) const \ { \ for (int n = 0; n < elems; n++) \ - if ((*this)[n] != val[n]) \ - return ret; \ - return !ret; \ + if (!((*this)[n] op2 val[n])) \ + return !ret; \ + return ret; \ } #define SCALAR_OP(elems, op) \ @@ -78,8 +78,12 @@ namespace lol VECTOR_OP(elems, *) \ VECTOR_OP(elems, /) \ \ - BOOL_OP(elems, ==, false) \ - BOOL_OP(elems, !=, true) \ + BOOL_OP(elems, ==, ==, true) \ + BOOL_OP(elems, !=, ==, false) \ + BOOL_OP(elems, <=, <=, true) \ + BOOL_OP(elems, >=, >=, true) \ + BOOL_OP(elems, <, <, true) \ + BOOL_OP(elems, >, >, true) \ \ SCALAR_OP(elems, -) \ SCALAR_OP(elems, +) \ diff --git a/test/matrix.cpp b/test/matrix.cpp index 0cdf79e7..3a8ef2b4 100644 --- a/test/matrix.cpp +++ b/test/matrix.cpp @@ -25,6 +25,8 @@ namespace lol class MatrixTest : public CppUnit::TestCase { CPPUNIT_TEST_SUITE(MatrixTest); + CPPUNIT_TEST(test_vec_eq); + CPPUNIT_TEST(test_vec_lt); CPPUNIT_TEST(test_mat_det); CPPUNIT_TEST(test_mat_mul); CPPUNIT_TEST(test_mat_inv); @@ -48,6 +50,78 @@ public: void tearDown() {} + void test_vec_eq() + { + vec2 a2(1.0f, 2.0f); + vec2 b2(0.0f, 2.0f); + vec2 c2(1.0f, 0.0f); + + CPPUNIT_ASSERT(a2 == a2); + CPPUNIT_ASSERT(!(a2 != a2)); + + CPPUNIT_ASSERT(a2 != b2); + CPPUNIT_ASSERT(!(a2 == b2)); + CPPUNIT_ASSERT(a2 != c2); + CPPUNIT_ASSERT(!(a2 == c2)); + + vec3 a3(1.0f, 2.0f, 3.0f); + vec3 b3(0.0f, 2.0f, 3.0f); + vec3 c3(1.0f, 0.0f, 3.0f); + vec3 d3(1.0f, 2.0f, 0.0f); + + CPPUNIT_ASSERT(a3 == a3); + CPPUNIT_ASSERT(!(a3 != a3)); + + CPPUNIT_ASSERT(a3 != b3); + CPPUNIT_ASSERT(!(a3 == b3)); + CPPUNIT_ASSERT(a3 != c3); + CPPUNIT_ASSERT(!(a3 == c3)); + CPPUNIT_ASSERT(a3 != d3); + CPPUNIT_ASSERT(!(a3 == d3)); + + vec4 a4(1.0f, 2.0f, 3.0f, 4.0f); + vec4 b4(0.0f, 2.0f, 3.0f, 4.0f); + vec4 c4(1.0f, 0.0f, 3.0f, 4.0f); + vec4 d4(1.0f, 2.0f, 0.0f, 4.0f); + vec4 e4(1.0f, 2.0f, 3.0f, 0.0f); + + CPPUNIT_ASSERT(a4 == a4); + CPPUNIT_ASSERT(!(a4 != a4)); + + CPPUNIT_ASSERT(a4 != b4); + CPPUNIT_ASSERT(!(a4 == b4)); + CPPUNIT_ASSERT(a4 != c4); + CPPUNIT_ASSERT(!(a4 == c4)); + CPPUNIT_ASSERT(a4 != d4); + CPPUNIT_ASSERT(!(a4 == d4)); + CPPUNIT_ASSERT(a4 != e4); + CPPUNIT_ASSERT(!(a4 == e4)); + } + + void test_vec_lt() + { + vec2 a2(1.0f, 3.0f); + vec2 b2(0.0f, 0.0f); + vec2 c2(1.0f, 1.0f); + vec2 d2(2.0f, 2.0f); + vec2 e2(3.0f, 3.0f); + vec2 f2(4.0f, 4.0f); + + CPPUNIT_ASSERT(a2 <= a2); + CPPUNIT_ASSERT(!(a2 < a2)); + + CPPUNIT_ASSERT(!(a2 <= b2)); + CPPUNIT_ASSERT(!(a2 < b2)); + CPPUNIT_ASSERT(!(a2 <= c2)); + CPPUNIT_ASSERT(!(a2 < c2)); + CPPUNIT_ASSERT(!(a2 <= d2)); + CPPUNIT_ASSERT(!(a2 < d2)); + CPPUNIT_ASSERT(a2 <= e2); + CPPUNIT_ASSERT(!(a2 < e2)); + CPPUNIT_ASSERT(a2 <= f2); + CPPUNIT_ASSERT(a2 < f2); + } + void test_mat_det() { float d1 = triangular.det();