ソースを参照

math: add bitwise operators for bigints, comparison operators, unary

plus and minus, subtraction, and a lot of unit tests.
undefined
Sam Hocevar 10年前
コミット
d92547bf3e
2個のファイルの変更329行の追加11行の削除
  1. +159
    -6
      src/lol/math/bigint.h
  2. +170
    -5
      src/t/math/bigint.cpp

+ 159
- 6
src/lol/math/bigint.h ファイルの表示

@@ -94,30 +94,106 @@ public:
template<int M>
explicit bigint(bigint<M,T> const &x)
{
for (int i = 0; i < (N < M) ? N : M; ++i)
for (unsigned int i = 0; i < (N < M) ? N : M; ++i)
m_digits[i] = x.m_digits[i];

if (N > M)
{
T padding = x.is_negative() ? digit_mask : (T)0;
for (int i = M; i < N; ++i)
for (unsigned int i = M; i < N; ++i)
m_digits[i] = padding;
}
}

/*
* bigint bitwise not: we just flip all bits except the unused one.
*/
bigint<N,T> operator ~() const
{
bigint<N,T> ret;
for (unsigned int i = 0; i < N; ++i)
ret.m_digits[i] = m_digits[i] ^ digit_mask;
return ret;
}

/*
* bigint bitwise AND: just perform a bitwise AND on all digits.
*/
bigint<N,T> & operator &=(bigint<N,T> const &x)
{
for (unsigned int i = 0; i < N; ++i)
m_digits[i] &= x.m_digits[i];
}

inline bigint<N,T> operator &(bigint<N,T> const &x) const
{
return bigint<N,T>(*this) &= x;
}

/*
* bigint bitwise OR: just perform a bitwise OR on all digits.
*/
bigint<N,T> & operator |=(bigint<N,T> const &x)
{
for (unsigned int i = 0; i < N; ++i)
m_digits[i] |= x.m_digits[i];
}

inline bigint<N,T> operator |(bigint<N,T> const &x) const
{
return bigint<N,T>(*this) |= x;
}

/*
* bigint bitwise XOR: just perform a bitwise XOR on all digits.
*/
bigint<N,T> & operator ^=(bigint<N,T> const &x)
{
for (unsigned int i = 0; i < N; ++i)
m_digits[i] ^= x.m_digits[i];
}

inline bigint<N,T> operator ^(bigint<N,T> const &x) const
{
return bigint<N,T>(*this) ^= x;
}

/*
* bigint unary plus: a no-op
*/
inline bigint<N,T> operator +() const
{
return *this;
}

/*
* bigint unary minus: flip bits and add one
*/
bigint<N,T> operator -() const
{
bigint<N,T> ret;
T carry(1);
for (unsigned int i = 0; i < N; ++i)
{
T digit = (m_digits[i] ^ digit_mask) + carry;
ret.m_digits[i] = digit & digit_mask;
carry = (digit & ~digit_mask) ? T(1) : T(0);
}
return ret;
}

