ref: 515b8e89c1a6c0859eaefbf7835a4e6497a57e0b
dir: /libstd/bigint.myr/
use "alloc.use" use "chartype.use" use "cmp.use" use "die.use" use "extremum.use" use "hasprefix.use" use "option.use" use "slcp.use" use "sldup.use" use "slfill.use" use "slpush.use" use "types.use" use "utf.use" use "errno.use" pkg std = type bigint = struct dig : uint32[:] /* little endian, no leading zeros. */ sign : int /* -1 for -ve, 0 for zero, 1 for +ve. */ ;; /* administrivia */ generic mkbigint : (v : @a::(numeric,integral) -> bigint#) const bigfree : (a : bigint# -> void) const bigdup : (a : bigint# -> bigint#) const bigassign : (d : bigint#, s : bigint# -> bigint#) const bigmove : (d : bigint#, s : bigint# -> bigint#) const bigparse : (s : byte[:] -> option(bigint#)) const bigclear : (a : bigint# -> bigint#) const bigbfmt : (b : byte[:], a : bigint#, base : int -> size) /* const bigtoint : (a : bigint# -> @a::(numeric,integral)) */ /* some useful predicates */ const bigiszero : (a : bigint# -> bool) const bigeq : (a : bigint#, b : bigint# -> bool) generic bigeqi : (a : bigint#, b : @a::(numeric,integral) -> bool) const bigcmp : (a : bigint#, b : bigint# -> order) /* bigint*bigint -> bigint ops */ const bigadd : (a : bigint#, b : bigint# -> bigint#) const bigsub : (a : bigint#, b : bigint# -> bigint#) const bigmul : (a : bigint#, b : bigint# -> bigint#) const bigdiv : (a : bigint#, b : bigint# -> bigint#) const bigmod : (a : bigint#, b : bigint# -> bigint#) const bigdivmod : (a : bigint#, b : bigint# -> (bigint#, bigint#)) const bigshl : (a : bigint#, b : bigint# -> bigint#) const bigshr : (a : bigint#, b : bigint# -> bigint#) const bigmodpow : (b : bigint#, e : bigint#, m : bigint# -> bigint#) /* const bigpow : (a : bigint#, b : bigint# -> bigint#) */ /* bigint*int -> bigint ops */ generic bigaddi : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigsubi : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigmuli : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigdivi : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigshli : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigshri : (a : bigint#, b : @a::(integral,numeric) -> bigint#) /* const bigpowi : (a : bigint#, b : uint64 -> bigint#) */ ;; const Base = 0x100000000ul generic mkbigint = {v : @a::(integral,numeric) var a var val a = zalloc() if v < 0 a.sign = -1 v = -v elif v > 0 a.sign = 1 ;; val = v castto(uint64) a.dig = slpush([][:], val castto(uint32)) if val > Base a.dig = slpush(a.dig, (val/Base) castto(uint32)) ;; -> trim(a) } const bigfree = {a slfree(a.dig) free(a) } const bigdup = {a -> bigassign(zalloc(), a) } const bigassign = {d, s slfree(d.dig) d# = s# d.dig = sldup(s.dig) -> d } const bigmove = {d, s slfree(d.dig) d# = s# s.dig = [][:] s.sign = 0 -> d } const bigclear = {v std.slfree(v.dig) v.sign = 0 v.dig = [][:] -> v } /* for now, just dump out something for debugging... */ const bigbfmt = {buf, x, base const digitchars = [ '0','1','2','3','4','5','6','7','8','9', 'a','b','c','d','e','f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] var v, val var n, i var tmp, rem, b if base < 0 || base > 36 die("invalid base in bigbfmt\n") ;; if bigiszero(x) n ;; if base == 0 b = mkbigint(10) else b = mkbigint(base) ;; n = 0 val = bigdup(x) /* generate the digits in reverse order */ while !bigiszero(val) (v, rem) = bigdivmod(val, b) if rem.dig.len > 0 n += encode(buf[n:], digitchars[rem.dig[0]]) else n += encode(buf[n:], '0') ;; bigfree(val) bigfree(rem) val = v ;; bigfree(val) bigfree(b) /* this is done last, so we get things right when we reverse the string */ if x.sign == 0 n += encode(buf[n:], '0') elif x.sign == -1 n += encode(buf[n:], '-') ;; /* we only generated ascii digits, so this works for reversing. */ for i = 0; i < n/2; i++ tmp = buf[i] buf[i] = buf[n - i - 1] buf[n - i - 1] = tmp ;; -> n } const bigparse = {str var c, val : int, base var v, b var a if hasprefix(str, "0x") || hasprefix(str, "0X") base = 16 elif hasprefix(str, "0o") || hasprefix(str, "0O") base = 8 elif hasprefix(str, "0b") || hasprefix(str, "0B") base = 2 else base = 10 ;; if base != 10 str = str[2:] ;; a = mkbigint(0) b = mkbigint(base) /* efficiency hack: to save allocations, just mutate v[0]. The value will always fit in one digit. */ v = mkbigint(1) while str.len != 0 (c, str) = striter(str) if c == '_' continue ;; val = charval(c, base) if val < 0 || val > base bigfree(a) bigfree(b) bigfree(v) -> `None ;; v.dig[0] = val castto(uint32) if val == 0 v.sign = 0 else v.sign = 1 ;; bigmul(a, b) bigadd(a, v) ;; -> `Some a } const bigiszero = {v -> v.dig.len == 0 } const bigeq = {a, b var i if a.sign != b.sign || a.dig.len != b.dig.len -> false ;; for i = 0; i < a.dig.len; i++ if a.dig[i] != b.dig[i] -> false ;; ;; -> true } generic bigeqi = {a, b var v var dig : uint32[2] bigdigit(&v, b < 0, b castto(uint64), dig[:]) -> bigeq(a, &v) } const bigcmp = {a, b var i var da, db, sa, sb sa = a.sign castto(int64) sb = b.sign castto(int64) if sa < sb -> `Before elif sa > sb -> `After elif a.dig.len < b.dig.len -> signedorder(-sa) elif a.dig.len > b.dig.len -> signedorder(sa) else /* otherwise, the one with the first larger digit is bigger */ for i = a.dig.len; i > 0; i-- da = a.dig[i - 1] castto(int64) db = b.dig[i - 1] castto(int64) -> signedorder(sa * (da - db)) ;; ;; -> `Equal } const signedorder = {sign if sign < 0 -> `Before elif sign == 0 -> `Equal else -> `After ;; } /* a += b */ const bigadd = {a, b if a.sign == b.sign || a.sign == 0 a.sign = b.sign -> uadd(a, b) elif b.sign == 0 -> a else match bigcmp(a, b) | `Before: /* a is negative */ a.sign = b.sign -> usub(b, a) | `After: /* b is negative */ -> usub(a, b) | `Equal: die("Impossible. Equal vals with different sign.") ;; ;; } /* adds two unsigned values together. */ const uadd = {a, b var v, i var carry var n carry = 0 n = max(a.dig.len, b.dig.len) /* guaranteed to carry no more than one value */ a.dig = slzgrow(a.dig, n + 1) for i = 0; i < n; i++ v = (a.dig[i] castto(uint64)) + carry; if i < b.dig.len v += (b.dig[i] castto(uint64)) ;; if v >= Base carry = 1 else carry = 0 ;; a.dig[i] = v castto(uint32) ;; a.dig[i] += carry castto(uint32) -> trim(a) } /* a -= b */ const bigsub = {a, b /* 0 - x = -x */ if a.sign == 0 bigassign(a, b) a.sign = -b.sign -> a /* x - 0 = x */ elif b.sign == 0 -> a elif a.sign != b.sign -> uadd(a, b) else match bigcmp(a, b) | `Before: /* a is negative */ a.sign = b.sign -> usub(b, a) | `After: /* b is negative */ -> usub(a, b) | `Equal: -> bigclear(a) ;; ;; -> a } /* subtracts two unsigned values, where 'a' is strictly greater than 'b' */ const usub = {a, b var carry var v, i carry = 0 for i = 0; i < a.dig.len; i++ v = (a.dig[i] castto(int64)) - carry if i < b.dig.len v -= (b.dig[i] castto(int64)) ;; if v < 0 carry = 1 else carry = 0 ;; a.dig[i] = v castto(uint32) ;; -> trim(a) } /* a *= b */ const bigmul = {a, b var i, j var ai, bj, wij var carry, t var w if a.sign == 0 || b.sign == 0 a.sign = 0 slfree(a.dig) a.dig = [][:] -> a elif a.sign != b.sign a.sign = -1 else a.sign = 1 ;; w = slzalloc(a.dig.len + b.dig.len) for j = 0; j < b.dig.len; j++ carry = 0 for i = 0; i < a.dig.len; i++ ai = a.dig[i] castto(uint64) bj = b.dig[j] castto(uint64) wij = w[i+j] castto(uint64) t = ai * bj + wij + carry w[i + j] = (t castto(uint32)) carry = t >> 32 ;; w[i+j] = carry castto(uint32) ;; slfree(a.dig) a.dig = w -> trim(a) } const bigdiv = {a : bigint#, b : bigint# -> bigint# var q, r (q, r) = bigdivmod(a, b) bigfree(r) -> bigmove(a, q) } const bigmod = {a : bigint#, b : bigint# -> bigint# var q, r (q, r) = bigdivmod(a, b) bigfree(q) -> bigmove(a, r) } /* a /= b */ const bigdivmod = {a : bigint#, b : bigint# -> (bigint#, bigint#) /* Implements bigint division using Algorithm D from Knuth: Seminumerical algorithms, Section 4.3.1. */ var m : int64, n : int64 var qhat, rhat, carry, shift var x, y, z, w, p, t /* temporaries */ var b0, aj var u, v var i, j : int64 var q if bigiszero(b) die("divide by zero\n") ;; /* if b > a, we trucate to 0, with remainder 'a' */ if a.dig.len < b.dig.len -> (mkbigint(0), bigdup(a)) ;; q = zalloc() q.dig = slzalloc(max(a.dig.len, b.dig.len) + 1) if a.sign != b.sign q.sign = -1 else q.sign = 1 ;; /* handle single digit divisor separately: the knuth algorithm needs at least 2 digits. */ if b.dig.len == 1 carry = 0 b0 = (b.dig[0] castto(uint64)) for j = a.dig.len; j > 0; j-- aj = (a.dig[j - 1] castto(uint64)) q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32) carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0 ;; q = trim(q) -> (q, trim(mkbigint(carry castto(int32)))) ;; u = bigdup(a) v = bigdup(b) m = u.dig.len n = v.dig.len shift = nlz(v.dig[n - 1]) bigshli(u, shift) bigshli(v, shift) u.dig = slzgrow(u.dig, u.dig.len + 1) for j = m - n; j >= 0; j-- /* load a few temps */ x = u.dig[j + n] castto(uint64) y = u.dig[j + n - 1] castto(uint64) z = v.dig[n - 1] castto(uint64) w = v.dig[n - 2] castto(uint64) t = u.dig[j + n - 2] castto(uint64) /* estimate qhat */ qhat = (x*Base + y)/z rhat = (x*Base + y) - qhat*z :divagain if qhat >= Base || (qhat * w) > (rhat*Base + t) qhat-- rhat += z if rhat < Base goto divagain ;; ;; /* multiply and subtract */ carry = 0 for i = 0; i < n; i++ p = qhat * (v.dig[i] castto(uint64)) t = (u.dig[i+j] castto(uint64)) - carry - (p % Base) u.dig[i+j] = t castto(uint32) carry = (((p castto(int64)) >> 32) - ((t castto(int64)) >> 32)) castto(uint64); ;; t = (u.dig[j + n] castto(uint64)) - carry u.dig[j + n] = t castto(uint32) q.dig[j] = qhat castto(uint32) /* adjust */ if t castto(int64) < 0 q.dig[j]-- carry = 0 for i = 0; i < n; i++ t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry u.dig[i+j] = t castto(uint32) carry = t >> 32 ;; u.dig[j+n] = u.dig[j+n] + (carry castto(uint32)); ;; ;; /* undo the biasing for remainder */ u = bigshri(u, shift) -> (trim(q), trim(u)) } /* computes b^e % m */ const bigmodpow = {base, exp, mod var r, n r = mkbigint(1) n = 0 while !bigiszero(exp) if (exp.dig[0] & 1) != 0 bigmul(r, base) bigmod(r, mod) ;; bigshri(exp, 1) bigmul(base, base) bigmod(base, mod) ;; -> bigmove(base, r) } /* returns the number of leading zeros */ const nlz = {a : uint32 var n if a == 0 -> 32 ;; n = 0 if a <= 0x0000ffff n += 16 a <<= 16 ;; if a <= 0x00ffffff n += 8 a <<= 8 ;; if a <= 0x0fffffff n += 4 a <<= 4 ;; if a <= 0x3fffffff n += 2 a <<= 2 ;; if a <= 0x7fffffff n += 1 a <<= 1 ;; -> n } /* a <<= b */ const bigshl = {a, b match b.dig.len | 0: -> a | 1: -> bigshli(a, b.dig[0] castto(uint64)) | n: die("shift by way too much\n") ;; } /* a >>= b, unsigned */ const bigshr = {a, b match b.dig.len | 0: -> a | 1: -> bigshri(a, b.dig[0] castto(uint64)) | n: die("shift by way too much\n") ;; } /* a + b, b is integer. FIXME: acually make this a performace improvement */ generic bigaddi = {a, b var bigb : bigint var dig : uint32[2] bigdigit(&bigb, b < 0, b castto(uint64), dig[:]) bigadd(a, &bigb) -> a } generic bigsubi = {a, b : @a::(numeric,integral) var bigb : bigint var dig : uint32[2] bigdigit(&bigb, b < 0, b castto(uint64), dig[:]) bigsub(a, &bigb) -> a } generic bigmuli = {a, b var bigb : bigint var dig : uint32[2] bigdigit(&bigb, b < 0, b castto(uint64), dig[:]) bigmul(a, &bigb) -> a } generic bigdivi = {a, b var bigb : bigint var dig : uint32[2] bigdigit(&bigb, b < 0, b castto(uint64), dig[:]) bigdiv(a, &bigb) -> a } /* a << s, with integer arg. logical left shift. any other type would be illogical. */ generic bigshli = {a, s : @a::(numeric,integral) var off, shift var t, carry var i assert(s >= 0, "shift amount must be positive") off = (s castto(uint64)) / 32 shift = (s castto(uint64)) % 32 /* zero shifted by anything is zero */ if a.sign == 0 -> a ;; a.dig = slzgrow(a.dig, 1 + a.dig.len + off castto(size)) /* blit over the base values */ for i = a.dig.len; i > off; i-- a.dig[i - 1] = a.dig[i - 1 - off] ;; for i = 0; i < off; i++ a.dig[i] = 0 ;; /* and shift over by the remainder */ carry = 0 for i = 0; i < a.dig.len; i++ t = (a.dig[i] castto(uint64)) << shift a.dig[i] = (t | carry) castto(uint32) carry = t >> 32 ;; -> trim(a) } /* logical shift right, zero fills. sign remains untouched. */ generic bigshri = {a, s var off, shift var t, carry var i assert(s >= 0, "shift amount must be positive") off = (s castto(uint64)) / 32 shift = (s castto(uint64)) % 32 /* blit over the base values */ for i = 0; i < a.dig.len - off; i++ a.dig[i] = a.dig[i + off] ;; for i = a.dig.len - off; i < a.dig.len; i++ a.dig[i] = 0 ;; /* and shift over by the remainder */ carry = 0 for i = a.dig.len; i > 0; i-- t = (a.dig[i - 1] castto(uint64)) a.dig[i - 1] = (carry | (t >> shift)) castto(uint32) carry = t << (32 - shift) ;; -> trim(a) } /* creates a bigint on the stack; should not be modified. */ const bigdigit = {v, isneg : bool, val : uint64, dig v.sign = 1 if isneg val = -val v.sign = -1 ;; if val == 0 v.sign = 0 v.dig = [][:] elif val < Base v.dig = dig[:1] v.dig[0] = val castto(uint32) else v.dig = dig v.dig[0] = val castto(uint32) v.dig[1] = (val >> 32) castto(uint32) ;; } /* trims leading zeros */ const trim = {a var i for i = a.dig.len; i > 0; i-- if a.dig[i - 1] != 0 break ;; ;; a.dig = slgrow(a.dig, i) if i == 0 a.sign = 0 elif a.sign == 0 a.sign = 1 ;; -> a }