shithub: mc

ref: f762f96f07cf2dfc7b3bb745d6a767dccda0a01a
dir: /lib/crypto/aesgcm.myr/

View raw version
use std

use "aes"
use "ct"

pkg crypto =
	type aesgcmctx = struct
		aes	: aesctx
		ctr	: byte[16]
		j0	: byte[16]
		/* FIXME: constant time gfmul needed here. */
		M	: uint32[4][256][16]
		H	: uint32[4]
	;;
	const aesgcminit	: (c : aesgcmctx#, key : byte[:], iv : byte[:] -> void)
	const aesgcmfin		: (c : aesgcmctx# -> void)
	const aesgcmencrypt	: (c : aesgcmctx#, buf : byte[:], aad : byte[:], tag : byte[:] -> void)
	const aesgcmdecrypt	: (c : aesgcmctx#, buf : byte[:], aad : byte[:], tag : byte[:] -> bool)
;;

const aesgcminit = {c, key, iv
	var X, L, Y, b : byte[16]

	/* Set up H for ghash */
	std.slfill(b[:], 0)
	aeskeysched(&c.aes, key)
	aesencrypt(&c.aes, b[:], b[:])
	load128(b[:], c.H[:])

	/* Init M table. TODO: constant time gfmul */
	for var i=0; i < 16; i++
		for var j=0; j < 256; j++
			X = [0, 0, 0, 0]
			X[i>>2] = j<<((i&3)<<3)
			gfmul(X, c.H[:], c.M[i][j][:])
		;;
	;;

	/* Set up iv */
	if iv.len == (96/8)
		std.slcp(c.ctr[:96/8], iv)
		std.slfill(c.ctr[96/8:], 0)
		inc(c)
	else
		L = [0,0,0,0]
		Y = [0,0,0,0]
		ghash(c, iv, Y[:])
		L[0] = (iv.len << 3 : uint32)
		L[1] = (iv.len >> 29 : uint32)
		L[2] = 0
		L[3] = 0
		ghash1(c, L[:], Y[:])
		store128(Y[:], c.ctr[:])
	;;

	/* set up J0 */
	aesencrypt(&c.aes, c.ctr[:], c.j0[:])
}

const aesgcmfin = {c
	std.slfill(c.ctr[:], 0)
}

const aesgcmencrypt = {c, buf, aad, tag
	var L, Y

	L = [0,0,0,0]
	Y = [0,0,0,0]
	ghash(c, aad, Y[:])
	aesctr(c, buf)
	ghash(c, buf, Y[:])
	L[0] = buf.len << 3
	L[1] = buf.len >> 29
	L[2] = aad.len << 3
	L[3] = aad.len >> 29
	ghash1(c, L[:], Y[:])
	store128(Y[:], tag)
	for var i = 0; i < 16; i++
		tag[i] ^= c.j0[i]
	;;
}

const aesgcmdecrypt = {c, buf, aad, tag
	var ctag : byte[16], tmp : byte[16]
	var L, Y

	L = [0,0,0,0]
	Y = [0,0,0,0]
	ghash(c, aad, Y[:])
	ghash(c, buf, Y[:])
	L[0] = buf.len << 3
	L[1] = buf.len >> 29
	L[2] = aad.len << 3
	L[3] = aad.len >> 29
	ghash1(c, L[:], Y[:])
	store128(Y[:], ctag[:])
	for var i = 0; i < 16; i++
		ctag[i] ^= c.j0[i]
	;;
	aesctr(c, buf)
	-> bufeq(tag, ctag[:])
}

const ghash = {c, buf, Y
	var tmp : byte[16]
	var X : uint32[4]
	var len, off

	len = buf.len
	off = 0
	while len >= 16
		load128(buf[off:off+16], X[:])
		ghash1(c, X[:], Y[:])
		off += 16
		len -= 16
	;;
	if len > 0
		std.slcp(tmp[:len], buf[off:off+len])
		std.slfill(tmp[len:], 0)
		load128(tmp[:], X[:])
		ghash1(c, X[:], Y[:])
	;;
}

const ghash1 = {c, X, Y
	var Xi

	X[0] ^= Y[0]
	X[1] ^= Y[1]
	X[2] ^= Y[2]
	X[3] ^= Y[3]
	Y[0] = 0
	Y[1] = 0
	Y[2] = 0
	Y[3] = 0
	for var i=0; i<16; i++
		Xi = c.M[i][(X[i>>2]>>((i&3)<<3))&0xFF];
		Y[0] ^= Xi[0];
		Y[1] ^= Xi[1];
		Y[2] ^= Xi[2];
		Y[3] ^= Xi[3];
	;;
}

const aesctr = {c, buf
	var tmp : byte[16]
	var o, r

	for var i = 0; i + 16 <= buf.len; i += 16
		inc(c)
		aesencrypt(&c.aes, c.ctr[:], tmp[:])
		for var j = 0; j < 16; j++
			buf[i+j] ^= tmp[j]
		;;
	;;

	r = buf.len & 0xf
	if r != 0
		inc(c)
		o = buf.len & ~0xf
		aesencrypt(&c.aes, c.ctr[:], tmp[:])
		for var i = 0; i < r; i++
			buf[o++] ^= tmp[i]
		;;
	;;
}

const inc = {c
	var ctr

	ctr = std.getbe32(c.ctr[12:16]) + 1
	std.putbe32(c.ctr[12:16], ctr)
}

const gfmul = {X, Y, Z
	var m : int32, i : int32

	Z[0] = 0
	Z[1] = 0
	Z[2] = 0
	Z[3] = 0
	for var i=127; i >= 0; i--
		m = ((Y[i>>5]:int32) << (31-(i&31))) >> 31
		Z[0] ^= X[0] & (m : uint32)
		Z[1] ^= X[1] & (m : uint32)
		Z[2] ^= X[2] & (m : uint32)
		Z[3] ^= X[3] & (m : uint32)
		m = ((X[0]:int32)<<31) >> 31
		X[0] = X[0]>>1 | X[1]<<31
		X[1] = X[1]>>1 | X[2]<<31
		X[2] = X[2]>>1 | X[3]<<31
		X[3] = X[3]>>1 ^ (0xE1000000 & m : uint32)
	;;
}

const load128 = {b, B
	B[0] = \
		(b[15] : uint32) <<  0 | \
		(b[14] : uint32) <<  8 | \
		(b[13] : uint32) << 16 | \
		(b[12] : uint32) << 24
	B[1] = \
		(b[11] : uint32) <<  0 | \
		(b[10] : uint32) <<  8 | \
		(b[ 9] : uint32) << 16 | \
		(b[ 8] : uint32) << 24
	B[2] = \
		(b[ 7] : uint32) <<  0 | \
		(b[ 6] : uint32) <<  8 | \
		(b[ 5] : uint32) << 16 | \
		(b[ 4] : uint32) << 24
	B[3] = \
		(b[ 3] : uint32) <<  0 | \
		(b[ 2] : uint32) <<  8 | \
		(b[ 1] : uint32) << 16 | \
		(b[ 0] : uint32) << 24
}

const store128 = {B : uint32[:], b : byte[:]
	b[15] = (B[0] >>  0 : byte)
	b[14] = (B[0] >>  8 : byte)
	b[13] = (B[0] >> 16 : byte)
	b[12] = (B[0] >> 24 : byte)
	b[11] = (B[1] >>  0 : byte)
	b[10] = (B[1] >>  8 : byte)
	b[ 9] = (B[1] >> 16 : byte)
	b[ 8] = (B[1] >> 24 : byte)
	b[ 7] = (B[2] >>  0 : byte)
	b[ 6] = (B[2] >>  8 : byte)
	b[ 5] = (B[2] >> 16 : byte)
	b[ 4] = (B[2] >> 24 : byte)
	b[ 3] = (B[3] >>  0 : byte)
	b[ 2] = (B[3] >>  8 : byte)
	b[ 1] = (B[3] >> 16 : byte)
	b[ 0] = (B[3] >> 24 : byte)
}