ソースを参照

math: improve outer product and add unit tests.

undefined
Sam Hocevar 10年前
コミット
17fe1306d2
2個のファイルの変更49行の追加18行の削除
  1. +19
    -18
      src/lol/math/matrix.h
  2. +30
    -0
      test/unit/matrix.cpp

+ 19
- 18
src/lol/math/matrix.h ファイルの表示

@@ -548,6 +548,19 @@ static inline mat_t<T, N, N> &operator *=(mat_t<T, N, N> &a,
return a = a * b;
}

/*
* Vector-vector outer product
*/

template<typename T, int COLS, int ROWS>
static inline mat_t<T, COLS, ROWS> outer(vec_t<T, ROWS> const &a,
vec_t<T, COLS> const &b)
{
/* Valid cast because mat_t and vec_t have similar layouts */
return *reinterpret_cast<mat_t<T, 1, ROWS> const *>(&a)
* *reinterpret_cast<mat_t<T, COLS, 1> const *>(&b);
}

/*
* Matrix-matrix outer product (Kronecker product)
*/
@@ -556,29 +569,17 @@ template<typename T, int COLS1, int COLS2, int ROWS1, int ROWS2>
static inline mat_t<T, COLS1 * COLS2, ROWS1 * ROWS2>
outer(mat_t<T, COLS1, ROWS1> const &a, mat_t<T, COLS2, ROWS2> const &b)
{
/* FIXME: could this be optimised somehow? */
mat_t<T, COLS1 * COLS2, ROWS1 * ROWS2> ret;
for (int i1 = 0; i1 < COLS1; ++i1)
for (int j1 = 0; j1 < ROWS1; ++j1)
for (int i2 = 0; i2 < COLS2; ++i2)
for (int j2 = 0; j2 < ROWS2; ++j2)
ret[i1 * COLS2 + i2][j1 * ROWS2 + j2]
= a[i1][j1] * b[i2][j2];
for (int i2 = 0; i2 < COLS2; ++i2)
{
/* Valid cast because mat_t and vec_t have similar layouts */
*reinterpret_cast<mat_t<T, ROWS1, ROWS2> *>(&ret[i1 * COLS2 + i2])
= outer(b[i2], a[i1]);
}
return ret;
}

/*
* Vector-vector outer product
*/

template<typename T, int COLS, int ROWS>
static inline mat_t<T, COLS, ROWS> outer(vec_t<T, ROWS> const &a,
vec_t<T, COLS> const &b)
{
return *reinterpret_cast<mat_t<T, 1, ROWS> const *>(&a)
* *reinterpret_cast<mat_t<T, COLS, 1> const *>(&b);
}

#if !LOL_FEATURE_CXX11_CONSTEXPR
#undef constexpr
#endif


+ 30
- 0
test/unit/matrix.cpp ファイルの表示

@@ -158,6 +158,36 @@ LOLUNIT_FIXTURE(MatrixTest)
LOLUNIT_ASSERT_EQUAL(m2[3][3], 1.0f);
}

LOLUNIT_TEST(Kronecker)
{
int const COLS1 = 2, ROWS1 = 3;
int const COLS2 = 5, ROWS2 = 7;

mat_t<int, COLS1, ROWS1> a;
mat_t<int, COLS2, ROWS2> b;

for (int i = 0; i < COLS1; ++i)
for (int j = 0; j < ROWS1; ++j)
a[i][j] = (i + 11) * (j + 13);

for (int i = 0; i < COLS2; ++i)
for (int j = 0; j < ROWS2; ++j)
b[i][j] = (i + 17) * (j + 19);

mat_t<int, COLS1 * COLS2, ROWS1 * ROWS2> m = outer(a, b);

for (int i1 = 0; i1 < COLS1; ++i1)
for (int j1 = 0; j1 < ROWS1; ++j1)
for (int i2 = 0; i2 < COLS2; ++i2)
for (int j2 = 0; j2 < ROWS2; ++j2)
{
int expected = a[i1][j1] * b[i2][j2];
int actual = m[i1 * COLS2 + i2][j1 * ROWS2 + j2];

LOLUNIT_ASSERT_EQUAL(actual, expected);
}
}

mat2 tri2, id2, inv2;
mat3 tri3, id3, inv3;
mat4 tri4, id4, inv4;


読み込み中…
キャンセル
保存