ref: 6d4651504d015f853c5ce79e77c73b5b8d48793f
dir: /libstd/bigint.myr/
use "alloc.use"
use "chartype.use"
use "cmp.use"
use "die.use"
use "extremum.use"
use "hasprefix.use"
use "option.use"
use "slcp.use"
use "sldup.use"
use "slfill.use"
use "slpush.use"
use "types.use"
use "utf.use"
use "errno.use"
pkg std =
	type bigint = struct
		dig	: uint32[:] 	/* little endian, no leading zeros. */
		sign	: int		/* -1 for -ve, 0 for zero, 1 for +ve. */
	;;
	/* administrivia */
	generic mkbigint	: (v : @a::(numeric,integral) -> bigint#)
	const bigfree	: (a : bigint# -> void)
	const bigdup	: (a : bigint# -> bigint#)
	const bigassign	: (d : bigint#, s : bigint# -> bigint#)
	const bigmove	: (d : bigint#, s : bigint# -> bigint#)
	const bigparse	: (s : byte[:] -> option(bigint#))
	const bigclear	: (a : bigint# -> bigint#)
	const bigbfmt	: (b : byte[:], a : bigint#, base : int -> size)
	/*
	const bigtoint	: (a : bigint#	-> @a::(numeric,integral))
	*/
	/* some useful predicates */
	const bigiszero	: (a : bigint# -> bool)
	const bigeq	: (a : bigint#, b : bigint# -> bool)
	generic bigeqi	: (a : bigint#, b : @a::(numeric,integral) -> bool)
	const bigcmp	: (a : bigint#, b : bigint# -> order)
	/* bigint*bigint -> bigint ops */
	const bigadd	: (a : bigint#, b : bigint# -> bigint#)
	const bigsub	: (a : bigint#, b : bigint# -> bigint#)
	const bigmul	: (a : bigint#, b : bigint# -> bigint#)
	const bigdiv	: (a : bigint#, b : bigint# -> bigint#)
	const bigmod	: (a : bigint#, b : bigint# -> bigint#)
	const bigdivmod	: (a : bigint#, b : bigint# -> (bigint#, bigint#))
	const bigshl	: (a : bigint#, b : bigint# -> bigint#)
	const bigshr	: (a : bigint#, b : bigint# -> bigint#)
	const bigmodpow	: (b : bigint#, e : bigint#, m : bigint# -> bigint#)
	/*
	const bigpow	: (a : bigint#, b : bigint# -> bigint#)
	*/
	/* bigint*int -> bigint ops */
	generic bigaddi	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigsubi	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigmuli	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigdivi	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigshli	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigshri	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	/*
	const bigpowi	: (a : bigint#, b : uint64 -> bigint#)
	*/
;;
const Base = 0x100000000ul
generic mkbigint = {v : @a::(integral,numeric)
	var a
	var val
	a = zalloc()
	if v < 0
		a.sign = -1
		v = -v
	elif v > 0
		a.sign = 1
	;;
	val = v castto(uint64)
	a.dig = slpush([][:], val castto(uint32))
	if val > Base
		a.dig = slpush(a.dig, (val/Base) castto(uint32))
	;;
	-> trim(a)
}
const bigfree = {a
	slfree(a.dig)
	free(a)
}
const bigdup = {a
	-> bigassign(zalloc(), a)
}
const bigassign = {d, s
	slfree(d.dig)
	d# = s#
	d.dig = sldup(s.dig)
	-> d
}
const bigmove = {d, s
	slfree(d.dig)
	d# = s#
	s.dig = [][:]
	s.sign = 0
	-> d
}
const bigclear = {v
	std.slfree(v.dig)
	v.sign = 0
	v.dig = [][:]
	-> v
}
/* for now, just dump out something for debugging... */
const bigbfmt = {buf, x, base
	const digitchars = [
	'0','1','2','3','4','5','6','7','8','9',
	'a','b','c','d','e','f', 'g', 'h', 'i', 'j',
	'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
	't', 'u', 'v', 'w', 'x', 'y', 'z']
	var v, val
	var n, i
	var tmp, rem, b
	if base < 0 || base > 36
		die("invalid base in bigbfmt\n")
	;;
	if bigiszero(x)
		n
	;;
	if base == 0
		b = mkbigint(10)
	else
		b = mkbigint(base)
	;;
	n = 0
	val = bigdup(x)
	/* generate the digits in reverse order */
	while !bigiszero(val)
		(v, rem) = bigdivmod(val, b)
		if rem.dig.len > 0
			n += encode(buf[n:], digitchars[rem.dig[0]])
		else
			n += encode(buf[n:], '0')
		;;
		bigfree(val)
		bigfree(rem)
		val = v
	;;
	bigfree(val)
	bigfree(b)
	/* this is done last, so we get things right when we reverse the string */
	if x.sign == 0
		n += encode(buf[n:], '0')
	elif x.sign == -1
		n += encode(buf[n:], '-')
	;;
	/* we only generated ascii digits, so this works for reversing. */
	for i = 0; i < n/2; i++
		tmp = buf[i]
		buf[i] = buf[n - i - 1]
		buf[n - i - 1] = tmp
	;;
	-> n
}
const bigparse = {str
	var c, val : int, base
	var v, b
	var a
	if hasprefix(str, "0x") || hasprefix(str, "0X")
		base = 16
	elif hasprefix(str, "0o") || hasprefix(str, "0O")
		base = 8
	elif hasprefix(str, "0b") || hasprefix(str, "0B")
		base = 2
	else
		base = 10
	;;
	if base != 10
		str = str[2:]
	;;
	a = mkbigint(0)
	b = mkbigint(base)
	/*
	 efficiency hack: to save allocations,
	 just mutate v[0]. The value will always
	 fit in one digit.
	 */
	v = mkbigint(1)
	while str.len != 0
		(c, str) = striter(str)
		if c == '_'
			continue
		;;
		val = charval(c, base)
		if val < 0 || val > base
			bigfree(a)
			bigfree(b)
			bigfree(v)
			-> `None
		;;
		v.dig[0] = val castto(uint32)
		if val == 0
			v.sign = 0
		else
			v.sign = 1
		;;
		bigmul(a, b)
		bigadd(a, v)
	;;
	-> `Some a
}
const bigiszero = {v
	-> v.dig.len == 0
}
const bigeq = {a, b
	var i
	if a.sign != b.sign || a.dig.len != b.dig.len
		-> false
	;;
	for i = 0; i < a.dig.len; i++
		if a.dig[i] != b.dig[i]
			-> false
		;;
	;;
	-> true
}
generic bigeqi = {a, b
	var v
	var dig : uint32[2]
	bigdigit(&v, b < 0, b castto(uint64), dig[:])
	-> bigeq(a, &v)
}
const bigcmp = {a, b
	var i
	var da, db, sa, sb
	sa = a.sign castto(int64)
	sb = b.sign castto(int64)
	if sa < sb
		-> `Before
	elif sa > sb
		-> `After
	elif a.dig.len < b.dig.len
		-> signedorder(-sa)
	elif a.dig.len > b.dig.len
		-> signedorder(sa)
	else
		/* otherwise, the one with the first larger digit is bigger */
		for i = a.dig.len; i > 0; i--
			da = a.dig[i - 1] castto(int64)
			db = b.dig[i - 1] castto(int64)
			-> signedorder(sa * (da - db))
		;;
	;;
	-> `Equal
}
const signedorder = {sign
	if sign < 0
		-> `Before 
	elif sign == 0
		-> `Equal
	else
		-> `After
	;;
}
/* a += b */
const bigadd = {a, b
	if a.sign == b.sign || a.sign == 0 
		a.sign = b.sign
		-> uadd(a, b)
	elif b.sign == 0
		-> a
	else
		match bigcmp(a, b)
		| `Before: /* a is negative */
		    a.sign = b.sign
		    -> usub(b, a)
		| `After: /* b is negative */
		    -> usub(a, b)
		| `Equal:
			die("Impossible. Equal vals with different sign.")
		;;
	;;
}
/* adds two unsigned values together. */
const uadd = {a, b
	var v, i
	var carry
	var n
	carry = 0
	n = max(a.dig.len, b.dig.len)
	/* guaranteed to carry no more than one value */
	a.dig = slzgrow(a.dig, n + 1)
	for i = 0; i < n; i++
		v = (a.dig[i] castto(uint64)) + carry;
		if i < b.dig.len
			v += (b.dig[i] castto(uint64))
		;;
		if v >= Base
			carry = 1
		else
			carry = 0
		;;
		a.dig[i] = v castto(uint32)
	;;
	a.dig[i] += carry castto(uint32)
	-> trim(a)
}
/* a -= b */
const bigsub = {a, b
	/* 0 - x = -x */
	if a.sign == 0
		bigassign(a, b)
		a.sign = -b.sign
		-> a
	/* x - 0 = x */
	elif b.sign == 0
		-> a
	elif a.sign != b.sign
		-> uadd(a, b)
	else
		match bigcmp(a, b)
		| `Before: /* a is negative */
		    a.sign = b.sign
		    -> usub(b, a)
		| `After: /* b is negative */
		    -> usub(a, b)
		| `Equal:
		    -> bigclear(a)
		;;
	;;
	-> a
}
/* subtracts two unsigned values, where 'a' is strictly greater than 'b' */
const usub = {a, b
	var carry
	var v, i
	carry = 0
	for i = 0; i < a.dig.len; i++
		v = (a.dig[i] castto(int64)) - carry
		if i < b.dig.len
			v -= (b.dig[i] castto(int64)) 
		;;
		if v < 0
			carry = 1
		else
			carry = 0
		;;
		a.dig[i] = v castto(uint32)
	;;
	-> trim(a)
}
/* a *= b */
const bigmul = {a, b
	var i, j
	var ai, bj, wij
	var carry, t
	var w
	if a.sign == 0 || b.sign == 0
		a.sign = 0
		slfree(a.dig)
		a.dig = [][:]
		-> a
	elif a.sign != b.sign
		a.sign = -1
	else
		a.sign = 1
	;;
	w  = slzalloc(a.dig.len + b.dig.len)
	for j = 0; j < b.dig.len; j++
		carry = 0
		for i = 0; i < a.dig.len; i++
			ai = a.dig[i] castto(uint64)
			bj = b.dig[j] castto(uint64)
			wij = w[i+j] castto(uint64)
			t = ai * bj + wij + carry
			w[i + j] = (t castto(uint32))
			carry = t >> 32
		;;
		w[i+j] = carry castto(uint32)
	;;
	slfree(a.dig)
	a.dig = w
	-> trim(a)
}
const bigdiv = {a : bigint#, b : bigint# -> bigint#
	var q, r
	(q, r) = bigdivmod(a, b)
	bigfree(r)
	-> bigmove(a, q)
}
const bigmod = {a : bigint#, b : bigint# -> bigint#
	var q, r
	(q, r) = bigdivmod(a, b)
	bigfree(q)
	-> bigmove(a, r)
}
/* a /= b */
const bigdivmod = {a : bigint#, b : bigint# -> (bigint#, bigint#)
	/*
	Implements bigint division using Algorithm D from
	Knuth: Seminumerical algorithms, Section 4.3.1.
	*/
	var m : int64, n : int64
	var qhat, rhat, carry, shift
	var x, y, z, w, p, t /* temporaries */
	var b0, aj
	var u, v
	var i, j : int64
	var q
	if bigiszero(b)
		die("divide by zero\n")
	;;
	/* if b > a, we trucate to 0, with remainder 'a' */
	if a.dig.len < b.dig.len
		-> (mkbigint(0), bigdup(a))
	;;
	q = zalloc()
	q.dig = slzalloc(max(a.dig.len, b.dig.len) + 1)
	if a.sign != b.sign
		q.sign = -1
	else
		q.sign = 1
	;;
	/* handle single digit divisor separately: the knuth algorithm needs at least 2 digits. */
	if b.dig.len == 1
		carry = 0
		b0 = (b.dig[0] castto(uint64))
		for j = a.dig.len; j > 0; j--
			aj = (a.dig[j - 1] castto(uint64))
			q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32)
			carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0
		;;
		q = trim(q)
		-> (q, trim(mkbigint(carry castto(int32))))
	;;
	u = bigdup(a)
	v = bigdup(b)
	m = u.dig.len
	n = v.dig.len
	shift = nlz(v.dig[n - 1])
	bigshli(u, shift)
	bigshli(v, shift)
	u.dig = slzgrow(u.dig, u.dig.len + 1)
	for j = m - n; j >= 0; j--
		/* load a few temps */
		x = u.dig[j + n] castto(uint64)
		y = u.dig[j + n - 1] castto(uint64)
		z = v.dig[n - 1] castto(uint64)
		w = v.dig[n - 2] castto(uint64)
		t = u.dig[j + n - 2] castto(uint64)
		/* estimate qhat */
		qhat = (x*Base + y)/z
		rhat = (x*Base + y) - qhat*z
:divagain
		if qhat >= Base || (qhat * w) > (rhat*Base + t)
			qhat--
			rhat += z
			if rhat < Base
				goto divagain
			;;
		;;
		/* multiply and subtract */
		carry = 0
		for i = 0; i < n; i++
			p = qhat * (v.dig[i] castto(uint64))
			t = (u.dig[i+j] castto(uint64)) - carry - (p % Base)
			u.dig[i+j] = t castto(uint32)
			carry = (((p castto(int64)) >> 32) - ((t castto(int64)) >> 32)) castto(uint64);
		;;
		t = (u.dig[j + n] castto(uint64)) - carry
		u.dig[j + n] = t castto(uint32)
		q.dig[j] = qhat castto(uint32)
		/* adjust */
		if t castto(int64) < 0
			q.dig[j]--
			carry = 0
			for i = 0; i < n; i++
				t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry
				u.dig[i+j] = t castto(uint32)
				carry = t >> 32
			;;
			u.dig[j+n] = u.dig[j+n] + (carry castto(uint32));
		;;
	;;
	/* undo the biasing for remainder */
	u = bigshri(u, shift)
	-> (trim(q), trim(u))
}
/* computes b^e % m */
const bigmodpow = {base, exp, mod
	var r, n
	r = mkbigint(1)
	n = 0
    	while !bigiszero(exp)
		if (exp.dig[0] & 1) != 0
			bigmul(r, base)
			bigmod(r, mod)
		;;
		bigshri(exp, 1)
		bigmul(base, base)
		bigmod(base, mod)
	;;
	-> bigmove(base, r)
}
/* returns the number of leading zeros */
const nlz = {a : uint32
	var n
	if a == 0
		-> 32
	;;
	n = 0
	if a <= 0x0000ffff
		n += 16
		a <<= 16
	;;
	if a <= 0x00ffffff
		n += 8
		a <<= 8
	;;
	if a <= 0x0fffffff
		n += 4
		a <<= 4
	;;
	if a <= 0x3fffffff
		n += 2
		a <<= 2
	;;
	if a <= 0x7fffffff
		n += 1
		a <<= 1
	;;
	-> n
}
/* a <<= b */
const bigshl = {a, b
	match b.dig.len
	| 0:	-> a
	| 1:	-> bigshli(a, b.dig[0] castto(uint64))
	| n:	die("shift by way too much\n")
	;;
}
/* a >>= b, unsigned */
const bigshr = {a, b
	match b.dig.len
	| 0:	-> a
	| 1:	-> bigshri(a, b.dig[0] castto(uint64))
	| n:	die("shift by way too much\n")
	;;
}
/* a + b, b is integer.
FIXME: acually make this a performace improvement
*/
generic bigaddi = {a, b
	var bigb : bigint
	var dig : uint32[2]
	bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
	bigadd(a, &bigb)
	-> a
}
generic bigsubi = {a, b : @a::(numeric,integral)
	var bigb : bigint
	var dig : uint32[2]
	bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
	bigsub(a, &bigb)
	-> a
}
generic bigmuli = {a, b
	var bigb : bigint
	var dig : uint32[2]
	bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
	bigmul(a, &bigb)
	-> a
}
generic bigdivi = {a, b
	var bigb : bigint
	var dig : uint32[2]
	bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
	bigdiv(a, &bigb)
	-> a
}
/* 
  a << s, with integer arg.
  logical left shift. any other type would be illogical.
 */
generic bigshli = {a, s : @a::(numeric,integral)
	var off, shift
	var t, carry
	var i
	assert(s >= 0, "shift amount must be positive")
	off = (s castto(uint64)) / 32
	shift = (s castto(uint64)) % 32
	/* zero shifted by anything is zero */
	if a.sign == 0
		-> a
	;;
	a.dig = slzgrow(a.dig, 1 + a.dig.len + off castto(size))
	/* blit over the base values */
	for i = a.dig.len; i > off; i--
		a.dig[i - 1] = a.dig[i - 1 - off]
	;;
	for i = 0; i < off; i++
		a.dig[i] = 0
	;;
	/* and shift over by the remainder */
	carry = 0
	for i = 0; i < a.dig.len; i++
		t = (a.dig[i] castto(uint64)) << shift
		a.dig[i] = (t | carry) castto(uint32) 
		carry = t >> 32
	;;
	-> trim(a)
}
/* logical shift right, zero fills. sign remains untouched. */
generic bigshri = {a, s
	var off, shift
	var t, carry
	var i
	assert(s >= 0, "shift amount must be positive")
	off = (s castto(uint64)) / 32
	shift = (s castto(uint64)) % 32
	/* blit over the base values */
	for i = 0; i < a.dig.len - off; i++
		a.dig[i] = a.dig[i + off]
	;;
	for i = a.dig.len - off; i < a.dig.len; i++
		a.dig[i] = 0
	;;
	/* and shift over by the remainder */
	carry = 0
	for i = a.dig.len; i > 0; i--
		t = (a.dig[i - 1] castto(uint64))
		a.dig[i - 1] = (carry | (t >> shift)) castto(uint32) 
		carry = t << (32 - shift)
	;;
	-> trim(a)
}
/* creates a bigint on the stack; should not be modified. */
const bigdigit = {v, isneg : bool, val : uint64, dig
	v.sign = 1
	if isneg
		val = -val
		v.sign = -1
	;;
	if val == 0
		v.sign = 0
		v.dig = [][:]
	elif val < Base
		v.dig = dig[:1]
		v.dig[0] = val castto(uint32)
	else
		v.dig = dig
		v.dig[0] = val castto(uint32)
		v.dig[1] = (val >> 32) castto(uint32)
	;;
}
/* trims leading zeros */
const trim = {a
	var i
	for i = a.dig.len; i > 0; i--
		if a.dig[i - 1] != 0
			break
		;;
	;;
	a.dig = slgrow(a.dig, i)
	if i == 0
		a.sign = 0
	elif a.sign == 0
		a.sign = 1
	;;
	-> a
}