ref: 654eb21dce9327393d64b13d3eb3cbc8db35fe3b
dir: /lib/regex/interp.myr/
use std
use "types"
pkg regex =
	/* regex execution */
	const exec	: (re : regex#, str : byte[:] -> std.option(byte[:][:]))
	const search	: (re : regex#, str : byte[:] -> std.option(byte[:][:]))
	/* regex execution returning indexes */
	const iexec	: (re : regex#, str : byte[:] -> std.option((std.size, std.size)[:]))
	const isearch	: (re : regex#, str : byte[:] -> std.option((std.size, std.size)[:]))
	/* substitution */
	const sub	: (re : regex#, str : byte[:], subst : byte[:][:] -> std.option(byte[:]))
	const sbsub	: (sb : std.strbuf#, re : regex#, str : byte[:], subst : byte[:][:] -> bool)
	const suball	: (re : regex#, str : byte[:], subst : byte[:][:] -> byte[:])
	const sbsuball	: (sb : std.strbuf#, re : regex#, str : byte[:], subst : byte[:][:] -> void)
	const matchfree	: (pat : byte[:][:] -> void)
;;
/* Ugly: for performance. std.option() should be used instead when unions get faster. */
const Zthr = (0 : rethread#)
const Maxfree = 128
const matchfree = {m
	std.slfree(m)
}
const exec = {re, str
	-> getmatches(re, run(re, str, 0, true))
}
const iexec = {re, str
	-> getidxmatches(re, run(re, str, 0, true))
}
const search = {re, str
	var thr = Zthr
	for var i = 0; i < str.len; i++
		thr = run(re, str[i:], 0, false)
		if thr != Zthr
			break
		;;
	;;
	-> getmatches(re, thr)
}
const isearch = {re, str
	var thr = Zthr
	for var i = 0; i < str.len; i++
		thr = run(re, str[i:], 0, false)
		if thr != Zthr
			break
		;;
	;;
	-> getidxmatches(re, thr)
}
const sub = {re, str, subst
	var sb
	sb = std.mksb()
	if sbsub(sb, re, str, subst)
		-> `std.Some std.sbfin(sb)
	;;
	-> `std.None
}
const sbsub = {sb, re, str, subst
	std.assert(re.nmatch == subst.len + 1, "substitution length does not match capture count")
	-> dosubst(sb, re, run(re, str, 0, true), str, subst)
}
const suball = {re, str, subst
	var sb
	sb = std.mksb()
	sbsuball(sb, re, str, subst)
	-> std.sbfin(sb)
}
const sbsuball = {sb, re, str, subst
	var thr, len, i
	std.assert(re.nmatch == subst.len + 1, "substitution length does not match capture count")
	i = 0
	while i < str.len
		thr = run(re, str[i:], 0, false)
		if thr == Zthr
			std.sbputb(sb, str[i])
			i++
		else
			len = thr.mgroup[0][1]
			dosubst(sb, re, thr, str[i:len + i], subst)
			i += len
		;;
		cleanup(re, thr)
	;;
}
const dosubst = {sb, re, thr, str, subst
	var off
	if thr == Zthr
		-> false
	;;
	off = 0
	for var i = 1; i < re.nmatch; i++
		if thr.mgroup[i][0] != -1 && thr.mgroup[i][1] != -1
			std.sbputs(sb, str[off:thr.mgroup[i][0]])
			std.sbputs(sb, subst[i - 1])
			off = thr.mgroup[i][1]
		;;
	;;
	std.sbputs(sb, str[off:])
	-> true
}
const cleanup = {re, result
	lfree(re.runq)
	lfree(re.expired)
	lfree(re.free)
	std.free(result)
	re.runq = Zthr
	re.expired = Zthr
	re.free = Zthr
	re.nfree = 0
	re.nthr = 0
}
const lfree = {thr
	for var next = thr; thr != Zthr; thr = next
		next = thr.next
		std.free(thr)
	;;
}
const getmatches = {re, thr
	var ret, i
	if thr == Zthr
		-> `std.None
	;;
	i = 0
	ret = std.slalloc(re.nmatch)
	for [lo, hi] : thr.mgroup[:re.nmatch]
		if lo != -1 && hi != -1
			ret[i] = re.str[lo : hi]
		;;
		i++
	;;
	cleanup(re, thr)
	-> `std.Some ret
}
const getidxmatches = {re, thr
	var ret
	if thr == Zthr
		-> `std.None
	;;
	ret = std.slalloc(re.nmatch)
	for var i = 0; i < re.nmatch; i++
		ret[i] = (thr.mgroup[i][0], thr.mgroup[i][1])
	;;
	cleanup(re, thr)
	-> `std.Some ret
}
/* returns a matching thread, or Zthr if no threads matched */
const run = {re, str, idx, wholestr
	var bestmatch
	var consumed
	var states
	var thr
	var ip
	re.str = str
	re.strp = 0
	re.nexttid = 0
	bestmatch = Zthr
	states = std.mkbs()
	re.runq = mkthread(re, 0)
	if re.debug
		/* The last run could have left things here, since we need this info after the run */
		for bs : re.traces
			std.bsfree(bs)
		;;
		std.slfree(re.traces)
		re.traces = [][:]
		std.slpush(&re.traces, std.mkbs())
	;;
	for var i = 0; i < re.nmatch; i++
		re.runq.mgroup[i][0] = -1
		re.runq.mgroup[i][1] = -1
	;;
	while re.nthr > 0
		while re.runq != Zthr
			if re.trace
				std.put("switch\n")
			;;
			/* set up the next thread */
			thr = re.runq
			re.runq = thr.next
			ip = thr.ip
			consumed = step(re, thr, -1)
			while !consumed
				consumed = step(re, thr, ip)
			;;
			if std.bshas(states, thr.ip)
				die(re, thr)
			;;
			if thr.dead
				thrfree(re, thr)
			elif thr.matched
				if bestmatch != Zthr
					thrfree(re, bestmatch)
				;;
				if re.strp == re.str.len
					bestmatch = thr
					goto done
				elif !wholestr
					bestmatch = thr
				else
					thrfree(re, thr)
				;;
			elif !thr.matched
				std.bsput(states, thr.ip)
				if re.expired == Zthr
					re.expired = thr
				;;
				if re.expiredtail != Zthr
					re.expiredtail.next = thr
				;;
				re.expiredtail = thr
				thr.next = Zthr
			;;
		;;
		std.bsclear(states)
		re.runq = re.expired
		re.expired = Zthr
		re.expiredtail = Zthr
		re.strp++
	;;
:done
	std.bsfree(states)
	-> bestmatch
}
/* 
 Steps forward one instruction. Returns true if a byte of input was
 consumed, false otherwise.
*/
const step = {re, thr, curip
	var str, nthr, inst
	str = re.str
	inst = re.code[thr.ip]
	if re.trace
		itrace(re, thr, re.prog[thr.ip])
	;;
	if re.debug
		std.bsput(re.traces[thr.tid], thr.ip)
	;;
	match inst & 0xf
	/* Char matching. Consume exactly one byte from the string. */
	| OpRange:
		var lo = (inst >>  4 : byte)
		var hi = (inst >> 16 : byte)
		
		if !within(re, str) || lo > str[re.strp] || hi < str[re.strp]
			die(re, thr)
		else
			thr.ip++
		;;
	| OpByte:
		var b = (inst >> 4 : byte)
		if !within(re, str)
			die(re, thr)
		elif b != str[re.strp]
			die(re, thr)
		else
			thr.ip++
		;;
	| OpFork:
		var lip = ((inst >>  4) & 0x3fffffff : std.size)
		var rip = ((inst >> 34) & 0x3fffffff : std.size)
		if rip != curip
			nthr = mkthread(re, rip)
			nthr.next = re.runq
			nthr.mgroup = thr.mgroup
			re.runq = nthr
		;;
		if re.debug
			std.slpush(&re.traces, std.bsdup(re.traces[thr.tid]))
		;;
		thr.ip = lip
		-> false
	/*
	  Non-consuming. All of these return false, and expect step to be
	  called again until exactly one byte is consumed from the string.
	 */
	| OpJmp:
		var ip = (inst >> 4 : std.size)
		thr.ip = ip
		-> false
	| OpMatch:
		var id = (inst >> 4 : std.size)
		re.lastthr = thr.tid
		finish(re, thr)
		-> true
	| OpLbra:
		var m = (inst >> 4 : std.size)
		thr.mgroup[m][0] = re.strp
		thr.ip++
		-> false
	| OpRbra:
		var m = (inst >> 4 : std.size)
		thr.mgroup[m][1] = re.strp
		thr.ip++
		-> false
	| OpBol:
		if re.strp == 0 || str[re.strp - 1] == ('\n' : byte)
			thr.ip++
			-> false
		else
			die(re, thr)
		;;
	| OpEol:
		if re.strp == str.len || str[re.strp] == ('\n' : byte)
			thr.ip++
			-> false
		else
			die(re, thr)
		;;
	/* check for word characters */
	| OpBow:
		if iswordchar(str[re.strp:]) && (re.strp == 0 || !iswordchar(prevchar(str, re.strp)))
			thr.ip++
			-> false
		else
			die(re, thr)
		;;
	| OpEow:
		if re.strp == str.len && iswordchar(prevchar(str, re.strp))
			thr.ip++
			-> false
		elif re.strp > 0 && !iswordchar(str[re.strp:]) && iswordchar(prevchar(str, re.strp))
			thr.ip++
			-> false
		else
			die(re, thr)
		;;
	| _:
		std.die("corrupt regex bytecode")
	;;
	-> true
}
const die = {re, thr
        /*
 	  we can have die called on a thread
	  multiple times, eg, if it has a bad
	  range *and* end in a state that another
	  thread is in. We should only decrement
	  the number of threads for that once.
	 */
        if !thr.dead
		re.nthr--
	;;
	re.lastip = thr.ip
	re.lastthr = thr.tid
	thr.dead = true
}
const finish = {re, thr
	thr.matched = true
	re.nthr--
}
const mkthread = {re, ip
	var thr : rethread#
	if re.free != Zthr
		thr = re.free
		re.free = thr.next
		re.nfree--
	else
		thr = std.alloc()
	;;
	thr.next = Zthr
	thr.ip = ip
	thr.tid = re.nexttid++
	thr.dead = false
	thr.matched = false
	re.nthr++
	-> thr
}
const thrfree = {re, thr
	if re.nfree >= Maxfree
		std.free(thr)
	else
		thr.next = re.free
		re.free = thr
		re.nfree++
	;;
}
const within = {re, str 
	-> re.strp < str.len
}
const itrace = {re, thr, inst
	match inst
	| `Ibyte b:	std.put("\t{}.{}:\tByte ({})\n", thr.tid, thr.ip, thr.tid, thr.ip, b)
	| `Irange (lo, hi):	std.put("\t{}.{}:\tRange {}, {}\n", thr.tid, thr.ip, lo, hi)
	| `Ilbra m:	std.put("\t{}.{}:\tLbra {}\n", thr.tid, thr.ip, m)
	| `Irbra m:	std.put("\t{}.{}:\tRbra {}\n", thr.tid, thr.ip, m)
	/* anchors */
	| `Ibol:	std.put("\t{}.{}:\tBol\n", thr.tid, thr.ip)
	| `Ieol:	std.put("\t{}.{}:\tEol\n", thr.tid, thr.ip)
	| `Ibow:	std.put("\t{}.{}:\tBow\n", thr.tid, thr.ip)
	| `Ieow:	std.put("\t{}.{}:\tEow\n", thr.tid, thr.ip)
	/* control flow */
	| `Ifork (l, r):	std.put("\t{}.{}:\tFork {}, {}\n", thr.tid, thr.ip, l, r)
	| `Ijmp	ip:	std.put("\t{}.{}:\tJmp {}\n", thr.tid, thr.ip, ip)
	| `Imatch m:	std.put("\t{}.{}:\tMatch {}\n", thr.tid, thr.ip, m)
	;;
}
/* must be called with i >= 1 */
const prevchar = {s, i
	std.assert(i != 0, "prevchar must be called with i >= 1\n")
	i--
	while i != 0 && s[i] >= 0x80
		i--
	;;
	-> s[i:]
}
const iswordchar = {s
	var c
	c = std.decode(s)
	-> std.isalpha(c) || std.isdigit(c) || c == '_'
}