shithub: mc

ref: 56dffc40e38779981b50628d0c4781255410a357
dir: /lib/std/fltparse.myr/

View raw version
use "alloc"
use "bigint"
use "chartype"
use "chomp"
use "extremum"
use "fltbits"
use "hasprefix"
use "intparse"
use "option"
use "utf"
use "striter"

pkg std =
	const flt64parse	: (str : byte[:] -> std.option(flt64))
	const flt32parse	: (str : byte[:] -> std.option(flt32))
;;

type parseresult = union
	`Flt (int, std.bigint#, int16)
	`Inf int
	`Nan int
	`Badflt
;;

const flt64parse = {str
	var val

	match fltparse(str)
	| `Badflt:	-> `std.None
	| `Inf sign:	-> `std.Some ((sign : flt64) * std.flt64inf())
	| `Nan sign:	-> `std.Some ((sign : flt64) * std.flt64nan())
	| `Flt (sign, mant, exp):
		val = toflt(sign, mant, exp, &lim64)
		std.bigfree(mant)
		-> `std.Some val
	;;
}

const flt32parse = {str
	var val

	match fltparse(str)
	| `Badflt:	-> `std.None
	| `Inf sign:	-> `std.Some ((sign : flt32) * std.flt32inf())
	| `Nan sign:	-> `std.Some ((sign : flt32) * std.flt32nan())
	| `Flt (sign, mant, exp):
		val = toflt(sign, mant, exp, &lim32)
		std.bigfree(mant)
		-> `std.Some (val : flt32)
	;;
}

const fltparse = {str
	var sign, exp, mant, expinc, i

	if std.chomp(&str, "-")
		sign = -1
	else
		sign = 1
		std.chomp(&str, "+")
	;;

	match str
	| "inf":	-> `Inf sign
	| "Inf":	-> `Inf sign
	| "nan":	-> `Nan sign
	| "NaN":	-> `Nan sign
	| "":	-> `Badflt
	| _:	/* nothing */
	;;

	i = 0
	exp = 0
	expinc = 0
	mant = std.mkbigint(0)
	for var c = std.decode(str[i:]); c == '0'; c = std.decode(str[i:])
		i++
	;;
	for var c = std.decode(str[i:]); fltchar(c); c = std.decode(str[i:])
		i++
		if c == '_'
			continue
		elif c == '.'
			/* fail if we've already seen the '.' */
			if expinc != 0
				goto error
			;;
			expinc = -1
		elif std.isdigit(c)
			exp += expinc
			std.bigmuli(mant, 10)
			std.bigaddi(mant, std.charval(c, 10))
		else
			goto error
		;;
	;;
	
	if std.hasprefix(str[i:], "e") || std.hasprefix(str[i:], "E")
		/* intparse doesn't accept '+', so if we have 'e+', consume the '+ */
		if std.hasprefix(str[i+1:], "+")
			i++
		;;
		match std.intparsebase(str[i+1:], 10)
		| `std.Some n:	exp += n
		| `std.None:	goto error
		;;
		i = str.len
	;;
	if i != str.len
		goto error
	;;

	-> `Flt (sign, mant, exp)

:error
	std.bigfree(mant)
	-> `Badflt
}

const fltchar = {c
	-> std.isdigit(c) || c == '.' || c == '_'
}

type lim = struct
	minsig	: uint64
	maxsig	: uint64
	loshift	: uint64
	nextinc	: uint64
	minexp	: int16
	maxexp	: int16
	sigbits	: int16
;;

const lim64 : lim = [
	.minsig=0x10000000000000ul,
	.maxsig=0x20000000000000ul,
	.minexp=-1022 - 52,
	.maxexp=1023 - 52,
	.sigbits=53,
	.nextinc=1,
	.loshift=0
]

const lim32 : lim = [
	.minsig=0x400000,
	.maxsig=0x800000,
	.minexp=-126-22,
	.maxexp=127-22,
	.sigbits=22,
	.nextinc=0x20000000,
	.loshift=30
]

const toflt = {sign, mant, exp, lim
	/* if it's all zero, just return 0 */
	if std.bigeqi(mant, 0)
		-> (sign : flt64) * 0.0
	/* todo: add fast parsing for common cases */
	else
		-> (sign : flt64) * fallback(mant, exp, lim)
	;;
}

const fallback = {mant, exp, lim
	var u, v, k : int16, eabs
	var x, r, xprime, rprime
	var f

	u = mant
	k = 0
	v = std.mkbigint(1)
	x = std.mkbigint(0)
	r = std.mkbigint(0)
	eabs = std.abs(exp)
	if exp >= 0
		/* can be optimized */
		mulexp(u, eabs, 10)
	else
		mulexp(v, eabs, 10)
	;;

	estimate(u, v, &k, lim)

	while true
		(xprime, rprime) = std.bigdivmod(u, v)
		std.bigsteal(x, xprime)
		std.bigsteal(r, rprime)
		if k == lim.minexp 
			if std.biggei(x, lim.minsig) && std.biglei(x, lim.maxsig)
				break
			else
				f = denormal(x, v, r, lim)
				goto done
			;;
		elif k > lim.maxexp
			f = std.flt64inf()
			goto done
		;;
		if std.biglti(x, lim.minsig)
			std.bigmuli(u, 2)
			k--
		elif std.biggti(x, lim.maxsig)
			std.bigmuli(v, 2)
			k++
		else
			break
		;;
	;;
	f = assemble(u, v, r, k, lim)
:done
	std.bigfree(v)
	std.bigfree(x)
	std.bigfree(r)
	-> f
}

const estimate = {u, v, k, lim
	var log2u, log2v, log2rat
	var ushift, vshift
	var targ

	targ = lim.sigbits
	log2u = std.bigbitcount(u)
	log2v = std.bigbitcount(v)
	ushift = 0
	vshift = 0
	log2rat = (log2u - log2v : int16)
	/* 
	if we deal with denormals, we just punt to the 'k == minexp' test
	and take the slow path.
	*/
	if log2rat < targ - 1
		ushift = std.clamp(targ - log2rat - 1, lim.minexp, lim.maxexp)
		k# -= ushift
	elif log2rat > targ + 1
		vshift = std.clamp(log2rat - targ + 1, lim.minexp, lim.maxexp)
		k# += vshift
	;;
	std.bigshli(u, ushift)
	std.bigshli(v, vshift)
}

const assemble = {u, v, r, k, lim
	var z, diff

	std.bigdiv(u, v)
	diff = std.bigdup(v)
	std.bigsub(diff, r)
	z = mkflt(u, k)

	match std.bigcmp(r, diff)
	| `std.Before:	/* nothing */
	| `std.After:	z = nextfloat(z, lim)
	| `std.Equal:
		if !std.bigiseven(u)
			z = nextfloat(z, lim)
		;;
	;;
	std.bigfree(diff)
	-> z
}

