diff --git a/src/lol/math/matrix.h b/src/lol/math/matrix.h index 9a73e977..6f883c1e 100644 --- a/src/lol/math/matrix.h +++ b/src/lol/math/matrix.h @@ -424,9 +424,45 @@ T cofactor(mat_t const &m, int i, int j) return (i ^ j) ? -tmp : tmp; } -/* - * Compute LU-decomposition - */ +// Lu decomposition with partial pivoting +template LOL_ATTR_NODISCARD +std::tuple, vec_t> lu_decomposition(mat_t const &m) +{ + mat_t lu = m; + vec_t perm; + + for (int i = 0; i < N; ++i) + perm[i] = i; + perm[N] = 1; + + for (int k = 0; k < N; ++k) + { + // Find row with the largest absolute value + int best_j = k; + for (int j = k + 1; j < N; ++j) + if (abs(lu[k][j]) > lol::abs(lu[k][best_j])) + best_j = j; + + // Swap rows in result + if (best_j != k) + { + std::swap(perm[k], perm[best_j]); + perm[N] = -perm[N]; + for (int i = 0; i < N; ++i) + std::swap(lu[i][k], lu[i][best_j]); + } + + // Compute the Schur complement in the lower triangular part + for (int j = k + 1; j < N; ++j) + { + lu[k][j] /= lu[k][k]; + for (int i = k + 1; i < N; ++i) + lu[i][j] -= lu[i][k] * lu[k][j]; + } + } + + return std::make_tuple(lu, perm); +} template void lu_decomposition(mat_t const &m, mat_t & L, mat_t & U) @@ -523,12 +559,8 @@ mat_t permute_rows(mat_t const & m, vec_t const & perm mat_t result; for (int i = 0 ; i < N ; ++i) - { for (int j = 0 ; j < N ; ++j) - { result[i][j] = m[i][permutation[j]]; - } - } return result; } @@ -539,12 +571,7 @@ mat_t permute_cols(mat_t const & m, vec_t const & perm mat_t result; for (int i = 0 ; i < N ; ++i) - { - for (int j = 0 ; j < N ; ++j) - { - result[i][j] = m[permutation[i]][j]; - } - } + result[i] = m[permutation[i]]; return result; }