shithub: mc

ref: 21b91ad9fa7ae830272dc67fbe3150dc30921702
dir: /libstd/bigint.myr/

View raw version
use "alloc.use"
use "chartype.use"
use "cmp.use"
use "die.use"
use "extremum.use"
use "hasprefix.use"
use "option.use"
use "slcp.use"
use "sldup.use"
use "slfill.use"
use "slpush.use"
use "types.use"
use "utf.use"
use "errno.use"

pkg std =
	type bigint = struct
		dig	: uint32[:] 	/* little endian, no leading zeros. */
		sign	: int		/* -1 for -ve, 0 for zero, 1 for +ve. */
	;;

	/* administrivia */
	generic mkbigint	: (v : @a::(numeric,integral) -> bigint#)
	const bigfree	: (a : bigint# -> void)
	const bigdup	: (a : bigint# -> bigint#)
	const bigassign	: (d : bigint#, s : bigint# -> bigint#)
	const bigmove	: (d : bigint#, s : bigint# -> bigint#)
	const bigparse	: (s : byte[:] -> option(bigint#))
	const bigclear	: (a : bigint# -> bigint#)
	const bigfmt	: (a : bigint#, base : int -> byte[:])
	const bigbfmt	: (b : byte[:], a : bigint#, base : int -> size)
	/*
	const bigtoint	: (a : bigint#	-> @a::(numeric,integral))
	*/

	/* some useful predicates */
	const bigiszero	: (a : bigint# -> bool)
	const bigeq	: (a : bigint#, b : bigint# -> bool)
	generic bigeqi	: (a : bigint#, b : @a::(numeric,integral) -> bool)
	const bigcmp	: (a : bigint#, b : bigint# -> order)

	/* bigint*bigint -> bigint ops */
	const bigadd	: (a : bigint#, b : bigint# -> bigint#)
	const bigsub	: (a : bigint#, b : bigint# -> bigint#)
	const bigmul	: (a : bigint#, b : bigint# -> bigint#)
	const bigdiv	: (a : bigint#, b : bigint# -> bigint#)
	const bigmod	: (a : bigint#, b : bigint# -> bigint#)
	const bigdivmod	: (a : bigint#, b : bigint# -> (bigint#, bigint#))
	const bigshl	: (a : bigint#, b : bigint# -> bigint#)
	const bigshr	: (a : bigint#, b : bigint# -> bigint#)
	const bigmodpow	: (b : bigint#, e : bigint#, m : bigint# -> bigint#)
	/*
	const bigpow	: (a : bigint#, b : bigint# -> bigint#)
	*/

	/* bigint*int -> bigint ops */
	generic bigaddi	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigsubi	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigmuli	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigdivi	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigshli	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	generic bigshri	: (a : bigint#, b : @a::(integral,numeric) -> bigint#)
	/*
	const bigpowi	: (a : bigint#, b : uint64 -> bigint#)
	*/
;;

const Base = 0x100000000ul
generic mkbigint = {v : @a::(integral,numeric)
	var a
	var val

	a = zalloc()

	if v < 0
		a.sign = -1
		v = -v
	elif v > 0
		a.sign = 1
	;;
	val = v castto(uint64)
	a.dig = slpush([][:], val castto(uint32))
	if val > Base
		a.dig = slpush(a.dig, (val/Base) castto(uint32))
	;;
	-> trim(a)
}

const bigfree = {a
	slfree(a.dig)
	free(a)
}

const bigdup = {a
	-> bigassign(zalloc(), a)
}

const bigassign = {d, s
	slfree(d.dig)
	d# = s#
	d.dig = sldup(s.dig)
	-> d
}

const bigmove = {d, s
	slfree(d.dig)
	d# = s#
	s.dig = [][:]
	s.sign = 0
	-> d
}

const bigclear = {v
	std.slfree(v.dig)
	v.sign = 0
	v.dig = [][:]
	-> v
}

const bigfmt = {a, base
	var buf
	var n

 	/*
	allocate a buffer guaranteed to be big enough.
	that's 
		2 + floor(nbits/(log_2(10)))
	or
		2 + a.dig.len * 32/3.32...
	or
		2 + a.dig.len * 10
	plus one for the - sign.
	*/
	buf = slalloc(3 + a.dig.len * 10)
	n = bigbfmt(buf, a, base)
	-> buf[:n]
}

/* for now, just dump out something for debugging... */
const bigbfmt = {buf, x, base
	const digitchars = [
	'0','1','2','3','4','5','6','7','8','9',
	'a','b','c','d','e','f', 'g', 'h', 'i', 'j',
	'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
	't', 'u', 'v', 'w', 'x', 'y', 'z']
	var v, val
	var n, i
	var tmp, rem, b

	if base < 0 || base > 36
		die("invalid base in bigbfmt\n")
	;;

	if bigiszero(x)
		n
	;;

	if base == 0
		b = mkbigint(10)
	else
		b = mkbigint(base)
	;;
	n = 0
	val = bigdup(x)
	/* generate the digits in reverse order */
	while !bigiszero(val)
		(v, rem) = bigdivmod(val, b)
		if rem.dig.len > 0
			n += encode(buf[n:], digitchars[rem.dig[0]])
		else
			n += encode(buf[n:], '0')
		;;
		bigfree(val)
		bigfree(rem)
		val = v
	;;
	bigfree(val)
	bigfree(b)

	/* this is done last, so we get things right when we reverse the string */
	if x.sign == 0
		n += encode(buf[n:], '0')
	elif x.sign == -1
		n += encode(buf[n:], '-')
	;;

	/* we only generated ascii digits, so this works for reversing. */
	for i = 0; i < n/2; i++
		tmp = buf[i]
		buf[i] = buf[n - i - 1]
		buf[n - i - 1] = tmp
	;;
	-> n
}

const bigparse = {str
	var c, val : int, base
	var v, b
	var a

	if hasprefix(str, "0x") || hasprefix(str, "0X")
		base = 16
	elif hasprefix(str, "0o") || hasprefix(str, "0O")
		base = 8
	elif hasprefix(str, "0b") || hasprefix(str, "0B")
		base = 2
	else
		base = 10
	;;
	if base != 10
		str = str[2:]
	;;

	a = mkbigint(0)
	b = mkbigint(base)
	/*
	 efficiency hack: to save allocations,
	 just mutate v[0]. The value will always
	 fit in one digit.
	 */
	v = mkbigint(1)
	while str.len != 0
		(c, str) = striter(str)
		if c == '_'
			continue
		;;
		val = charval(c, base)
		if val < 0 || val > base
			bigfree(a)
			bigfree(b)
			bigfree(v)
			-> `None
		;;
		v.dig[0] = val castto(uint32)
		if val == 0
			v.sign = 0
		else
			v.sign = 1
		;;
		bigmul(a, b)
		bigadd(a, v)

	;;
	-> `Some a
}

const bigiszero = {v
	-> v.dig.len == 0
}

const bigeq = {a, b
	var i

	if a.sign != b.sign || a.dig.len != b.dig.len
		-> false
	;;
	for i = 0; i < a.dig.len; i++
		if a.dig[i] != b.dig[i]
			-> false
		;;
	;;
	-> true
}

generic bigeqi = {a, b
	var v
	var dig : uint32[2]

	bigdigit(&v, b < 0, b castto(uint64), dig[:])
	-> bigeq(a, &v)
}

const bigcmp = {a, b
	var i
	var da, db, sa, sb

	sa = a.sign castto(int64)
	sb = b.sign castto(int64)
	if sa < sb
		-> `Before
	elif sa > sb
		-> `After
	elif a.dig.len < b.dig.len
		-> signedorder(-sa)
	elif a.dig.len > b.dig.len
		-> signedorder(sa)
	else
		/* otherwise, the one with the first larger digit is bigger */
		for i = a.dig.len; i > 0; i--
			da = a.dig[i - 1] castto(int64)
			db = b.dig[i - 1] castto(int64)
			-> signedorder(sa * (da - db))
		;;
	;;
	-> `Equal
}

const signedorder = {sign
	if sign < 0
		-> `Before 
	elif sign == 0
		-> `Equal
	else
		-> `After
	;;
}

/* a += b */
const bigadd = {a, b
	if a.sign == b.sign || a.sign == 0 
		a.sign = b.sign
		-> uadd(a, b)
	elif b.sign == 0
		-> a
	else
		match bigcmp(a, b)
		| `Before: /* a is negative */
		    a.sign = b.sign
		    -> usub(b, a)
		| `After: /* b is negative */
		    -> usub(a, b)
		| `Equal:
			die("Impossible. Equal vals with different sign.")
		;;
	;;
}

/* adds two unsigned values together. */
const uadd = {a, b
	var v, i
	var carry
	var n

	carry = 0
	n = max(a.dig.len, b.dig.len)
	/* guaranteed to carry no more than one value */
	a.dig = slzgrow(a.dig, n + 1)
	for i = 0; i < n; i++
		v = (a.dig[i] castto(uint64)) + carry;
		if i < b.dig.len
			v += (b.dig[i] castto(uint64))
		;;

		if v >= Base
			carry = 1
		else
			carry = 0
		;;
		a.dig[i] = v castto(uint32)
	;;
	a.dig[i] += carry castto(uint32)
	-> trim(a)
}

/* a -= b */
const bigsub = {a, b
	/* 0 - x = -x */
	if a.sign == 0
		bigassign(a, b)
		a.sign = -b.sign
		-> a
	/* x - 0 = x */
	elif b.sign == 0
		-> a
	elif a.sign != b.sign
		-> uadd(a, b)
	else
		match bigcmp(a, b)
		| `Before: /* a is negative */
		    a.sign = b.sign
		    -> usub(b, a)
		| `After: /* b is negative */
		    -> usub(a, b)
		| `Equal:
		    -> bigclear(a)
		;;
	;;
	-> a
}

/* subtracts two unsigned values, where 'a' is strictly greater than 'b' */
const usub = {a, b
	var carry
	var v, i

	carry = 0
	for i = 0; i < a.dig.len; i++
		v = (a.dig[i] castto(int64)) - carry
		if i < b.dig.len
			v -= (b.dig[i] castto(int64)) 
		;;
		if v < 0
			carry = 1
		else
			carry = 0
		;;
		a.dig[i] = v castto(uint32)
	;;
	-> trim(a)
}

/* a *= b */
const bigmul = {a, b
	var i, j
	var ai, bj, wij
	var carry, t
	var w

	if a.sign == 0 || b.sign == 0
		a.sign = 0
		slfree(a.dig)
		a.dig = [][:]
		-> a
	elif a.sign != b.sign
		a.sign = -1
	else
		a.sign = 1
	;;
	w  = slzalloc(a.dig.len + b.dig.len)
	for j = 0; j < b.dig.len; j++
		carry = 0
		for i = 0; i < a.dig.len; i++
			ai = a.dig[i] castto(uint64)
			bj = b.dig[j] castto(uint64)
			wij = w[i+j] castto(uint64)
			t = ai * bj + wij + carry
			w[i + j] = (t castto(uint32))
			carry = t >> 32
		;;
		w[i+j] = carry castto(uint32)
	;;
	slfree(a.dig)
	a.dig = w
	-> trim(a)
}

const bigdiv = {a : bigint#, b : bigint# -> bigint#
	var q, r

	(q, r) = bigdivmod(a, b)
	bigfree(r)
	-> bigmove(a, q)
}
const bigmod = {a : bigint#, b : bigint# -> bigint#
	var q, r

	(q, r) = bigdivmod(a, b)
	bigfree(q)
	-> bigmove(a, r)
}

/* a /= b */
const bigdivmod = {a : bigint#, b : bigint# -> (bigint#, bigint#)
	/*
	Implements bigint division using Algorithm D from
	Knuth: Seminumerical algorithms, Section 4.3.1.
	*/
	var m : int64, n : int64
	var qhat, rhat, carry, shift
	var x, y, z, w, p, t /* temporaries */
	var b0, aj
	var u, v
	var i, j : int64
	var q

	if bigiszero(b)
		die("divide by zero\n")
	;;
	/* if b > a, we trucate to 0, with remainder 'a' */
	if a.dig.len < b.dig.len
		-> (mkbigint(0), bigdup(a))
	;;

	q = zalloc()
	q.dig = slzalloc(max(a.dig.len, b.dig.len) + 1)
	if a.sign != b.sign
		q.sign = -1
	else
		q.sign = 1
	;;

	/* handle single digit divisor separately: the knuth algorithm needs at least 2 digits. */
	if b.dig.len == 1
		carry = 0
		b0 = (b.dig[0] castto(uint64))
		for j = a.dig.len; j > 0; j--
			aj = (a.dig[j - 1] castto(uint64))
			q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32)
			carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0
		;;
		q = trim(q)
		-> (q, trim(mkbigint(carry castto(int32))))
	;;

	u = bigdup(a)
	v = bigdup(b)
	m = u.dig.len
	n = v.dig.len

	shift = nlz(v.dig[n - 1])
	bigshli(u, shift)
	bigshli(v, shift)
	u.dig = slzgrow(u.dig, u.dig.len + 1)

	for j = m - n; j >= 0; j--
		/* load a few temps */
		x = u.dig[j + n] castto(uint64)
		y = u.dig[j + n - 1] castto(uint64)
		z = v.dig[n - 1] castto(uint64)
		w = v.dig[n - 2] castto(uint64)
		t = u.dig[j + n - 2] castto(uint64)

		/* estimate qhat */
		qhat = (x*Base + y)/z
		rhat = (x*Base + y) - qhat*z
:divagain
		if qhat >= Base || (qhat * w) > (rhat*Base + t)
			qhat--
			rhat += z
			if rhat < Base
				goto divagain
			;;
		;;

		/* multiply and subtract */
		carry = 0
		for i = 0; i < n; i++
			p = qhat * (v.dig[i] castto(uint64))

			t = (u.dig[i+j] castto(uint64)) - carry - (p % Base)
			u.dig[i+j] = t castto(uint32)

			carry = (((p castto(int64)) >> 32) - ((t castto(int64)) >> 32)) castto(uint64);
		;;
		t = (u.dig[j + n] castto(uint64)) - carry
		u.dig[j + n] = t castto(uint32)
		q.dig[j] = qhat castto(uint32)
		/* adjust */
		if t castto(int64) < 0
			q.dig[j]--
			carry = 0
			for i = 0; i < n; i++
				t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry
				u.dig[i+j] = t castto(uint32)
				carry = t >> 32
			;;
			u.dig[j+n] = u.dig[j+n] + (carry castto(uint32));
		;;

	;;
	/* undo the biasing for remainder */
	u = bigshri(u, shift)
	-> (trim(q), trim(u))
}

/* computes b^e % m */
const bigmodpow = {base, exp, mod
	var r, n

	r = mkbigint(1)
	n = 0
    	while !bigiszero(exp)
		if (exp.dig[0] & 1) != 0
			bigmul(r, base)
			bigmod(r, mod)
		;;
		bigshri(exp, 1)
		bigmul(base, base)
		bigmod(base, mod)
	;;
	-> bigmove(base, r)
}

/* returns the number of leading zeros */
const nlz = {a : uint32
	var n

	if a == 0
		-> 32
	;;
	n = 0
	if a <= 0x0000ffff
		n += 16
		a <<= 16
	;;
	if a <= 0x00ffffff
		n += 8
		a <<= 8
	;;
	if a <= 0x0fffffff
		n += 4
		a <<= 4
	;;
	if a <= 0x3fffffff
		n += 2
		a <<= 2
	;;
	if a <= 0x7fffffff
		n += 1
		a <<= 1
	;;
	-> n
}


/* a <<= b */
const bigshl = {a, b
	match b.dig.len
	| 0:	-> a
	| 1:	-> bigshli(a, b.dig[0] castto(uint64))
	| n:	die("shift by way too much\n")
	;;
}

/* a >>= b, unsigned */
const bigshr = {a, b
	match b.dig.len
	| 0:	-> a
	| 1:	-> bigshri(a, b.dig[0] castto(uint64))
	| n:	die("shift by way too much\n")
	;;
}

/* a + b, b is integer.
FIXME: acually make this a performace improvement
*/
generic bigaddi = {a, b
	var bigb : bigint
	var dig : uint32[2]

	bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
	bigadd(a, &bigb)
	-> a
}

generic bigsubi = {a, b : @a::(numeric,integral)
	var bigb : bigint
	var dig : uint32[2]

	bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
	bigsub(a, &bigb)
	-> a
}

generic bigmuli = {a, b
	var bigb : bigint
	var dig : uint32[2]

	bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
	bigmul(a, &bigb)
	-> a
}

generic bigdivi = {a, b
	var bigb : bigint
	var dig : uint32[2]

	bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
	bigdiv(a, &bigb)
	-> a
}

/* 
  a << s, with integer arg.
  logical left shift. any other type would be illogical.
 */
generic bigshli = {a, s : @a::(numeric,integral)
	var off, shift
	var t, carry
	var i

	assert(s >= 0, "shift amount must be positive")
	off = (s castto(uint64)) / 32
	shift = (s castto(uint64)) % 32

	/* zero shifted by anything is zero */
	if a.sign == 0
		-> a
	;;
	a.dig = slzgrow(a.dig, 1 + a.dig.len + off castto(size))
	/* blit over the base values */
	for i = a.dig.len; i > off; i--
		a.dig[i - 1] = a.dig[i - 1 - off]
	;;
	for i = 0; i < off; i++
		a.dig[i] = 0
	;;
	/* and shift over by the remainder */
	carry = 0
	for i = 0; i < a.dig.len; i++
		t = (a.dig[i] castto(uint64)) << shift
		a.dig[i] = (t | carry) castto(uint32) 
		carry = t >> 32
	;;
	-> trim(a)
}

/* logical shift right, zero fills. sign remains untouched. */
generic bigshri = {a, s
	var off, shift
	var t, carry
	var i

	assert(s >= 0, "shift amount must be positive")
	off = (s castto(uint64)) / 32
	shift = (s castto(uint64)) % 32

	/* blit over the base values */
	for i = 0; i < a.dig.len - off; i++
		a.dig[i] = a.dig[i + off]
	;;
	for i = a.dig.len - off; i < a.dig.len; i++
		a.dig[i] = 0
	;;
	/* and shift over by the remainder */
	carry = 0
	for i = a.dig.len; i > 0; i--
		t = (a.dig[i - 1] castto(uint64))
		a.dig[i - 1] = (carry | (t >> shift)) castto(uint32) 
		carry = t << (32 - shift)
	;;
	-> trim(a)
}

/* creates a bigint on the stack; should not be modified. */
const bigdigit = {v, isneg : bool, val : uint64, dig
	v.sign = 1
	if isneg
		val = -val
		v.sign = -1
	;;
	if val == 0
		v.sign = 0
		v.dig = [][:]
	elif val < Base
		v.dig = dig[:1]
		v.dig[0] = val castto(uint32)
	else
		v.dig = dig
		v.dig[0] = val castto(uint32)
		v.dig[1] = (val >> 32) castto(uint32)
	;;
}

/* trims leading zeros */
const trim = {a
	var i

	for i = a.dig.len; i > 0; i--
		if a.dig[i - 1] != 0
			break
		;;
	;;
	a.dig = slgrow(a.dig, i)
	if i == 0
		a.sign = 0
	elif a.sign == 0
		a.sign = 1
	;;
	-> a
}