/*
* bigint addition: we add the digits one-to-one, carrying overflows,
* and replace digits with padding if one of the two operands is
* shorter.
* and pad missing digits if one of the two operands is shorter.
*/
template<int M>
template<unsigned int M>
bigint<(N > M) ? N : M, T> operator +(bigint<M,T> const &x) const
{
bigint<(N > M) ? N : M, T> ret;
T padding = is_negative() ? digit_mask : (T)0;
T x_padding = x.is_negative() ? digit_mask : (T)0;
T carry(0);
for (int i = 0; i < (N > M) ? N : M; ++i)
for (unsigned int i = 0; i < ((N > M) ? N : M); ++i)
{
T digit = (i < N ? m_digits[i] : padding)
+ (i < M ? x.m_digits[i] : x_padding)
@@ -128,7 +204,80 @@ public:
return ret;
}

/*
* bigint subtraction: a combination of addition and unary minus;
* we add the result of flipping digits and adding one.
*/
template<unsigned int M>
bigint<(N > M) ? N : M, T> operator -(bigint<M,T> const &x) const
{
bigint<(N > M) ? N : M, T> ret;
T padding = is_negative() ? digit_mask : (T)0;
T x_padding = x.is_negative() ? digit_mask : (T)0;
T carry(1);
for (unsigned int i = 0; i < ((N > M) ? N : M); ++i)
{
T digit = (i < N ? m_digits[i] : padding)
+ ((i < M ? x.m_digits[i] : x_padding) ^ digit_mask)
+ carry;
ret.m_digits[i] = digit & digit_mask;
carry = (digit & ~digit_mask) ? T(1) : T(0);
}
return ret;
}

/*
* bigint equality operator: just use memcmp.
* FIXME: we could easily support operands of different sizes.
*/
inline bool operator ==(bigint<N,T> const &x) const
{
return memcmp(m_digits, x.m_digits, sizeof(m_digits)) == 0;
}

inline bool operator !=(bigint<N,T> const &x) const
{
return !(*this == x);
}

/*
* bigint comparison operators: take a quick decision if signs
* differ. Otherwise, compare all digits.
*/
bool operator >(bigint<N,T> const &x) const
{
if (is_negative() ^ x.is_negative())
return x.is_negative();
for (unsigned int i = 0; i < N; ++i)
if (m_digits[i] != x.m_digits[i])
return m_digits[i] > x.m_digits[i];
return false;
}

bool operator <(bigint<N,T> const &x) const
{
if (is_negative() ^ x.is_negative())
return is_negative();
for (unsigned int i = 0; i < N; ++i)
if (m_digits[i] != x.m_digits[i])
return m_digits[i] < x.m_digits[i];
return false;
}

inline bool operator >=(bigint<N,T> const &x) const
{
return !(*this < x);
}

inline bool operator <=(bigint<N,T> const &x) const
{
return !(*this > x);
}

private:
/* Allow other types of bigints to access our private members */
template<unsigned int, typename> friend class bigint;

inline bool is_negative() const
{
if (N < 1)
@@ -139,5 +288,9 @@ private:
T m_digits[N];
};

typedef bigint<8, uint32_t> int248_t;
typedef bigint<16, uint32_t> int496_t;
typedef bigint<32, uint32_t> int992_t;

} /* namespace lol */


+ 170
- 5
src/t/math/bigint.cpp ファイルの表示

@@ -46,13 +46,178 @@ lolunit_declare_fixture(bigint_test)
lolunit_assert_equal((uint32_t)c, ~(uint32_t)0);
}

lolunit_declare_test(empty_bigint_is_zero)
lolunit_declare_test(operator_equal)
{
bigint<0> a, b(1), c(-1);
bigint<> a(-1), b(0), c(1);

lolunit_assert_equal((int)a, 0);
lolunit_assert_equal((int)b, 0);
lolunit_assert_equal((int)c, 0);
lolunit_assert(a == a);
lolunit_assert(!(a == b));
lolunit_assert(!(a == c));

lolunit_assert(!(b == a));
lolunit_assert(b == b);
lolunit_assert(!(b == c));

lolunit_assert(!(c == a));
lolunit_assert(!(c == b));
lolunit_assert(c == c);
}

lolunit_declare_test(operator_notequal)
{
bigint<> a(-1), b(0), c(1);

lolunit_assert(!(a != a));
lolunit_assert(a != b);
lolunit_assert(a != c);

lolunit_assert(b != a);
lolunit_assert(!(b != b));
lolunit_assert(b != c);

lolunit_assert(c != a);
lolunit_assert(c != b);
lolunit_assert(!(c != c));
}

