ref: 90d9a10c0234f5868c2e86882aae72fc931d53fd
dir: /lib/crypto/ctbig.myr/
use std
use iter
use "ct"
pkg crypto =
type ctbig = struct
nbit : std.size
dig : uint32[:] /* little endian, no leading zeros. */
;;
generic mkctbign : (v : @a, nbit : std.size -> ctbig#) :: numeric,integral @a
const ctzero : (nbit : std.size -> ctbig#)
const ctbytesle : (v : ctbig# -> byte[:])
const ctbytesbe : (v : ctbig# -> byte[:])
const mkctbigle : (v : byte[:], nbit : std.size -> ctbig#)
const mkctbigbe : (v : byte[:], nbit : std.size -> ctbig#)
const ctfree : (v : ctbig# -> void)
const ctbigdup : (v : ctbig# -> ctbig#)
pkglocal const ct2big : (v : ctbig# -> std.bigint#)
pkglocal const big2ct : (v : std.bigint#, nbit : std.size -> ctbig#)
/* arithmetic */
pkglocal const ctadd : (r : ctbig#, a : ctbig#, b : ctbig# -> void)
pkglocal const ctsub : (r : ctbig#, a : ctbig#, b : ctbig# -> void)
pkglocal const ctmul : (r : ctbig#, a : ctbig#, b : ctbig# -> void)
pkglocal const ctmodpow : (r : ctbig#, a : ctbig#, b : ctbig#, m : ctbig# -> void)
pkglocal const ctiszero : (v : ctbig# -> bool)
pkglocal const cteq : (a : ctbig#, b : ctbig# -> bool)
pkglocal const ctne : (a : ctbig#, b : ctbig# -> bool)
pkglocal const ctgt : (a : ctbig#, b : ctbig# -> bool)
pkglocal const ctge : (a : ctbig#, b : ctbig# -> bool)
pkglocal const ctlt : (a : ctbig#, b : ctbig# -> bool)
pkglocal const ctle : (a : ctbig#, b : ctbig# -> bool)
/* for testing */
pkglocal const growmod : (r : ctbig#, a : ctbig#, k : uint32, m : ctbig# -> void)
pkglocal const clip : (v : ctbig# -> ctbig#)
impl std.equatable ctbig#
;;
const Bits = 32
const Base = 0x100000000ul
impl std.equatable ctbig# =
eq = {a, b
-> cteq(a, b)
}
;;
const __init__ = {
var ct : ctbig#
ct = ctzero(0)
std.fmtinstall(std.typeof(ct), ctfmt)
ctfree(ct)
}
const ctfmt = {sb, ap, opts
var ct : ctbig#
ct = std.vanext(ap)
for d : iter.byreverse(ct.dig)
std.sbfmt(sb, "{w=8,p=0,x}.", d)
;;
}
generic mkctbign = {v : @a, nbit : std.size :: integral,numeric @a
var a
var val
a = std.zalloc()
val = (v : uint64)
a.nbit = nbit
a.dig = std.slalloc(ndig(nbit))
if nbit > 0
a.dig[0] = (val : uint32)
;;
if nbit > 32
a.dig[1] = (val >> 32 : uint32)
;;
-> clip(a)
}
const ctzero = {nbit
-> std.mk([
.nbit=nbit,
.dig=std.slzalloc(ndig(nbit)),
])
}
const ctdup = {v
-> std.mk([
.nbit=v.nbit,
.dig=std.sldup(v.dig)
])
}
const ct2big = {ct
-> std.mk([
.sign=1,
.dig=std.sldup(ct.dig)
])
}
const big2ct = {big, nbit
var v, n, l
n = ndig(nbit)
l = std.min(n, big.dig.len)
v = std.slzalloc(n)
std.slcp(v[:l], big.dig[:l])
-> clip(std.mk([
.nbit=nbit,
.dig=v,
]))
}
const mkctbigle = {v, nbit
var a, last, i, o, off
/*
It's ok to depend on the length of v here: we can leak the
size of the numbers.
*/
o = 0
a = std.slzalloc(ndig(nbit))
for i = 0; i + 4 <= v.len; i += 4
a[o++] = \
((v[i + 0] : uint32) << 0) | \
((v[i + 1] : uint32) << 8) | \
((v[i + 2] : uint32) << 16) | \
((v[i + 3] : uint32) << 24)
;;
if i != v.len
last = 0
for i; i < v.len; i++
off = i & 0x3
last |= (v[i] : uint32) << (8 *off)
;;
a[o++] = last
;;
-> clip(std.mk([.nbit=nbit, .dig=a]))
}
const mkctbigbe = {v, nbit
var a, i, o, tail : byte[4]
/*
It's ok to depend on the length of v here: we can leak the
size of the numbers.
*/
o = 0
a = std.slzalloc(ndig(nbit))
for i = v.len ; i >= 4; i -= 4
a[o++] = std.getbe32(v[i-4:i])
;;
if i != 0
std.slfill(tail[:], 0)
std.slcp(tail[4-i:], v[:i])
a[o++] = std.getbe32(tail[:])
;;
-> clip(std.mk([.nbit=nbit, .dig=a]))
}
const ctbytesle = {v
var d, i, n, o, ret
o = 0
n = (v.nbit + 7) / 8
ret = std.slalloc(n)
for i = 0; i * 4 < n; i++
d = v.dig[i]
ret[o++] = (d >> 0 : byte)
ret[o++] = (d >> 8 : byte)
ret[o++] = (d >> 16 : byte)
ret[o++] = (d >> 24 : byte)
;;
if i * 4 != n
d = v.dig[i]
for ; i < n; i++
ret[o++] = (d : byte)
d >>= 8
;;
;;
-> ret
}
const ctbytesbe = {v : ctbig#
var d : uint32, i, n, o, ret
i = v.dig.len - 1
o = 0
n = (v.nbit + 7) / 8
ret = std.slalloc(n)
if n & 0x3 != 0
d = v.dig[i--]
for var j = n & 0x3 + 1; j > 0; j--
ret[o++] = (d >> 8*(j - 1 : uint32): byte)
;;
;;
for ; i >= 0 ; i--
d = v.dig[i]
ret[o++] = (d >> 24 : byte)
ret[o++] = (d >> 16 : byte)
ret[o++] = (d >> 8 : byte)
ret[o++] = (d >> 0 : byte)
;;
-> ret
}
const ctbigdup = {v
-> std.mk([
.nbit=v.nbit,
.dig=std.sldup(v.dig),
])
}
const ctfree = {v
std.slfree(v.dig)
std.free(v)
}
const ctadd = {r, a, b
ctaddcc(r, a, b, 1)
}
const ctaddcc = {r, a, b, ctl
var v, i, carry
checksz(a, b)
checksz(a, r)
carry = 0
for i = 0; i < a.dig.len; i++
v = (a.dig[i] : uint64) + (b.dig[i] : uint64) + carry;
r.dig[i] = mux(ctl, (v : uint32), r.dig[i])
carry = v >> 32
;;
clip(r)
}
const ctsub = {r, a, b
ctsubcc(r, a, b, 1)
}
const ctsubcc = {r, a, b, ctl
var borrow, v, i
checksz(a, b)
checksz(a, r)
borrow = 0
for i = 0; i < a.dig.len; i++
v = (a.dig[i] : uint64) - (b.dig[i] : uint64) - borrow
borrow = (v & (1<<63)) >> 63
r.dig[i] = mux(ctl, (v : uint32), r.dig[i])
;;
clip(r)
-> borrow
}
const ctmul = {r, a, b
var i, j
var ai, bj, wij
var carry, t
var w
checksz(a, b)
checksz(a, r)
w = std.slzalloc(a.dig.len + b.dig.len)
for j = 0; j < b.dig.len; j++
carry = 0
for i = 0; i < a.dig.len; i++
ai = (a.dig[i] : uint64)
bj = (b.dig[j] : uint64)
wij = (w[i+j] : uint64)
t = ai * bj + wij + carry
w[i+j] = (t : uint32)
carry = t >> 32
;;
w[i + j] = (carry : uint32)
;;
/* safe to leak that a == r; not data dependent */
std.slgrow(&w, a.dig.len)
if a == r
std.slfree(a.dig)
;;
r.dig = w[:a.dig.len]
clip(r)
}
/*
* Returns the top digit in the number that has
* a bit set. This is useful for finding our division.
*/
const topfull = {n : ctbig#
var top
top = 0
for var i = 0; i < n.dig.len; i++
top = mux(n.dig[i], i, top)
;;
-> 0
}
const unalignedword = {v, bit
var lo, hi, s, i
s = (bit & 0x1f : uint32)
i = (bit >> 5 : uint32)
lo = v.dig[i]
if s == 0
hi = 0
else
hi = v.dig[i + 1]
;;
-> (lo >> s) | (hi << (32 - s))
}
/*
* Multiplies by 2**32 mod m
*/
const growmod = {r, a, k, m
var a0, a1, b0, hi, g, q, tb, e
var chf, clow, under, over
var cc : uint64
checksz(a, m)
std.assert(a.dig.len > 1, "bad modulus\n")
std.assert(m.dig[m.dig.len - 1] & (1 << 31) != 0, "top of mod not set: m={}, nbit={}\n", m, m.nbit)
std.assert(m.nbit % 32 == 0, "ragged sizes not yet supported: a.nbit=={}\n", a.nbit)
a0 = (unalignedword(a, a.nbit - 32) : uint64) << 32
a1 = (unalignedword(a, a.nbit - 64) : uint64) << 0
b0 = (unalignedword(m, m.nbit - 32) : uint64)
/*
* We hold the top digit here, so
* this keeps the number of digits the same, and
* as a result, keeps checksz() happy.
*/
hi = a.dig[a.dig.len - 1]
/* Do the multiplication of x by 2**32 */
std.slcp(r.dig[1:], a.dig[:a.dig.len-1])
r.dig[0] = k
g = ((a0 + a1) / b0 : uint32)
e = eq(a0, b0)
q = mux((e : uint32), 0xffffffff, mux(eq(g, 0), 0, g - 1));
cc = 0;
tb = 1;
for var u = 0; u < r.dig.len; u++
var mw, zw, xw, nxw
var zl : uint64
mw = m.dig[u];
zl = (mw : uint64) * (q : uint64) + cc
cc = zl >> 32
zw = (zl : uint32)
xw = r.dig[u]
nxw = xw - zw;
cc += (gt(nxw, xw) : uint64)
r.dig[u] = nxw;
tb = mux(eq(nxw, mw), tb, gt(nxw, mw));
;;
/*
* We can either underestimate or overestimate q,
* - If we overestimated, either cc < hi, or cc == hi && tb != 0.
* - If we overestimated, cc > hi.
* - Otherwise, we got it exactly right.
*
* If we overestimated, we need to subtract 'm' once. If we
* underestimated, we need to add it once.
*/
chf = (cc >> 32 : uint32)
clow = (cc >> 0 : uint32)
over = chf | gt(clow, hi);
under = ~over & (tb | (~chf & lt(clow, hi)));
ctaddcc(r, r, m, over);
ctsubcc(r, r, m, under);
clip(r)
}
const tomonty = {r, x, m
checksz(x, r)
checksz(x, m)
std.slcp(r.dig, x.dig)
for var i = 0; i < m.dig.len; i++
growmod(r, r, 0, m)
;;
}
const ccopy = {r, v, ctl
checksz(r, v)
for var i = 0; i < r.dig.len; i++
r.dig[i] = mux(ctl, v.dig[i], r.dig[i])
;;
}
const muladd = {a, b, k
-> (a : uint64) * (b : uint64) + (k : uint64)
}
const montymul = {r : ctbig#, x : ctbig#, y : ctbig#, m : ctbig#, m0i : uint32
var dh : uint64
var s
checksz(x, y)
checksz(x, m)
checksz(x, r)
std.slfill(r.dig, 0)
dh = 0
for var u = 0; u < x.dig.len; u++
var f : uint32, xu : uint32
var r1 : uint64, r2 : uint64, zh : uint64
xu = x.dig[u]
f = (r.dig[0] + x.dig[u] * y.dig[0]) * m0i;
r1 = 0;
r2 = 0;
for var v = 0; v < y.dig.len; v++
var z : uint64
var t : uint32
z = muladd(xu, y.dig[v], r.dig[v]) + r1
r1 = z >> 32
t = (z : uint32)
z = muladd(f, m.dig[v], t) + r2
r2 = z >> 32
if v != 0
r.dig[v - 1] = (z : uint32)
;;
;;
zh = dh + r1 + r2;
r.dig[r.dig.len - 1] = (zh : uint32)
dh = zh >> 32;
;;
/*
* r may still be greater than m at that point; notably, the
* 'dh' word may be non-zero.
*/
s = ne(dh, 0) | (ctge(r, m) : uint64)
ctsubcc(r, r, m, (s : uint32))
}
const ninv32 = {x
var y
y = 2 - x
y *= 2 - y * x
y *= 2 - y * x
y *= 2 - y * x
y *= 2 - y * x
-> mux(x & 1, -y, 0)
}
const ctmodpow = {r, a, e, m
var t1, t2, m0i, ctl
var n = 0
t1 = ctdup(a)
t2 = ctzero(a.nbit)
m0i = ninv32(m.dig[0])
tomonty(t1, a, m);
std.slfill(r.dig, 0);
r.dig[0] = 1;
for var i = 0; i < e.nbit; i++
ctl = (e.dig[i>>5] >> (i & 0x1f : uint32)) & 1
montymul(t2, r, t1, m, m0i)
ccopy(r, t2, ctl);
montymul(t2, t1, t1, m, m0i);
std.slcp(t1.dig, t2.dig);
;;
ctfree(t1)
ctfree(t2)
}
const ctiszero = {a
var z, zz
z = 1
for var i = 0; i < a.dig.len; i++
zz = mux(a.dig[i], 0, 1)
z = mux(zz, z, 0)
;;
-> (z : bool)
}
const cteq = {a, b
var nz
checksz(a, b)
nz = 0
for var i = 0; i < a.dig.len; i++
nz = nz | a.dig[i] - b.dig[i]
;;
-> (eq(nz, 0) : bool)
}
const ctne = {a, b
var v
v = (cteq(a, b) : byte)
-> (not(v) : bool)
}
const ctgt = {a, b
-> (ctsubcc(b, b, a, 0) : bool)
}
const ctge = {a, b
var v
v = (ctlt(a, b) : byte)
-> (not(v) : bool)
}
const ctlt = {a, b
-> (ctsubcc(a, a, b, 0) : bool)
}
const ctle = {a, b
var v
v = (ctgt(a, b) : byte)
-> (not(v) : bool)
}
const ndig = {nbit
-> (nbit + 8*sizeof(uint32) - 1)/(8*sizeof(uint32))
}
const checksz = {a, b
std.assert(a.nbit == b.nbit, "mismatched bit sizes")
std.assert(a.dig.len == b.dig.len, "mismatched backing sizes")
}
const clip = {v
var mask, edge : uint64
edge = (v.nbit : uint64) & (Bits - 1)
mask = mux(edge, (1 << edge) - 1, ~0)
v.dig[v.dig.len - 1] &= (mask : uint32)
-> v
}