shithub: mc

ref: 6e944f5fe207cd0294eaa4cd023533771fd403d8
dir: /lib/crypto/ctbig.myr/

View raw version
use std
use iter

use "ct"

pkg crypto =
	type ctbig = struct
		nbit	: std.size
		dig	: uint32[:] 	/* little endian, no leading zeros. */
	;;

	generic mkctbign 	: (v : @a, nbit : std.size -> ctbig#) :: numeric,integral @a
	const ctzero	: (nbit : std.size -> ctbig#)
	const ctbytesle	: (v : ctbig# -> byte[:])
	const ctbytesbe	: (v : ctbig# -> byte[:])
	const mkctbigle	: (v : byte[:], nbit : std.size -> ctbig#)
	const mkctbigbe	: (v : byte[:], nbit : std.size -> ctbig#)

	const ctfree	: (v : ctbig# -> void)
	const ctbigdup	: (v : ctbig# -> ctbig#)
	pkglocal const ct2big	: (v : ctbig# -> std.bigint#)
	pkglocal const big2ct	: (v : std.bigint#, nbit : std.size -> ctbig#)

	/* arithmetic */
	pkglocal const ctadd	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
	pkglocal const ctsub	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
	pkglocal const ctmul	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
	pkglocal const ctmodpow	: (r : ctbig#, a : ctbig#, b : ctbig#, m : ctbig# -> void)

	pkglocal const ctiszero	: (v : ctbig# -> bool)
	pkglocal const cteq	: (a : ctbig#, b : ctbig# -> bool)
	pkglocal const ctne	: (a : ctbig#, b : ctbig# -> bool)
	pkglocal const ctgt	: (a : ctbig#, b : ctbig# -> bool)
	pkglocal const ctge	: (a : ctbig#, b : ctbig# -> bool)
	pkglocal const ctlt	: (a : ctbig#, b : ctbig# -> bool)
	pkglocal const ctle	: (a : ctbig#, b : ctbig# -> bool)

	/* for testing */
	pkglocal const growmod	: (r : ctbig#, a : ctbig#, k : uint32, m : ctbig# -> void)
	pkglocal const clip	: (v : ctbig# -> ctbig#)

	impl std.equatable ctbig#
;;

const Bits = 32
const Base = 0x100000000ul

impl std.equatable ctbig# =
	eq = {a, b
		-> cteq(a, b)
	}
;;

const __init__ = {
	var ct : ctbig#

	ct = ctzero(0)
	std.fmtinstall(std.typeof(ct), ctfmt)
	ctfree(ct)
}

const ctfmt = {sb, ap, opts
	var ct : ctbig#

	ct = std.vanext(ap)
	for d : iter.byreverse(ct.dig)
		std.sbfmt(sb, "{w=8,p=0,x}.", d)
	;;
}

generic mkctbign = {v : @a, nbit : std.size :: integral,numeric @a
	var a
	var val

	a = std.zalloc()

	val = (v : uint64)
	a.nbit = nbit
	a.dig = std.slalloc(ndig(nbit))
	if nbit > 0
		a.dig[0] = (val : uint32)
	;;
	if nbit > 32
		a.dig[1] = (val >> 32 : uint32)
	;;
	-> clip(a)
}

const ctzero = {nbit
	-> std.mk([
		.nbit=nbit,
		.dig=std.slzalloc(ndig(nbit)),
	])
}

const ctdup = {v
	-> std.mk([
		.nbit=v.nbit,
		.dig=std.sldup(v.dig)
	])
}

const ct2big = {ct
	-> std.mk([
		.sign=1,
		.dig=std.sldup(ct.dig)
	])
}

const big2ct = {big, nbit
	var v, n, l

	n = ndig(nbit)
	l = std.min(n, big.dig.len)
	v = std.slzalloc(n)
	std.slcp(v[:l], big.dig[:l])
	-> clip(std.mk([
		.nbit=nbit,
		.dig=v,
	]))
}

const mkctbigle = {v, nbit
	var a, last, i, o, off

	/*
	  It's ok to depend on the length of v here: we can leak the
	  size of the numbers.
	 */
	o = 0
	a = std.slzalloc(ndig(nbit))
	for i = 0; i + 4 <= v.len; i += 4
		a[o++] = \
			((v[i + 0] : uint32) <<  0) | \
			((v[i + 1] : uint32) <<  8) | \
			((v[i + 2] : uint32) << 16) | \
			((v[i + 3] : uint32) << 24)
	;;

	if i != v.len
		last = 0
		for i; i < v.len; i++
			off = i & 0x3
			last |= (v[i] : uint32) << (8 *off)
		;;
		a[o++] = last
	;;
	-> clip(std.mk([.nbit=nbit, .dig=a]))
}

const mkctbigbe = {v, nbit
	var a, i, o, tail : byte[4]

	/*
	  It's ok to depend on the length of v here: we can leak the
	  size of the numbers.
	 */
	o = 0
	a = std.slzalloc(ndig(nbit))
	for i = v.len ; i >= 4; i -= 4
		a[o++] = std.getbe32(v[i-4:i])
	;;

	if i != 0
		std.slfill(tail[:], 0)
		std.slcp(tail[4-i:], v[:i])
		a[o++] = std.getbe32(tail[:])
	;;
	-> clip(std.mk([.nbit=nbit, .dig=a]))
}

const ctbytesle = {v
	var d, i, n, o, ret

	o = 0
	n = (v.nbit + 7) / 8
	ret = std.slalloc(n)
	for i = 0; i * 4  < n; i++
		d = v.dig[i]
		ret[o++] = (d >>  0 : byte)
		ret[o++] = (d >>  8 : byte)
		ret[o++] = (d >> 16 : byte)
		ret[o++] = (d >> 24 : byte)
	;;

	if i * 4 != n
		d = v.dig[i]
		for ; i < n; i++
			ret[o++] = (d : byte)
			d >>= 8
		;;
	;;
	-> ret
}

const ctbytesbe = {v : ctbig#
	var d : uint32, i, n, o, ret

	i = v.dig.len - 1
	o = 0
	n = (v.nbit + 7) / 8
	ret = std.slalloc(n)
	if n & 0x3 != 0
		d = v.dig[i--]
		for var j = n & 0x3 + 1; j > 0; j--
			ret[o++] = (d >> 8*(j - 1 : uint32): byte)
		;;
	;;
	for ; i >= 0 ; i--
		d = v.dig[i]
		ret[o++] = (d >> 24 : byte)
		ret[o++] = (d >> 16 : byte)
		ret[o++] = (d >>  8 : byte)
		ret[o++] = (d >>  0 : byte)
	;;
	-> ret
}

const ctbigdup = {v
	-> std.mk([
		.nbit=v.nbit,
		.dig=std.sldup(v.dig),
	])
}

const ctfree = {v
	std.slfree(v.dig)
	std.free(v)
}

const ctadd = {r, a, b
	ctaddcc(r, a, b, 1)
}

const ctaddcc = {r, a, b, ctl
	var v, i, carry

	checksz(a, b)
	checksz(a, r)

	carry = 0
	for i = 0; i < a.dig.len; i++
		v = (a.dig[i] : uint64) + (b.dig[i] : uint64) + carry;
		r.dig[i] = mux(ctl, (v  : uint32), r.dig[i])
		carry = v >> 32
	;;
	clip(r)
}

const ctsub = {r, a, b
	ctsubcc(r, a, b, 1)
}

const ctsubcc = {r, a, b, ctl
	var borrow, v, i

	checksz(a, b)
	checksz(a, r)

	borrow = 0
	for i = 0; i < a.dig.len; i++
		v = (a.dig[i] : uint64) - (b.dig[i] : uint64) - borrow
		borrow = (v & (1<<63)) >> 63
		r.dig[i] = mux(ctl, (v  : uint32), r.dig[i])
	;;
	clip(r)
	-> borrow
}

const ctmul = {r, a, b
	var i, j
	var ai, bj, wij
	var carry, t
	var w

	checksz(a, b)
	checksz(a, r)

	w  = std.slzalloc(a.dig.len + b.dig.len)
	for j = 0; j < b.dig.len; j++
		carry = 0
		for i = 0; i < a.dig.len; i++
			ai = (a.dig[i]  : uint64)
			bj = (b.dig[j]  : uint64)
			wij = (w[i+j]  : uint64)
			t = ai * bj + wij + carry
			w[i+j] = (t  : uint32)
			carry = t >> 32
		;;
		w[i + j] = (carry  : uint32)
	;;
	/* safe to leak that a == r; not data dependent */
	std.slgrow(&w, a.dig.len)
	if a == r
		std.slfree(a.dig)
	;;
	r.dig = w[:a.dig.len]
	clip(r)
}

/*
 * Returns the top digit in the number that has
 * a bit set. This is useful for finding our division.
 */
 const topfull = {n : ctbig#
	var top

	top = 0
	for var i = 0; i < n.dig.len; i++
		top = mux(n.dig[i], i, top)
	;;
	-> 0
}

const unalignedword = {v, bit
	var lo, hi, s, i

	s = (bit & 0x1f : uint32)
	i = (bit >> 5 : uint32)
	lo = v.dig[i] 
	if s == 0
		hi = 0
	else
		hi = v.dig[i + 1] 
	;;
	-> (lo >> s) | (hi << (32 - s))
}

/*
 * Multiplies by 2**32 mod m
 */
const growmod = {r, a, k, m
	var a0, a1, b0, hi, g, q, tb, e
	var chf, clow, under, over
	var cc : uint64

	checksz(a, m)
	std.assert(a.dig.len > 1, "bad modulus\n")
	std.assert(m.dig[m.dig.len - 1] & (1 << 31) != 0, "top of mod not set: m={}, nbit={}\n", m, m.nbit)
	std.assert(m.nbit % 32 == 0, "ragged sizes not yet supported: a.nbit=={}\n", a.nbit)

	a0 = (unalignedword(a, a.nbit - 32) : uint64) << 32
	a1 = (unalignedword(a, a.nbit - 64) : uint64) << 0
	b0 = (unalignedword(m, m.nbit - 32) : uint64)
	
	/* 
	 * We hold the top digit here, so 
	 * this keeps the number of digits the same, and
	 * as a result, keeps checksz() happy.
	 */
	hi = a.dig[a.dig.len - 1]

	/* Do the multiplication of x by 2**32 */
	std.slcp(r.dig[1:], a.dig[:a.dig.len-1])
	r.dig[0] = k
	g = ((a0 + a1) / b0 : uint32)
	e = eq(a0, b0)
	q = mux((e : uint32), 0xffffffff, mux(eq(g, 0), 0, g - 1));

	cc = 0;
	tb = 1;
	for var u = 0; u < r.dig.len; u++
		var mw, zw, xw, nxw
		var zl : uint64

		mw = m.dig[u];
		zl = (mw : uint64) * (q : uint64) + cc
		cc = zl >> 32
		zw = (zl : uint32)
		xw = r.dig[u]
		nxw = xw - zw;
		cc += (gt(nxw, xw) : uint64)
		r.dig[u] = nxw;
		tb = mux(eq(nxw, mw), tb, gt(nxw, mw));
	;;

	/*
	 * We can either underestimate or overestimate q, 
	 *  - If we overestimated, either cc < hi, or cc == hi && tb != 0.
	 *  - If we overestimated, cc > hi.
	 *  - Otherwise, we got it exactly right.
	 * 
	 * If we overestimated, we need to subtract 'm' once. If we
	 * underestimated, we need to add it once.
	 */
	chf = (cc >> 32 : uint32)
	clow = (cc >> 0 : uint32)
	over = chf | gt(clow, hi);
	under = ~over & (tb | (~chf & lt(clow, hi)));
	ctaddcc(r, r, m, over);
	ctsubcc(r, r, m, under);
	clip(r)

}

const tomonty = {r, x, m
	checksz(x, r)
	checksz(x, m)

	std.slcp(r.dig, x.dig)
	for var i = 0; i < m.dig.len; i++
		growmod(r, r, 0, m)
	;;
}

const ccopy = {r, v, ctl
	checksz(r, v)
	for var i = 0; i < r.dig.len; i++
		r.dig[i] = mux(ctl, v.dig[i], r.dig[i])
	;;
}

const muladd = {a, b, k
	-> (a : uint64) * (b : uint64) + (k : uint64)
}

const montymul = {r : ctbig#, x : ctbig#, y : ctbig#, m : ctbig#, m0i : uint32
	var dh : uint64
	var s

	checksz(x, y)
	checksz(x, m)
	checksz(x, r)

	std.slfill(r.dig, 0)
	dh = 0
	for var u = 0; u < x.dig.len; u++
		var f : uint32, xu : uint32
		var r1 : uint64, r2 : uint64, zh : uint64

		xu = x.dig[u]
		f = (r.dig[0] + x.dig[u] * y.dig[0]) * m0i;
		r1 = 0;
		r2 = 0;
		for var v = 0; v < y.dig.len; v++
			var z : uint64
			var t : uint32

			z = muladd(xu, y.dig[v], r.dig[v]) + r1
			r1 = z >> 32
			t = (z : uint32)
			z = muladd(f, m.dig[v], t) + r2
			r2 = z >> 32
			if v != 0
				r.dig[v - 1] = (z : uint32)
			;;
		;;
		zh = dh + r1 + r2;
		r.dig[r.dig.len - 1] = (zh : uint32)
		dh = zh >> 32;
	;;

	/*
	 * r may still be greater than m at that point; notably, the
	 * 'dh' word may be non-zero.
	 */
	s = ne(dh, 0) | (ctge(r, m) : uint64)
	ctsubcc(r, r, m, (s : uint32))
}

const ninv32 = {x
	var y

	y = 2 - x
	y *= 2 - y * x
	y *= 2 - y * x
	y *= 2 - y * x
	y *= 2 - y * x
	-> mux(x & 1, -y, 0)
}

const ctmodpow = {r, a, e, m
	var t1, t2, m0i, ctl
	var n = 0

	t1 = ctdup(a)
	t2 = ctzero(a.nbit)
	m0i = ninv32(m.dig[0])

	tomonty(t1, a, m);
	std.slfill(r.dig, 0);
	r.dig[0] = 1;
	for var i = 0; i < e.nbit; i++
		ctl = (e.dig[i>>5] >> (i & 0x1f : uint32)) & 1
		montymul(t2, r, t1, m, m0i)
		ccopy(r, t2, ctl);
		montymul(t2, t1, t1, m, m0i);
		std.slcp(t1.dig, t2.dig);
	;;
	ctfree(t1)
	ctfree(t2)
}

const ctiszero = {a
	var z, zz

	z = 1
	for var i = 0; i < a.dig.len; i++
		zz = mux(a.dig[i], 0, 1)
		z = mux(zz, z, 0)
	;;
	-> (z : bool)
}

const cteq = {a, b
	var nz

	checksz(a, b)
	nz = 0
	for var i = 0; i < a.dig.len; i++
		nz = nz | a.dig[i] - b.dig[i]
	;;
	-> (eq(nz, 0) : bool)
}

const ctne = {a, b
	var v

	v = (cteq(a, b) : byte)
	-> (not(v) : bool)
}

const ctgt = {a, b
	-> (ctsubcc(b, b, a, 0) : bool)
}

const ctge = {a, b
	var v

	v = (ctlt(a, b) : byte)
	-> (not(v) : bool)
}

const ctlt = {a, b
	-> (ctsubcc(a, a, b, 0) : bool)
}

const ctle = {a, b
	var v

	v = (ctgt(a, b) : byte)
	-> (not(v) : bool)
}

const ndig = {nbit
	-> (nbit + 8*sizeof(uint32) - 1)/(8*sizeof(uint32))
}

const checksz = {a, b
	std.assert(a.nbit == b.nbit, "mismatched bit sizes")
	std.assert(a.dig.len == b.dig.len, "mismatched backing sizes")
}

const clip = {v
	var mask, edge : uint64

	edge = (v.nbit : uint64) & (Bits - 1)
	mask = mux(edge, (1 << edge) - 1, ~0)
	v.dig[v.dig.len - 1] &= (mask : uint32)
	-> v
}