lolunit_declare_test(operator_smaller)
{
bigint<> a(-10), b(-1), c(0), d(1), e(10);

lolunit_assert(!(a < a));
lolunit_assert(a < b);
lolunit_assert(a < c);
lolunit_assert(a < d);
lolunit_assert(a < e);

lolunit_assert(!(b < a));
lolunit_assert(!(b < b));
lolunit_assert(b < c);
lolunit_assert(b < d);
lolunit_assert(b < e);

lolunit_assert(!(c < a));
lolunit_assert(!(c < b));
lolunit_assert(!(c < c));
lolunit_assert(c < d);
lolunit_assert(c < e);

lolunit_assert(!(d < a));
lolunit_assert(!(d < b));
lolunit_assert(!(d < c));
lolunit_assert(!(d < d));
lolunit_assert(d < e);

lolunit_assert(!(e < a));
lolunit_assert(!(e < b));
lolunit_assert(!(e < c));
lolunit_assert(!(e < d));
lolunit_assert(!(e < e));
}

lolunit_declare_test(operator_smaller_or_equal)
{
bigint<> a(-10), b(-1), c(0), d(1), e(10);

lolunit_assert(a <= a);
lolunit_assert(a <= b);
lolunit_assert(a <= c);
lolunit_assert(a <= d);
lolunit_assert(a <= e);

lolunit_assert(!(b <= a));
lolunit_assert(b <= b);
lolunit_assert(b <= c);
lolunit_assert(b <= d);
lolunit_assert(b <= e);

lolunit_assert(!(c <= a));
lolunit_assert(!(c <= b));
lolunit_assert(c <= c);
lolunit_assert(c <= d);
lolunit_assert(c <= e);

lolunit_assert(!(d <= a));
lolunit_assert(!(d <= b));
lolunit_assert(!(d <= c));
lolunit_assert(d <= d);
lolunit_assert(d <= e);

lolunit_assert(!(e <= a));
lolunit_assert(!(e <= b));
lolunit_assert(!(e <= c));
lolunit_assert(!(e <= d));
lolunit_assert(e <= e);
}

lolunit_declare_test(operator_greater)
{
bigint<> a(-10), b(-1), c(0), d(1), e(10);

lolunit_assert(!(a > a));
lolunit_assert(!(a > b));
lolunit_assert(!(a > c));
lolunit_assert(!(a > d));
lolunit_assert(!(a > e));

lolunit_assert(b > a);
lolunit_assert(!(b > b));
lolunit_assert(!(b > c));
lolunit_assert(!(b > d));
lolunit_assert(!(b > e));

lolunit_assert(c > a);
lolunit_assert(c > b);
lolunit_assert(!(c > c));
lolunit_assert(!(c > d));
lolunit_assert(!(c > e));

lolunit_assert(d > a);
lolunit_assert(d > b);
lolunit_assert(d > c);
lolunit_assert(!(d > d));
lolunit_assert(!(d > e));

lolunit_assert(e > a);
lolunit_assert(e > b);
lolunit_assert(e > c);
lolunit_assert(e > d);
lolunit_assert(!(e > e));
}

lolunit_declare_test(operator_greater_or_equal)
{
bigint<> a(-10), b(-1), c(0), d(1), e(10);

lolunit_assert(a >= a);
lolunit_assert(!(a >= b));
lolunit_assert(!(a >= c));
lolunit_assert(!(a >= d));
lolunit_assert(!(a >= e));

lolunit_assert(b >= a);
lolunit_assert(b >= b);
lolunit_assert(!(b >= c));
lolunit_assert(!(b >= d));
lolunit_assert(!(b >= e));

lolunit_assert(c >= a);
lolunit_assert(c >= b);
lolunit_assert(c >= c);
lolunit_assert(!(c >= d));
lolunit_assert(!(c >= e));

lolunit_assert(d >= a);
lolunit_assert(d >= b);
lolunit_assert(d >= c);
lolunit_assert(d >= d);
lolunit_assert(!(d >= e));

lolunit_assert(e >= a);
lolunit_assert(e >= b);
lolunit_assert(e >= c);
lolunit_assert(e >= d);
lolunit_assert(e >= e);
}
};



読み込み中…
キャンセル
保存