shithub: mc

ref: 814c5bcd2efc3bb9139a56a1e0eae177437f4f56
dir: /lib/crypto/ctbig.myr/

View raw version
use std

use "ct"

pkg crypto =
	type ctbig = struct
		nbit	: std.size
		dig	: uint32[:] 	/* little endian, no leading zeros. */
	;;

	generic mkctbign 	: (v : @a, nbit : std.size -> ctbig#) :: numeric,integral @a

	const ctzero	: (nbit : std.size -> ctbig#)
	const mkctbigle	: (v : byte[:], nbit : std.size -> ctbig#)
	//const mkctbigbe	: (v : byte[:], nbit : std.size -> ctbig#)

	const ctfree	: (v : ctbig# -> void)
	const ctbigdup	: (v : ctbig# -> ctbig#)
	const ctlike	: (v : ctbig# -> ctbig#)
	const ct2big	: (v : ctbig# -> std.bigint#)
	const big2ct	: (v : std.bigint#, nbit : std.size -> ctbig#)

	/* arithmetic */
	const ctadd	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
	const ctsub	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
	const ctmul	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
	//const ctdivmod	: (q : ctbig#, u : ctbig#, a : ctbig#, b : ctbig# -> void)
	//const ctmodpow	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)

	const ctiszero	: (v : ctbig# -> bool)
	const cteq	: (a : ctbig#, b : ctbig# -> bool)
	const ctne	: (a : ctbig#, b : ctbig# -> bool)
	const ctgt	: (a : ctbig#, b : ctbig# -> bool)
	const ctge	: (a : ctbig#, b : ctbig# -> bool)
	const ctlt	: (a : ctbig#, b : ctbig# -> bool)
	const ctle	: (a : ctbig#, b : ctbig# -> bool)

	impl std.equatable ctbig#
;;

const Bits = 32
const Base = 0x100000000ul

impl std.equatable ctbig# =
	eq = {a, b
		-> cteq(a, b)
	}
;;

const __init__ = {
	var ct : ctbig#

	ct = ctzero(0)
	std.fmtinstall(std.typeof(ct), ctfmt)
	ctfree(ct)
}

const ctfmt = {sb, ap, opts
	var ct : ctbig#

	ct = std.vanext(ap)
	for d : ct.dig
		std.sbfmt(sb, "{w=8,p=0,x}", d)
	;;
}

generic mkctbign = {v : @a, nbit : std.size :: integral,numeric @a
	var a
	var val

	a = std.zalloc()

	val = (v : uint64)
	a.nbit = nbit
	a.dig = std.slalloc(ndig(nbit))
	if nbit > 0
		a.dig[0] = (val : uint32)
	;;
	if nbit > 32
		a.dig[1] = (val >> 32 : uint32)
	;;
	-> clip(a)
}

const ctzero = {nbit
	-> std.mk([
		.nbit=nbit,
		.dig=std.slzalloc(ndig(nbit)),
	])
}

const ct2big = {ct
	-> std.mk([
		.sign=1,
		.dig=std.sldup(ct.dig)
	])
}

const big2ct = {big, nbit
	var v, n, l

	n = ndig(nbit)
	l = std.min(n, big.dig.len)
	v = std.slzalloc(n)
	std.slcp(v[:l], big.dig[:l])
	-> clip(std.mk([
		.nbit=nbit,
		.dig=v,
	]))
}

const mkctbigle = {v, nbit
	var a, last, i, o, off

	/*
	  It's ok to depend on the length of v here: we can leak the
	  size of the numbers.
	 */
	o = 0
	a = std.slzalloc(ndig(nbit))
	for i = 0; i + 4 <= v.len; i += 4
		a[o++] = \
			(v[i + 0] <<  0 : uint32) | \
			(v[i + 1] <<  8 : uint32) | \
			(v[i + 2] << 16 : uint32) | \
			(v[i + 3] << 24 : uint32)
	;;

	last = 0
	for i; i < v.len; i++
		off = i & 0x3
		last |= (v[off] : uint32) << (8 *off)
	;;
	a[o++] = last
	-> clip(std.mk([.nbit=nbit, .dig=a]))
}

const ctlike = {v
	-> std.mk([
		.nbit = v.nbit,
		.dig=std.slzalloc(v.dig.len),
	])
}

const ctbigdup = {v
	-> std.mk([
		.nbit=v.nbit,
		.dig=std.sldup(v.dig),
	])
}

const ctfree = {v
	std.slfree(v.dig)
	std.free(v)
}

const ctadd = {r, a, b
	var v, i, carry

	checksz(a, b)
	checksz(a, r)

	carry = 0
	for i = 0; i < a.dig.len; i++
		v = (a.dig[i] : uint64) + (b.dig[i] : uint64) + carry;
		r.dig[i] = (v  : uint32)
		carry = v >> 32
	;;
}

const ctsub = {r, a, b
	var borrow, v, i

	checksz(a, b)
	checksz(a, r)

	borrow = 0
	for i = 0; i < a.dig.len; i++
		v = (a.dig[i] : uint64) - (b.dig[i] : uint64) - borrow
		borrow = (v & (1<<63)) >> 63
		v = mux(borrow, v + Base, v)
		r.dig[i] = (v  : uint32)
	;;
	clip(r)
}

const ctmul = {r, a, b
	var i, j
	var ai, bj, wij
	var carry, t
	var w

	checksz(a, b)
	checksz(a, r)

	w  = std.slzalloc(a.dig.len + b.dig.len)
	for j = 0; j < b.dig.len; j++
		carry = 0
		for i = 0; i < a.dig.len; i++
			ai = (a.dig[i]  : uint64)
			bj = (b.dig[j]  : uint64)
			wij = (w[i+j]  : uint64)
			t = ai * bj + wij + carry
			w[i+j] = (t  : uint32)
			carry = t >> 32
		;;
		w[i + j] = (carry  : uint32)
	;;
	/* safe to leak that a == r; not data dependent */
	std.slgrow(&w, a.dig.len)
	if a == r
		std.slfree(a.dig)
	;;
	r.dig = w[:a.dig.len]
	clip(r)
}

const ctiszero = {a
	var z, zz

	z = 1
	for var i = 0; i < a.dig.len; i++
		zz = mux(a.dig[i], 0, 1)
		z = mux(zz, z, 0)
	;;
	-> (z : bool)
}

const cteq = {a, b
	var z, d, e

	checksz(a, b)

	e = 1
	for var i = 0; i < a.dig.len; i++
		z = a.dig[i] - b.dig[i]
		/* z != 0 ? 0 : 1 */
		d = mux(z, 0, 1)
		e = mux(e, d, 0)
	;;
	-> (e : bool)
}

const ctne = {a, b
	var v

	v = (cteq(a, b) : byte)
	-> (not(v) : bool)
}

const ctgt = {a, b
	var e, d, g

	checksz(a, b)

	g = 0
	for var i = 0; i < a.dig.len; i++
		e = not(a.dig[i] - b.dig[i])
		d = gt(a.dig[i], b.dig[i])
		g = mux(e, g, d) 
	;;
	-> (g : bool)
}

const ctge = {a, b
	var v

	v = (ctlt(a, b) : byte)
	-> (not(v) : bool)
}

const ctlt = {a, b
	var e, d, l

	checksz(a, b)

	l = 0
	for var i = 0; i < a.dig.len; i++
		e = not(a.dig[i] - b.dig[i])
		d = gt(a.dig[i], b.dig[i])
		l = mux(e, l, d) 
	;;
	-> (l : bool)
}

const ctle = {a, b
	var v

	v = (ctgt(a, b) : byte)
	-> (not(v) : bool)
}

const ndig = {nbit
	-> (nbit + 8*sizeof(uint32) - 1)/(8*sizeof(uint32))
}

const checksz = {a, b
	std.assert(a.nbit == b.nbit, "mismatched bit sizes")
	std.assert(a.dig.len == b.dig.len, "mismatched backing sizes")
}

const clip = {v
	var mask, edge

	edge = v.nbit & (Bits - 1)
	mask = (1 << (32 - edge)) - 1
	v.dig[v.dig.len - 1] &= (mask : uint32)
	-> v
}