ref: 93cff6aba077499c3693896cba43b1ceaaca7a62
dir: /lib/flate/flate.myr/
use std
use bio
use "types"
/* DEFLATE  https://www.ietf.org/rfc/rfc1951.txt
   Terminology:
     - A 'Huffman code' is one mechanism to map codes to symbols.
     - A 'code' is the sequence of bits used to encode a symbol
       in a Huffman code.  Codes are what appears in the DEFLATE
       stream.
     - A 'symbol' is what a code decodes to.
     - A 'length' is the length of an LZ77 reference (the number
       of bytes to repeat).
     - A 'distance' is the distance at which the bits referred to
       by an LZ77 reference are from the current point.
*/
pkg flate =
	/* state maintained during decompression */
	type flatedec = struct
		rdf : bio.file#
		wrf : bio.file#
		/* one-byte buffer */
		bval : int /* in [0, 255] current input byte */
		bpos : int /* in [0, 8] */
		/* ring buffer window for the bytes output */
		wpos : std.size
		wlen : std.size
		wbuf : byte[WinSz]
		/* Huffman codes */
		lenhc : hufc /* for lengths */
		dsthc : hufc /* for distances */
	;;
	/* one Huffman code */
	type hufc = struct
		/* # of codes of length [1, MaxBits] */
		count : int[MaxBits + 1]
		/* symbols sorted lexicographically by
		   <code length, symbol value>; c.f., mkhufc()
		   and readcode() for usage */
		symbol : int[MaxSyms]
	;;
	/* decode a DEFLATE stream */
	const decode : (outf : bio.file#, inf : bio.file# -> std.result(void, err))
	pkglocal const MaxSyms : std.size
	pkglocal const MaxBits : std.size
	pkglocal const WinSz : std.size
	pkglocal const bits : (st : flatedec#, n : int -> std.result(int, err))
	pkglocal const nocomp : (st : flatedec# -> std.result(void, err))
	pkglocal const mkhufc : (lengths : int[:], hc : hufc# -> void)
	pkglocal const readcode : (st : flatedec#, hc : hufc# -> std.result(int, err))
	pkglocal const outb : (st : flatedec#, b : byte -> std.result(void, err))
;;
const MaxLenSyms : std.size = 288 /* max # of length symbols */
const MaxDstSyms : std.size = 32 /* max # of distance symbols */
const MaxSyms = MaxLenSyms /* max # of symbols in a Huffman code */
const MaxBits = 15 /* max length of a code (in bits) */
const WinSz = 32 * std.KiB /* max distance for a reference */
var fixedlenhc : hufc
var fixeddsthc : hufc
const __init__ = {
	/* the fixed Huffman codes; c.f., RFC Section 3.2.6 */
	var length : int[MaxSyms]
	std.slfill(length[0:144], 8)
	std.slfill(length[144:256], 9)
	std.slfill(length[256:280], 7)
	std.slfill(length[280:288], 8)
	mkhufc(length[:], &fixedlenhc)
	std.slfill(length[0:32], 5)
	mkhufc(length[0:32], &fixeddsthc)
}
const outb = {st, b
	/* outputs a single byte, to the output file; additionally
	   record the byte emitted in the output window */
	var pos
	match bio.putb(st.wrf, b)
	| `std.Ok 0: std.die("no bytes written by bio.putb()")
	| `std.Ok _:
	| `std.Err e: -> `std.Err (`Io e)
	;;
	pos = st.wpos
	if pos == WinSz
		std.assert(st.wlen >= WinSz, "window should be full")
		pos = 0
	;;
	st.wbuf[pos] = b
	st.wpos = pos + 1
	st.wlen++ /* don't care about it being > WinSz */
	-> `std.Ok void
}
const bits = {st, n
	var pos, len, bits, bval
	std.assert(n >= 0 && n <= 31, "invalid argument")
	pos = st.bpos
	bits = st.bval >> pos
	for len = 8 - pos; len < n; len += 8
		match bio.getb(st.rdf)
		| `std.Ok b:
			bval = (b : int)
			st.bval = bval
			bits += bval << len
		| `std.Err e: -> `std.Err (`Io e)
		;;
	;;
	st.bpos = 8 - (len - n)
	-> `std.Ok (bits & ((1 << n) - 1))
}
const nocomp = {st
	/* process a non-compressed block; c.f., RFC Section 3.2.4 */
	var len
	st.bpos = 8 /* skip to a byte boundary */
	match bits(st, 16)
	| `std.Ok n: len = (n : std.size)
	| `std.Err e: -> `std.Err e
	;;
	match bits(st, 16)
	| `std.Ok nlen:
		if len != (nlen : std.size) ^ 0xffff
			-> `std.Err (`Format \
				"non-compressed block length could " \
				"not be verified")
		;;
	| `std.Err e: -> `std.Err e
	;;
	for var n = 0; n < len; n++
		match bio.getb(st.rdf)
		| `std.Ok b:
			match outb(st, b)
			| `std.Ok _:
			| err: -> err
			;;
		| `std.Err e: -> `std.Err (`Io e)
		;;
	;;
	-> `std.Ok void
}
const mkhufc = {length, hc
	/* note: it is possible to have 0s in the length array; this means
	   that the symbol at this index will never be part of the input
	   stream */
	var index : int[MaxBits + 1]
	std.assert(length.len <= hc.symbol.len, "invalid argument")
	std.slfill(hc.count[:], 0)
	for var symb = 0; symb < length.len; symb++
		hc.count[length[symb]]++
	;;
	std.slfill(index[:], 0)
	hc.count[0] = 0
	for var l = 1; l <= MaxBits; l++
		index[l] = index[l - 1] + hc.count[l - 1]
	;;
	for var symb = 0; symb < length.len; symb++
		if length[symb] != 0
			hc.symbol[index[length[symb]]++] = symb
		;;
	;;
	/* TODO: validate the Huffman code */
}
const readcode = {st, hc
	/* if you want to know more about the algorithm here, you
	   must look for "canonical Huffman codes"; see, for example,
	     Managing Gigabytes: Compressing and Indexing Documents
	       and Images
	     by I. H. Witten, A. Moffat, and T. C. Bell. */
	var code, first, index, count
	code = 0 /* code seen so far */
	first = 0 /* first code of length len */
	index = 0 /* index of the first code of length len in hc.symbol[:] */
	for var len = 1; len <= MaxBits; len++
		match bits(st, 1)
		| `std.Ok n: code += n
		| err: -> err
		;;
		count = hc.count[len]
		if code < first + count
			/* the code is exactly len bits long;
			   return the symbol from the table */
			-> `std.Ok (hc.symbol[index + (code - first)])
		;;
		index += count
		first = (first + count) << 1 /* c.f. RFC Section 3.2.2 */
		code <<= 1 /* make room for the next bit */
	;;
	-> `std.Err (`Format "invalid code")
}
const decomp = {st
	/* process data of a compressed block; c.f., RFC Section 3.2.5 */
	const basel = [
		3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
		35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258
	]
	const extral = [
		0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3,
		3, 4, 4, 4, 4, 5, 5, 5, 5, 0
	]
	const based = [
		1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
		257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
		8193, 12289, 16385, 24577
	]
	const extrad = [
		0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8,
		8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13
	]
	var symb, len, dst, index
	symb = 0
	while symb != 256
		match readcode(st, &st.lenhc)
		| `std.Ok s: symb = s
		| `std.Err e: -> `std.Err e
		;;
		if symb < 256
			match outb(st, (symb : byte))
			| `std.Ok _:
			| err: -> err
			;;
		elif symb > 256
			symb -= 257
			match bits(st, extral[symb])
			| `std.Ok x: len = basel[symb] + x
			| `std.Err e: -> `std.Err e
			;;
			match readcode(st, &st.dsthc)
			| `std.Ok s: symb = s
			| `std.Err e: -> `std.Err e
			;;
			match bits(st, extrad[symb])
			| `std.Ok x: dst = (based[symb] + x : std.size)
			| `std.Err e: -> `std.Err e
			;;
			if dst > st.wlen
				-> `std.Err (`Format \
					"reference to a byte before file start")
			;;
			index = (WinSz + st.wpos - dst) % WinSz
			for ; len > 0; len--
				match outb(st, st.wbuf[index])
				| `std.Ok _:
				| err: -> err
				;;
				index = (index + 1) % WinSz
			;;
		;;
	;;
	-> `std.Ok void
}
const dynamic = {st
	/* process a block compressed with dynamic Huffman codes;
	   c.f., RFC Section 3.2.7 */
	const coden = [
		16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2,
		14, 1, 15
	]
	const extra = [2, 3, 7]
	var length : int[MaxLenSyms + MaxDstSyms]
	var nlen, ndst, ncode, symb, nbits, hufc
	match bits(st, 5)
	| `std.Ok n: nlen = n + 257
	| `std.Err e: -> `std.Err e
	;;
	if nlen > 286
		-> `std.Err (`Format "invalid number of length codes")
	;;
	match bits(st, 5)
	| `std.Ok n: ndst = n + 1
	| `std.Err e: -> `std.Err e
	;;
	if ndst > 30
		-> `std.Err (`Format "invalid number of distance codes")
	;;
	match bits(st, 4)
	| `std.Ok n: ncode = n + 4
	| `std.Err e: -> `std.Err e
	;;
	std.slfill(length[:19], 0)
	for var c = 0; c < ncode; c++
		match bits(st, 3)
		| `std.Ok n: length[coden[c]] = n
		| `std.Err e: -> `std.Err e
		;;
	;;
	mkhufc(length[:19], &hufc)
	for var c = 0; c < nlen + ndst;
		/* according to the RFC: "the code length repeat codes
		   can cross from HLIT + 257 to the HDIST + 1 code
		   lengths"; this is why we have a single loop here */
		match readcode(st, &hufc)
		| `std.Ok n: symb = n
		| `std.Err e: -> `std.Err e
		;;
		if symb > 18
			-> `std.Err (`Format "invalid code code")
		;;
		if symb < 16
			length[c++] = symb
		else
			/* repeat some code length */
			match bits(st, extra[symb - 16])
			| `std.Ok n: nbits = n
			| `std.Err e: -> `std.Err e
			;;
			if symb == 16
				/* copy the previous code length 3-6 times */
				nbits += 3
				if c == 0 || c + nbits > nlen + ndst
					-> `std.Err (`Format \
						"invalid code repetition")
				;;
				std.slfill(length[c:c+nbits], length[c-1])
			else
				/* repeat a 0 length */
				if symb == 17
					nbits += 3
				else
					nbits += 11
				;;
				if c + nbits > nlen + ndst
					-> `std.Err (`Format \
						"invalid zero repetition")
				;;
				std.slfill(length[c:c+nbits], 0)
			;;
			c += nbits
		;;
	;;
	if length[256] == 0
		-> `std.Err (`Format "no code length of end-of-block symbol")
	;;
	mkhufc(length[:nlen], &st.lenhc)
	mkhufc(length[nlen:nlen + ndst], &st.dsthc)
	-> decomp(st)
}
const decode = {outf, inf
	var st, err, last
	st = [
		.rdf = inf,
		.wrf = outf,
		.bpos = 8,
	]
	last = 0
	while last != 1
		match bits(&st, 1)
		| `std.Ok l: last = l
		| `std.Err e: -> `std.Err e
		;;
		match bits(&st, 2)
		| `std.Ok 0:
			err = nocomp(&st)
		| `std.Ok 1:
			st.lenhc = fixedlenhc
			st.dsthc = fixeddsthc
			err = decomp(&st)
		| `std.Ok 2:
			err = dynamic(&st)
		| `std.Ok _:
			-> `std.Err (`Format "invalid block type")
		| `std.Err e:
			-> `std.Err e
		;;
		match err
		| `std.Ok _:
		| `std.Err e: -> `std.Err e
		;;
	;;
	-> `std.Ok void
}