diff --git a/src/lol/math/matrix.h b/src/lol/math/matrix.h index 426c21a0..3283f8b3 100644 --- a/src/lol/math/matrix.h +++ b/src/lol/math/matrix.h @@ -548,6 +548,19 @@ static inline mat_t &operator *=(mat_t &a, return a = a * b; } +/* + * Vector-vector outer product + */ + +template +static inline mat_t outer(vec_t const &a, + vec_t const &b) +{ + /* Valid cast because mat_t and vec_t have similar layouts */ + return *reinterpret_cast const *>(&a) + * *reinterpret_cast const *>(&b); +} + /* * Matrix-matrix outer product (Kronecker product) */ @@ -556,29 +569,17 @@ template static inline mat_t outer(mat_t const &a, mat_t const &b) { - /* FIXME: could this be optimised somehow? */ mat_t ret; for (int i1 = 0; i1 < COLS1; ++i1) - for (int j1 = 0; j1 < ROWS1; ++j1) - for (int i2 = 0; i2 < COLS2; ++i2) - for (int j2 = 0; j2 < ROWS2; ++j2) - ret[i1 * COLS2 + i2][j1 * ROWS2 + j2] - = a[i1][j1] * b[i2][j2]; + for (int i2 = 0; i2 < COLS2; ++i2) + { + /* Valid cast because mat_t and vec_t have similar layouts */ + *reinterpret_cast *>(&ret[i1 * COLS2 + i2]) + = outer(b[i2], a[i1]); + } return ret; } -/* - * Vector-vector outer product - */ - -template -static inline mat_t outer(vec_t const &a, - vec_t const &b) -{ - return *reinterpret_cast const *>(&a) - * *reinterpret_cast const *>(&b); -} - #if !LOL_FEATURE_CXX11_CONSTEXPR #undef constexpr #endif diff --git a/test/unit/matrix.cpp b/test/unit/matrix.cpp index 021ff6a2..fc14a403 100644 --- a/test/unit/matrix.cpp +++ b/test/unit/matrix.cpp @@ -158,6 +158,36 @@ LOLUNIT_FIXTURE(MatrixTest) LOLUNIT_ASSERT_EQUAL(m2[3][3], 1.0f); } + LOLUNIT_TEST(Kronecker) + { + int const COLS1 = 2, ROWS1 = 3; + int const COLS2 = 5, ROWS2 = 7; + + mat_t a; + mat_t b; + + for (int i = 0; i < COLS1; ++i) + for (int j = 0; j < ROWS1; ++j) + a[i][j] = (i + 11) * (j + 13); + + for (int i = 0; i < COLS2; ++i) + for (int j = 0; j < ROWS2; ++j) + b[i][j] = (i + 17) * (j + 19); + + mat_t m = outer(a, b); + + for (int i1 = 0; i1 < COLS1; ++i1) + for (int j1 = 0; j1 < ROWS1; ++j1) + for (int i2 = 0; i2 < COLS2; ++i2) + for (int j2 = 0; j2 < ROWS2; ++j2) + { + int expected = a[i1][j1] * b[i2][j2]; + int actual = m[i1 * COLS2 + i2][j1 * ROWS2 + j2]; + + LOLUNIT_ASSERT_EQUAL(actual, expected); + } + } + mat2 tri2, id2, inv2; mat3 tri3, id3, inv3; mat4 tri4, id4, inv4;