Browse Source

math: implement Karatsuba algorithm for large bigint multiplications.

undefined
Sam Hocevar 10 years ago
parent
commit
26c394053e
2 changed files with 91 additions and 36 deletions
  1. +15
    -9
      doc/samples/sandbox/sample.cpp
  2. +76
    -27
      src/lol/math/bigint.h

+ 15
- 9
doc/samples/sandbox/sample.cpp View File

@@ -20,17 +20,23 @@ using namespace lol;

int main()
{
bigint<16> i;
bigint<16> x(2), y(12);
auto z = x + y;
printf("0x%x\n", (int)z);
Timer t;

bigint<128> x(17), y(23);
x.print();
y.print();

auto z = x * y;
z.print();

bigint<0> lol;
auto w = z + lol;
printf("0x%x\n", (int)w);
for (int i = 0; i < 500000; ++i)
{
x = (bigint<128>)(x * x);
x ^= y;
}

printf("%d %d\n", (int)x, (int)y);

bigint<2> f(0x10101010), g(0x20202020);
(f * f * f * g - g).print();
printf("Time: %f s\n", t.Get());
}


+ 76
- 27
src/lol/math/bigint.h View File

@@ -106,7 +106,7 @@ public:
}

/*
* bigint bitwise not: we just flip all bits except the unused one.
* bigint bitwise NOT: we just flip all bits except the unused ones.
*/
bigint<N,T> operator ~() const
{
@@ -237,7 +237,7 @@ public:
template<unsigned int M>
bigint<N + M, T> operator *(bigint<M,T> const &x) const
{
return mul_naive(*this, x);
return multiply(x);
}

/*
@@ -304,6 +304,18 @@ public:
printf("\n");
}

template<unsigned int A, unsigned int B>
inline bigint<B - A, T> &slice()
{
return *reinterpret_cast<bigint<B - A, T> *>((T *)this + A);
}

template<unsigned int A, unsigned int B>
inline bigint<B - A, T> const &slice() const
{
return *reinterpret_cast<bigint<B - A, T> const *>((T const *)this + A);
}

private:
/* Allow other types of bigints to access our private members */
template<unsigned int, typename> friend class bigint;
@@ -315,27 +327,12 @@ private:
return (m_digits[N - 1] >> (bits_per_digit - 1)) != 0;
}

inline uint32_t get_uint32(int offset) const
{
unsigned int bit = offset * 32;
unsigned int digit_index = bit / bits_per_digit;
unsigned int bit_index = bit % bits_per_digit;

if (digit_index >= N)
return 0;

uint32_t ret = m_digits[digit_index] >> bit_index;
if (bits_per_digit - bit_index < 32 && digit_index < N - 1)
ret |= m_digits[digit_index + 1] << (bits_per_digit - bit_index);
return ret;
}

template<unsigned int M>
static inline bigint<N + M, T> mul_naive(bigint<N,T> const &a,
bigint<M,T> const &b)
typename std::enable_if<(N != M || (N > 1 && N < 64)), bigint<N + M, T>>
::type inline multiply(bigint<M,T> const &b) const
{
/* FIXME: other digit sizes are not supported */
static_assert(sizeof(T) == sizeof(uint32_t), "ensure T is uint32_t");
static_assert(sizeof(T) <= sizeof(uint32_t), "ensure T is uint32_t");

bigint<N + M> ret(0);
for (unsigned int i = 0; i < N; ++i)
@@ -344,16 +341,16 @@ private:
for (unsigned int j = 0; j < M; ++j)
{
uint64_t digit = ret.m_digits[i + j]
+ (uint64_t)a.m_digits[i] * b.m_digits[j]
+ (uint64_t)m_digits[i] * b.m_digits[j]
+ carry;
ret.m_digits[i + j] = (T)digit & a.digit_mask;
carry = (digit >> a.bits_per_digit) & a.digit_mask;
ret.m_digits[i + j] = (T)digit & digit_mask;
carry = (digit >> bits_per_digit) & digit_mask;
}

for (unsigned int j = M; i + j < M + N && carry != 0; ++i)
{
T digit = ret.m_digits[i + j] + carry;
ret.m_digits[i + j] = (T)digit & a.digit_mask;
ret.m_digits[i + j] = (T)digit & digit_mask;
carry = (digit & ~digit_mask) ? T(1) : T(0);
}
}
@@ -361,12 +358,64 @@ private:
return ret;
}

template<unsigned int M>
typename std::enable_if<(N == M && N >= 64), bigint<N + M, T>>
::type inline multiply(bigint<M,T> const &b) const
{
bigint<2 * N, T> ret, tmp(0);

bigint<N / 2, T> const &a0 = slice<0, N / 2>();
bigint<N / 2, T> const &a1 = slice<N / 2, N>();
bigint<N / 2, T> const &b0 = b.template slice<0, N / 2>();
bigint<N / 2, T> const &b1 = b.template slice<N / 2, N>();
bigint<N, T> &r0 = ret.template slice<0, N>();
bigint<N, T> &r1 = ret.template slice<N / 2, N + N / 2>();
bigint<N, T> &r2 = ret.template slice<N, 2 * N>();
r0 = a0 * b0;
r2 = a1 * b1;
/* FIXME: check for overflows here */
r1 = (a0 + a1) * (b0 + b1) - r0 - r2;

return ret + tmp;
}

template<unsigned int M>
typename std::enable_if<(N == M && N == 1), bigint<N + M, T>>
::type inline multiply(bigint<M,T> const &b) const
{
bigint<2, T> ret;
uint64_t digit = (uint64_t)m_digits[0] * b.m_digits[0];
ret.m_digits[0] = (T)(digit) & ret.digit_mask;
ret.m_digits[1] = (T)(digit >> ret.bits_per_digit) & ret.digit_mask;
return ret;
}

inline uint32_t get_uint32(int offset) const
{
unsigned int bit = offset * 32;
unsigned int digit_index = bit / bits_per_digit;
unsigned int bit_index = bit % bits_per_digit;

if (digit_index >= N)
return 0;

uint32_t ret = m_digits[digit_index] >> bit_index;
if (bits_per_digit - bit_index < 32 && digit_index < N - 1)
ret |= m_digits[digit_index + 1] << (bits_per_digit - bit_index);
return ret;
}

T m_digits[N];
};

typedef bigint<8, uint32_t> int248_t;
typedef bigint<16, uint32_t> int496_t;
typedef bigint<32, uint32_t> int992_t;
/*
* Some convenience typedefs
*/

typedef bigint<8, uint32_t> int248_t;
typedef bigint<16, uint32_t> int496_t;
typedef bigint<32, uint32_t> int992_t;
typedef bigint<64, uint32_t> int1984_t;

} /* namespace lol */


Loading…
Cancel
Save