ref: d0c579ea046b64d878d6fd0d9bcc8f0af9fdde73
dir: /lib/regex/compile.myr/
use std
use "types"
use "ranges"
pkg regex =
	const parse	: (re : byte[:]	-> std.result(ast#, status))
	const compile	: (re : byte[:] -> std.result(regex#, status))
	const dbgcompile	: (re : byte[:], trace : bool -> std.result(regex#, status))
	const free	: (re : regex# -> void)
;;
type parseresult = union
	`Some ast#
	`None
	`Err status
;;
/* Compiles a pattern into a regex */
const compile = {pat
	-> regexcompile(std.mk([.pat = pat, .nmatch = 1]), 0)
}
const parse = {pat
	var re
	re = std.mk([.pat = pat, .nmatch = 1])
	match regexparse(re)
	| `None:	-> `std.Err `Incomplete
	| `Err f:	-> `std.Err f
	| `Some t:
		if re.pat.len != re.idx
			-> `std.Err `Incomplete
		else
			-> `std.Ok t
		;;
	;;
}
/* Compiles a pattern into a debug regex. This can be verbose. */
const dbgcompile = {pat, trace
	var re, res
	re = std.mk([
		.pat = pat,
		.nmatch = 1,
		.debug = true,
		.trace = trace,
		.astloc = std.mkht(std.ptrhash, std.ptreq),
	])
	res = regexcompile(re, 0)
	-> res
}
/* compiles a pattern into an allocated regex */
const regexcompile = {re, id
	match regexparse(re)
	| `None:	-> `std.Err (`Incomplete)
	| `Err f:	-> `std.Err f
	| `Some t:
		/*
		we can bail out of parsing early if we get 
		an incorrectly encoded char
		*/
		if re.pat.len != re.idx
			astfree(t)
			-> `std.Err (`Incomplete)
		;;
		dump(re, t, 0)
		append(re, `Ilbra 0, t)
		gen(re, t)
		append(re, `Irbra 0, t)
		append(re, `Imatch id, t)
		idump(re)
		astfree(t)
		-> `std.Ok re
	;;
	-> `std.Err (`Noimpl)
}
const free = {re
	/* all the threads should be dead,
	 so we shouldn't have to free any*/
	std.slfree(re.prog)
	if re.debug
		std.htfree(re.astloc)
		std.slfree(re.pcidx)
		for bs : re.traces
			std.bsfree(bs)
		;;
		std.slfree(re.traces)
	;;
	std.free(re)
}
/* generates bytecode from an AST */
const gen = {re, t
	match t#
	|`Alt	(a, b): genalt(re, a, b, t)
	|`Cat	(a, b): gen(re, a); gen(re, b)
	/* repetition */
	|`Star	(a, r):	genstar(re, a, r, t)
	|`Plus	(a, r):	gen(re, a); genstar(re, a, r, t)
	|`Quest	a:	genquest(re, a)
	/* end matches */
	|`Chr	c:	genchar(re, c, t)
	|`Ranges  sl:	genranges(re, sl, t)
	/* meta */
	|`Bol:	append(re, `Ibol, t)
	|`Eol:	append(re, `Ibol, t)
	|`Bow:	append(re, `Ibow, t)
	|`Eow:	append(re, `Ieow, t)
	|`Cap	(m, a):
		append(re, `Ilbra m, t)
		gen(re, a)
		append(re, `Irbra m, t)
	;;
	-> re.proglen
}
const genranges = {re, sl, ast
	var lbuf : byte[4], hbuf : byte[4], boundbuf : byte[4]
	var lsz, hsz, bsz, i
	var rt : rangetrie#
	/* generate a trie of ranges */
	rt = std.zalloc()
	for r : sl
		/* 
		encode:
			lo => bounds[loidx] - 1
			bounds[loidx] => bounds[loidx + 1] - 1
			...
			bounds[hiidx - 1] => hi
		*/
		lsz = std.encode(lbuf[:], r[0])
		hsz = std.encode(hbuf[:], r[1])
		for i = lsz; i < hsz; i++
			bsz = bound(boundbuf[:], i, 0xff)
			rtinsert(rt, lbuf[:lsz], boundbuf[:bsz])
			lsz = bound(lbuf[:], i + 1, 0x00)
		;;
		rtinsert(rt, lbuf[:lsz], hbuf[:hsz])
	;;
	if re.trace
		rtdump(rt, 0)
	;;
	rangegen(re, ast, rt, rt.ranges, rt.link, rangeprogsize(rt) + re.proglen)
	rtfree(rt)
	-> re.proglen
}
const bound = {buf, len, fill
	var s
	if len == 1
		buf[0] = 0x7f
	else
		s = (len : byte)
		buf[0] = (0xff << (8 - s)) | (fill >> (s + 1))
		for var i = 1; i < len; i++
			buf[i] = 0x80 | (fill >> 2)
		;;
	;;
	-> len
}
type rangetrie = struct
	ranges	: (byte, byte)[:]
	link	: rangetrie#[:]
	end	: bool
;;
const rtdump = {rt, ind
	var l, h
	indent(ind)
	std.put("Range (end = {}) {{\n", rt.end)
	for var i = 0; i < rt.ranges.len; i++
		indent(ind + 1)
		(l, h) = rt.ranges[i]
		std.put("0x{x}-0x{x}: \n", l, h)
		rtdump(rt.link[i], ind + 1)
	;;
	indent(ind)
	std.put("}\n")
}
const indent = {ind
	for var i = 0; i < ind; i++
		std.put("\t")
	;;
}
const rtinsert = {rt, lo, hi
	var a, b
	var n
	std.assert(lo.len == hi.len, "range sizes differ")
	if lo.len == 0
		rt.end = true
		-> void
	;;
	n = rt.ranges.len
	if n == 0
		std.slpush(&rt.ranges, (lo[0], hi[0]))
		std.slpush(&rt.link, std.zalloc())
	else
		/*
		this is a safe way to compare because we know that ranges
		should always be coming in ordered. This means that equal
		values will be added one after the other.
		*/
		(a, b) = rt.ranges[n - 1]
		if a != lo[0] || b != hi[0]
			std.slpush(&rt.ranges, (lo[0], hi[0]))
			std.slpush(&rt.link, std.zalloc())
		;;
	;;
	rtinsert(rt.link[rt.link.len - 1], lo[1:], hi[1:])
}
const rtfree = {rt
	for l : rt.link
		rtfree(l)
	;;
	std.slfree(rt.link)
	std.slfree(rt.ranges)
	std.free(rt)
}
const rangegen = {re, ast, rt, ranges, links, end
	var alt, l0, l1, l2
	var a, b
	var n
	n = ranges.len
	if n == 0
		-> re.proglen
	elif n == 1
		(a, b) = ranges[0]
		append(re, `Irange (a, b), ast)
		if links[0].end
			if links[0].ranges.len > 0
				append(re, `Ifork (re.prog.len + 1, end), ast)
			else
				append(re, `Ijmp end, ast)
			;;
		;;
		rangegen(re, ast, links[0], links[0].ranges, links[0].link, end)
	else
		alt = re.proglen
		l0 = append(re, `Ifork (-1, -1), ast)
		l1 = rangegen(re, ast, rt, ranges[0:n/2], links[0:n/2], end)
		l2 = rangegen(re, ast, rt, ranges[n/2:n], links[n/2:n], end)
		re.prog[alt] = `Ifork (l0, l1)
	;;
	-> re.proglen
}
const rangeprogsize = {rt
	var sz
	if rt.ranges.len == 0
		sz = 0
	else
		sz = 2*rt.ranges.len - 1
		for l : rt.link
			sz += rangeprogsize(l)
		;;
	;;
	if rt.end
		sz += 1
	;;
	-> sz
}
/* calculates the forward jump distance for a utf8 character range */
const jmpdist = {n
	var d
	d = n - 1
	for var i = n - 1; i > 0; i--
		d += i
	;;
	-> d
}
/* generates an alternation */
const genalt = {re, l, r, t
	var alt
	var jmp
	var l0
	var l1
	var l2
	alt 	= re.proglen
	l0	= append(re, `Ifork (-1, -1), t) /* needs to be replaced */
		  gen(re, l)
	jmp	= re.proglen
	l1 	= append(re, `Ijmp -1, t) /* needs to be replaced */
	l2	= gen(re, r)
	re.prog[alt] = `Ifork(l0, l1)
	re.prog[jmp] = `Ijmp l2
	-> re.proglen
}
/* generates a repetition operator */
const genstar = {re, rep, reluct, t
	var alt
	var jmp
	var l0
	var l1
	var l2
	l0 	= re.proglen
	alt	= re.proglen
	l1 	= append(re, `Ifork (-1, -1), t) /* needs to be replaced */
	jmp	= gen(re, rep)
	l2	= append(re, `Ijmp -1, t)
	/* reluctant matches should prefer jumping to the end. */
	if reluct
		re.prog[alt] = `Ifork (l2, l1)
	else
		re.prog[alt] = `Ifork (l1, l2)
	;;
	re.prog[jmp] = `Ijmp l0
	-> re.proglen
}
/* generates a question mark operator */
const genquest = {re, q
	var alt
	var l0
	var l1
	alt	= re.proglen
	l0	= append(re, `Ifork (-1, -1), q) /* needs to be replaced */
	l1	= gen(re, q)
	re.prog[alt] = `Ifork (l0, l1)
	-> re.proglen
}
/* generates a single char match */
const genchar = {re, c, ast
	var b : byte[4]
	var n
	n = std.encode(b[:], c)
	std.assert(n > 0 && n < 4, "non-utf character in regex\n")
	for var i = 0; i < n; i++
		append(re, `Ibyte b[i], ast)
	;;
	-> re.proglen
}
/* appends an instructon to an re program */
const append = {re, insn, ast
	var sz
	if re.proglen == re.prog.len
		sz = std.max(1, 2*re.proglen)
		std.slgrow(&re.prog, sz)
		if re.debug
			std.slgrow(&re.pcidx, sz)
		;;
	;;
	if re.debug
		re.pcidx[re.proglen] = std.htgetv(re.astloc, ast, -1)
	;;
	re.prog[re.proglen] = insn
	re.proglen++
	-> re.proglen
}
/* instruction dump */
const idump = {re
	if !re.trace
		-> void
	;;
	for var i = 0; i < re.proglen; i++
		std.put("{}@{}:\t", i, re.pcidx[i])
		match re.prog[i]
		/* Char matching. Consume exactly one byte from the string. */
		| `Ibyte b:		std.put("`Ibyte {} ({})\n", b, (b : char)) 
		| `Irange (start, end):	
			std.put("`Irange ({},{})", start, end) 
			if std.isalnum((start : char)) && std.isalnum((end : char))
				std.put("\t/* {}-{} */", (start : char), (end : char))
			;;
			std.put("\n")
		/* capture groups */
		| `Ilbra m:		std.put("`Ilbra {}\n", m) 
		| `Irbra m:		std.put("`Irbra {}\n", m) 
		/* anchors */
		| `Ibol:			std.put("`Ibol\n")
		| `Ieol:			std.put("`Ieol\n")
		| `Ibow:			std.put("`Ibow\n")
		| `Ieow:			std.put("`Ieow\n")
		/* control flow */
		| `Ifork	(lip, rip):	std.put("`Ifork ({},{})\n", lip, rip) 
		| `Ijmp ip:		std.put("`Ijmp {}\n", ip) 
		| `Imatch id:		std.put("`Imatch {}\n", id) 
		;;
	;;
}
/* AST dump */
const dump = {re, t, indent
	if !re.trace
		-> void
	;;
	for var i = 0; i < indent; i++
		std.put("  ")
	;;
	std.put("{}/", std.htgetv(re.astloc, t, -1))
	match t#
	| `Alt	(a, b):
		std.put("Alt\n")
		dump(re, a, indent + 1)
		dump(re, b, indent + 1)
	| `Cat	(a, b):
		std.put("Cat\n")
		dump(re, a, indent + 1)
		dump(re, b, indent + 1)
	/* repetition */
	| `Star	(a, r):
		std.put("Star reluct={}\n", r)
		dump(re, a, indent + 1)
	| `Plus	(a, r):
		std.put("Plus reluct={}\n", r)
		dump(re, a, indent + 1)
	| `Quest	a:
		std.put("Quest\n")
		dump(re, a, indent + 1)
	| `Bol:
		std.put("Bol\n")
	| `Eol:
		std.put("Eol\n")
	| `Bow:
		std.put("Bow\n")
	| `Eow:
		std.put("Eow\n")
	/* end matches */
	| `Chr	c:
		std.put("Chr {}\n", c)
	| `Ranges rl:
                std.put("Ranges")
		for r : rl
			for var i = 0; i < indent + 1; i++
				std.put("  ")
			;;
			std.put("\t({}-{})\n", r[0], r[1])
		;;
	/* meta */
	| `Cap	(m, a):
		std.put("Cap {}\n", m)
		dump(re, a, indent + 1)
	;;
}
/* parses an expression */
const regexparse = {re
	match altexpr(re)
	| `Some t:
		if re.pat.len == re.idx
			-> `Some t
		else
			astfree(t)
			-> `Err `Incomplete
		;;
	| `None:
		-> `None
	| `Err st:
		-> `Err st
	;;
}
const altexpr = {re
	var ret
	var idx
	match catexpr(re)
	| `Some t:
		ret = t
		idx = re.idx
		if matchc(re, '|')
			match altexpr(re)
			| `Some rhs:
				ret = mk(re, `Alt (ret, rhs), idx)
			| `None:
				astfree(ret)
				-> `Err (`Incomplete)
			| `Err f:
				-> `Err f
			;;
		;;
	| other:
		-> other
	;;
	-> `Some ret
}
const catexpr = {re
	var ret
	var idx
	idx = re.idx
	match repexpr(re)
	| `Some t: 
		ret = t
		match catexpr(re)
		| `Some rhs:
			ret = mk(re, `Cat (t, rhs), idx)
		| `Err f:	-> `Err f
		| `None:	/* nothing */
		;;
	| other:
		-> other
	;;
	-> `Some ret
}
const repexpr = {re
	var ret
	var idx
	match baseexpr(re)
	| `Some t:
		idx = re.idx
		if matchc(re, '*')
			ret = mk(re, `Star (t, matchc(re, '?')), idx)
		elif matchc(re, '+')
			ret = mk(re, `Plus (t, matchc(re, '?')), idx)
		elif matchc(re, '?')
			ret = mk(re, `Quest t, idx)
		else
			ret = t
		;;
	| other:
		-> other
	;;
	-> `Some ret
}
const baseexpr = {re
	var ret, m, idx
	var nocap
	if re.pat.len == re.idx
		-> `None
	;;
	idx = re.idx
	match peekc(re)
	/* lower prec operators */
	| '|':	-> `None
	| ')':	-> `None
	| '*':	-> `Err `Badrep '*'
	| '+':	-> `Err `Badrep '+'
	| '?':	-> `Err `Badrep '?'
	| '[':	-> chrclass(re)
	| '^':	getc(re); ret = mk(re, `Bol, re.idx)
	| '$':	getc(re); ret = mk(re, `Eol, re.idx)
	| '.':	
		getc(re);
		ret = mk(re, `Ranges std.sldup([[0, std.Maxcharval]][:]), idx)
	| '(':	
		nocap = false
		m = 0
		getc(re)
		if matchc(re, '?')
			if !matchc(re, ':')
				-> `Err `Badrep '?'
			;;
			nocap = true
		else
			m = re.nmatch++
		;;
		match altexpr(re)
		| `Some s:
			if matchc(re, ')')
				if nocap
					-> `Some s
				else
					-> `Some mk(re, `Cap (m, s), idx)
				;;
			else
				-> `Err `Unbalanced '('
			;;
		| `None:	-> `Err `Emptyparen
		| `Err st:	-> `Err st
		;;
	| '\\':
		getc(re) /* consume the slash */
		if re.pat.len == re.idx
			-> `Err `Incomplete
		;;
		-> escaped(re)
	| c:
		getc(re)
		ret = mk(re, `Chr c, idx)
	;;
	-> `Some ret
}
const escaped = {re
	var ret, idx
	idx = re.idx
	match getc(re)
	/* character classes */
	| 'd':	ret = `Some mk(re, `Ranges std.sldup(_ranges.tabasciidigit[:]), idx)
	| 'x':	ret = `Some mk(re, `Ranges std.sldup(_ranges.tabasciixdigit[:]), idx)
	| 's':	ret = `Some mk(re, `Ranges std.sldup(_ranges.tabasciispace[:]), idx)
	| 'w':	ret = `Some mk(re, `Ranges std.sldup(_ranges.tabasciiword[:]), idx)
	| 'h':	ret = `Some mk(re, `Ranges std.sldup(_ranges.tabasciiblank[:]), idx)
	/* negated character classes */
	| 'W':	ret = `Some mk(re, `Ranges negate(_ranges.tabasciiword[:]), idx)
	| 'S':	ret = `Some mk(re, `Ranges negate(_ranges.tabasciispace[:]), idx)
	| 'D':	ret = `Some mk(re, `Ranges negate(_ranges.tabasciidigit[:]), idx)
	| 'X':	ret = `Some mk(re, `Ranges negate(_ranges.tabasciixdigit[:]), idx)
	| 'H':	ret = `Some mk(re, `Ranges negate(_ranges.tabasciiblank[:]), idx)
	/* unicode character classes */
	| 'p':	ret = unicodeclass(re, false)
	| 'P':	 ret = unicodeclass(re, true)
	/* operators that need an escape */
	| '<':	ret = `Some mk(re, `Bow, idx)
	| '>':	ret = `Some mk(re, `Eow, idx)
	/* escaped metachars */
	| '^':	ret = `Some mk(re, `Chr '^', idx)
	| '$':	ret = `Some mk(re, `Chr '$', idx)
	| '.':	ret = `Some mk(re, `Chr '.', idx)
	| '+':	ret = `Some mk(re, `Chr '+', idx)
	| '?':	ret = `Some mk(re, `Chr '?', idx)
	| '*':	ret = `Some mk(re, `Chr '*', idx)
	/* escaped nonprintable characters */
	| 'r':	ret = `Some mk(re, `Chr '\r', idx)
	| 'n':	ret = `Some mk(re, `Chr '\n', idx)
	| 'b':	ret = `Some mk(re, `Chr '\b', idx)
	| ')':	ret = `Some mk(re, `Chr ')', idx)
	| 'u':	ret = unichar(re, idx)
	| chr:	ret = `Err `Badescape chr
	;;
	-> ret
}
const unichar = {re, idx
	var c
	if !matchc(re, '{')
		-> `Err `Badescape 'u'
	;;
	c = 0
	while std.isxdigit(peekc(re))
		c *= 16
		c += std.charval(getc(re), 16)
	;;
	if !matchc(re, '}')
		-> `Err `Badescape 'u'
	;;
	-> `Some mk(re, `Chr c, idx)
}
const unicodeclass = {re, neg
	var c, s
	var tab
	var t
	var n
	var idx
	if re.pat.len == re.idx
		-> `Err (`Incomplete)
	;;
	n = 0
	idx = re.idx
	s = re.pat[idx:]
	/* either a single char pattern, or {pat} */
	match getc(re)
	| '{':
		s = s[1:]
		while re.pat.len > re.idx
			c = getc(re)
			if c == '}'
				break
			;;
			n += std.charlen(c)
		;;
	| r:
		n += std.charlen(r)
	;;
	s = s[:n]
	/* letters */
	if std.sleq(s, "L") || std.sleq(s, "Letter")
		tab = _ranges.tabalpha[:]
	elif std.sleq(s, "Lu") || std.sleq(s, "Uppercase_Letter")
		tab = _ranges.tabupper[:]
	elif std.sleq(s, "Ll") || std.sleq(s, "Lowercase_Letter")
		tab = _ranges.tablower[:]
	elif std.sleq(s, "Lt") || std.sleq(s, "Titlecase_Letter")
		tab = _ranges.tablower[:]
	/* numbers (incomplete) */
	elif std.sleq(s, "N") || std.sleq(s, "Number")
		tab = _ranges.tabdigit[:]
	elif std.sleq(s, "Z") || std.sleq(s, "Separator")
		tab = _ranges.tabspace[:]
	elif std.sleq(s, "Zs") || std.sleq(s, "Space_Separator")
		tab = _ranges.tabblank[:]
	else
		-> `Err (`Badrange s)
	;;
	if !neg
		t = mk(re, `Ranges std.sldup(tab), idx)
	else
		t = mk(re, `Ranges negate(tab), idx)
	;;
	-> `Some t
}
const chrclass = {re
	var rl, m, n
	var neg
	var idx
	var t
	/* we know we saw '[' on entry */
	idx = re.idx
	matchc(re, '[')
	neg = false
	if matchc(re, '^')
		neg = true
	;;
	rl = rangematch(re, [][:])
	while peekc(re) != ']' && re.pat.len > re.idx
		rl = rangematch(re, rl)
	;;
	if !matchc(re, ']')
		std.slfree(rl)
		-> `Err `Unbalanced '['
	;;
	std.sort(rl, {a, b;
		if a[0] < b[0]
			-> `std.Before
		elif a[0] == b[0]
			-> `std.Equal
		else
			-> `std.After
		;;})
	m = merge(rl)
	std.slfree(rl)
	if neg
		n = negate(m)
		std.slfree(m)
		t = mk(re, `Ranges n, idx)
	else
		t = mk(re, `Ranges m, idx)
	;;
	-> `Some t
}
const rangematch = {re, sl
	var lo
	var hi
	lo = getc(re)
	if matchc(re, '-')
		hi = getc(re)
		if lo <= hi
			-> std.slpush(&sl, [lo, hi])
		else
			-> std.slpush(&sl, [hi, lo])
		;;
	else
		-> std.slpush(&sl, [lo, lo])
	;;
}
const negate = {rng
	var start, end, next
	var neg
	neg = [][:]
	start = 0
	next = 0 /* if we have no ranges */
	for r : rng
		(end, next) = (r[0], r[1])
		std.slpush(&neg, [start, end - 1])
		start = next + 1
	;;
	std.slpush(&neg, [next + 1, std.Maxcharval])
	-> neg
}
/* rl is a sorted list of ranges */
const merge = {rl
	var lo, hi
	var ret
	if rl.len == 0
		-> [][:]
	;;
	ret = [][:]
	lo = rl[0][0]
	hi = rl[0][1]
	for r : rl[1:]
		/* if it overlaps or abuts, merge */
		if r[0] <= hi + 1
			hi = r[1]
		else
			std.slpush(&ret, [lo, hi])
			lo = r[0]
			hi = r[1]
		;;
	;;
	-> std.slpush(&ret, [lo, hi])
}
const matchc = {re, c
	var chr
	chr = std.decode(re.pat[re.idx:])
	if chr != c
		-> false
	;;
	re.idx += std.charlen(chr)
	-> true
}
const getc = {re
	var c
	c = std.decode(re.pat[re.idx:])
	re.idx += std.charlen(c)
	-> c
}
const peekc = {re
	-> std.decode(re.pat[re.idx:])
}
const astfree = {t
	match t#
	| `Alt	(a, b): astfree(a); astfree(b)
	| `Cat	(a, b): astfree(a); astfree(b)
	/* repetition */
	| `Star	(a, r):	astfree(a)
	| `Plus	(a, r):	astfree(a)
	| `Quest a:	astfree(a)
	/* end matches */
	| `Chr c:	
	| `Ranges rl:	std.slfree(rl)
	/* meta */
	| `Cap	(m, a):	astfree(a)
	| _:	/* other types have no suballocations */
	;;
	std.free(t)
}
const fmtfail = {sb, ap, opt
	match std.vanext(ap)
	| `Noimpl:	std.sbfmt(sb, "no implementation")
	| `Incomplete:	std.sbfmt(sb, "regex ended before input fully parsed")
	| `Unbalanced c:	std.sbfmt(sb, "unbalanced {}", c)
	| `Emptyparen:	std.sbfmt(sb, "empty parentheses")
	| `Badrep c:	std.sbfmt(sb, "invalid repetition {}", c)
	| `Badrange s:	std.sbfmt(sb, "invalid range name {}", s)
	| `Badescape c:	std.sbfmt(sb, "invalid escape code {}", c)
	;;
}
const mk = {re, ast, loc
	var n
	n = std.mk(ast)
	if re.debug
		std.htput(re.astloc, n, loc)
	;;
	-> n
}
const __init__ = {
	var e : status
	std.fmtinstall(std.typeof(e), fmtfail, [][:])
}