From 26c394053eafe62612403b208eecdab5fb499704 Mon Sep 17 00:00:00 2001 From: Sam Hocevar Date: Sat, 17 Jan 2015 16:39:12 +0000 Subject: [PATCH] math: implement Karatsuba algorithm for large bigint multiplications. --- doc/samples/sandbox/sample.cpp | 24 +++++--- src/lol/math/bigint.h | 103 ++++++++++++++++++++++++--------- 2 files changed, 91 insertions(+), 36 deletions(-) diff --git a/doc/samples/sandbox/sample.cpp b/doc/samples/sandbox/sample.cpp index d41383f2..c2d108cd 100644 --- a/doc/samples/sandbox/sample.cpp +++ b/doc/samples/sandbox/sample.cpp @@ -20,17 +20,23 @@ using namespace lol; int main() { - bigint<16> i; - bigint<16> x(2), y(12); - auto z = x + y; - printf("0x%x\n", (int)z); + Timer t; + + bigint<128> x(17), y(23); + x.print(); + y.print(); + + auto z = x * y; z.print(); - bigint<0> lol; - auto w = z + lol; - printf("0x%x\n", (int)w); + for (int i = 0; i < 500000; ++i) + { + x = (bigint<128>)(x * x); + x ^= y; + } + + printf("%d %d\n", (int)x, (int)y); - bigint<2> f(0x10101010), g(0x20202020); - (f * f * f * g - g).print(); + printf("Time: %f s\n", t.Get()); } diff --git a/src/lol/math/bigint.h b/src/lol/math/bigint.h index a512d222..2e7f3ebf 100644 --- a/src/lol/math/bigint.h +++ b/src/lol/math/bigint.h @@ -106,7 +106,7 @@ public: } /* - * bigint bitwise not: we just flip all bits except the unused one. + * bigint bitwise NOT: we just flip all bits except the unused ones. */ bigint operator ~() const { @@ -237,7 +237,7 @@ public: template bigint operator *(bigint const &x) const { - return mul_naive(*this, x); + return multiply(x); } /* @@ -304,6 +304,18 @@ public: printf("\n"); } + template + inline bigint &slice() + { + return *reinterpret_cast *>((T *)this + A); + } + + template + inline bigint const &slice() const + { + return *reinterpret_cast const *>((T const *)this + A); + } + private: /* Allow other types of bigints to access our private members */ template friend class bigint; @@ -315,27 +327,12 @@ private: return (m_digits[N - 1] >> (bits_per_digit - 1)) != 0; } - inline uint32_t get_uint32(int offset) const - { - unsigned int bit = offset * 32; - unsigned int digit_index = bit / bits_per_digit; - unsigned int bit_index = bit % bits_per_digit; - - if (digit_index >= N) - return 0; - - uint32_t ret = m_digits[digit_index] >> bit_index; - if (bits_per_digit - bit_index < 32 && digit_index < N - 1) - ret |= m_digits[digit_index + 1] << (bits_per_digit - bit_index); - return ret; - } - template - static inline bigint mul_naive(bigint const &a, - bigint const &b) + typename std::enable_if<(N != M || (N > 1 && N < 64)), bigint> + ::type inline multiply(bigint const &b) const { /* FIXME: other digit sizes are not supported */ - static_assert(sizeof(T) == sizeof(uint32_t), "ensure T is uint32_t"); + static_assert(sizeof(T) <= sizeof(uint32_t), "ensure T is uint32_t"); bigint ret(0); for (unsigned int i = 0; i < N; ++i) @@ -344,16 +341,16 @@ private: for (unsigned int j = 0; j < M; ++j) { uint64_t digit = ret.m_digits[i + j] - + (uint64_t)a.m_digits[i] * b.m_digits[j] + + (uint64_t)m_digits[i] * b.m_digits[j] + carry; - ret.m_digits[i + j] = (T)digit & a.digit_mask; - carry = (digit >> a.bits_per_digit) & a.digit_mask; + ret.m_digits[i + j] = (T)digit & digit_mask; + carry = (digit >> bits_per_digit) & digit_mask; } for (unsigned int j = M; i + j < M + N && carry != 0; ++i) { T digit = ret.m_digits[i + j] + carry; - ret.m_digits[i + j] = (T)digit & a.digit_mask; + ret.m_digits[i + j] = (T)digit & digit_mask; carry = (digit & ~digit_mask) ? T(1) : T(0); } } @@ -361,12 +358,64 @@ private: return ret; } + template + typename std::enable_if<(N == M && N >= 64), bigint> + ::type inline multiply(bigint const &b) const + { + bigint<2 * N, T> ret, tmp(0); + + bigint const &a0 = slice<0, N / 2>(); + bigint const &a1 = slice(); + bigint const &b0 = b.template slice<0, N / 2>(); + bigint const &b1 = b.template slice(); + bigint &r0 = ret.template slice<0, N>(); + bigint &r1 = ret.template slice(); + bigint &r2 = ret.template slice(); + r0 = a0 * b0; + r2 = a1 * b1; + /* FIXME: check for overflows here */ + r1 = (a0 + a1) * (b0 + b1) - r0 - r2; + + return ret + tmp; + } + + template + typename std::enable_if<(N == M && N == 1), bigint> + ::type inline multiply(bigint const &b) const + { + bigint<2, T> ret; + uint64_t digit = (uint64_t)m_digits[0] * b.m_digits[0]; + ret.m_digits[0] = (T)(digit) & ret.digit_mask; + ret.m_digits[1] = (T)(digit >> ret.bits_per_digit) & ret.digit_mask; + return ret; + } + + inline uint32_t get_uint32(int offset) const + { + unsigned int bit = offset * 32; + unsigned int digit_index = bit / bits_per_digit; + unsigned int bit_index = bit % bits_per_digit; + + if (digit_index >= N) + return 0; + + uint32_t ret = m_digits[digit_index] >> bit_index; + if (bits_per_digit - bit_index < 32 && digit_index < N - 1) + ret |= m_digits[digit_index + 1] << (bits_per_digit - bit_index); + return ret; + } + T m_digits[N]; }; -typedef bigint<8, uint32_t> int248_t; -typedef bigint<16, uint32_t> int496_t; -typedef bigint<32, uint32_t> int992_t; +/* + * Some convenience typedefs + */ + +typedef bigint<8, uint32_t> int248_t; +typedef bigint<16, uint32_t> int496_t; +typedef bigint<32, uint32_t> int992_t; +typedef bigint<64, uint32_t> int1984_t; } /* namespace lol */