const nextfloat = {z, lim
	var sign, mant, exp
	var za

	(sign, exp, mant) = std.flt64explode(z)
	if std.abs((mant : int64) - (1l << 52) - 1) < (lim.nextinc : int64)
		mant = 0
		exp++
	else
		mant += lim.nextinc
	;;
	za = std.flt64assem(sign, exp, mant)
	-> za
}

const mulexp = {val, exp, pow
	while exp > 0
		exp--
		std.bigmuli(val, pow)
	;;
}

const denormal = {x, v, rem, lim
	var m, z

	if x.dig.len == 0
		m = 0
	elif x.dig.len == 2
		m = (x.dig[0] : uint64) 
		m |= (x.dig[1] : uint64) << 32
	else
		m = (x.dig[0] : uint64) << lim.loshift
	;;
	z = std.flt64frombits(m)

	match std.bigcmp(rem, v)
	| `std.Before:	/* nothing */
	| `std.After:	z = nextfloat(z, lim)
	| `std.Equal:
		if !std.bigiseven(x)
			z = nextfloat(z, lim)
		;;
	;;
	-> z
}

const mkflt = {mant, exp
	var m, e
	if mant.dig.len == 2
		/* flt64: guaranteed to be in the range 2^52..2^53 */
		m = (mant.dig[0] : uint64) 
		m |= (mant.dig[1] : uint64) << 32
	else
		/* flt32: 2^22..2^23, so adjust into flt64 range */
		m = (mant.dig[0] : uint64) << 30
		exp -= 30
	;;
	m &= (1<<52) - 1

	e = (exp : uint64)
	e += 1023 + 52 /* exponent bias */
	e &= 0x7ff
	e <<= 52

	-> std.flt64frombits(m | e)

}