shithub: mc

ref: 65ec134c5fc370d6df3382edb4cc633bd7802e3e
dir: /lib/math/pown-impl.myr/

View raw version
use std

use "fpmath"
use "impls"
use "log-overkill"
use "log-impl"
use "sum-impl"
use "util"

/*
   This is an implementation of pown: computing x^n where n is an
   integer. We sort of follow [PEB04], but without their high-radix
   log_2. Instead, we use log-overkill, which should be good enough.
 */
pkg math =
	pkglocal const pown32 : (x : flt32, n : int32 -> flt32)
	pkglocal const pown64 : (x : flt64, n : int64 -> flt64)

	pkglocal const rootn32 : (x : flt32, q : uint32 -> flt32)
	pkglocal const rootn64 : (x : flt64, q : uint64 -> flt64)
;;

type fltdesc(@f, @u, @i) = struct
	explode : (f : @f -> (bool, @i, @u))
	assem : (n : bool, e : @i, s : @u -> @f)
	tobits : (f : @f -> @u)
	frombits : (u : @u -> @f)
	C : (@u, @u)[:]
	one_over_ln2_hi : @u
	one_over_ln2_lo : @u
	nan : @u
	inf : @u
	neginf : @u
	magcmp : (f : @f, g : @f -> std.order)
	two_by_two : (x : @f, y : @f -> (@f, @f))
	split_add : (x_h : @f, x_l : @f, y_h : @f, y_l : @f -> (@f, @f))
	split_mul : (x_h : @f, x_l : @f, y_h : @f, y_l : @f -> (@f, @f))
	floor : (x : @f -> @f)
	log_overkill : (x : @f -> (@f, @f))
	precision : @i
	nosgn_mask : @u
	implicit_bit : @u
	emin : @i
	emax : @i
	imax : @i
	imin : @i
;;

const desc32 : fltdesc(flt32, uint32, int32) =  [
	.explode = std.flt32explode,
	.assem = std.flt32assem,
	.tobits = std.flt32bits,
	.frombits = std.flt32frombits,
	.C = accurate_logs32[0:130], /* See log-impl.myr */
	.one_over_ln2_hi = 0x3fb8aa3b, /* 1/ln(2), top part */
	.one_over_ln2_lo = 0x32a57060, /* 1/ln(2), bottom part */
	.nan = 0x7fc00000,
	.inf = 0x7f800000,
	.neginf = 0xff800000,
	.magcmp = mag_cmp32,
	.two_by_two = two_by_two32,
	.split_add = split_add32,
	.split_mul = split_mul32,
	.floor = floor32,
	.log_overkill = logoverkill32,
	.precision = 24,
	.nosgn_mask = 0x7fffffff,
	.implicit_bit = 23,
	.emin = -126,
	.emax = 127,
	.imax = 2147483647, /* For detecting overflow in final exponent */
	.imin = -2147483648,
]

const desc64 : fltdesc(flt64, uint64, int64) =  [
	.explode = std.flt64explode,
	.assem = std.flt64assem,
	.tobits = std.flt64bits,
	.frombits = std.flt64frombits,
	.C = accurate_logs64[0:130], /* See log-impl.myr */
	.one_over_ln2_hi = 0x3ff71547652b82fe,
	.one_over_ln2_lo = 0x3c7777d0ffda0d24,
	.nan = 0x7ff8000000000000,
	.inf = 0x7ff0000000000000,
	.neginf = 0xfff0000000000000,
	.magcmp = mag_cmp64,
	.two_by_two = two_by_two64,
	.split_add = hl_add,
	.split_mul = hl_mult,
	.floor = floor64,
	.log_overkill = logoverkill64,
	.precision = 53,
	.nosgn_mask = 0x7fffffffffffffff,
	.implicit_bit = 52,
	.emin = -1022,
	.emax = 1023,
	.imax = 9223372036854775807,
	.imin = -9223372036854775808,
]

const split_add32 = {x_h : flt32, x_l : flt32, y_h : flt32, y_l : flt32
	var x : flt64 = (x_h : flt64) + (x_l : flt64)
	var y : flt64 = (y_h : flt64) + (y_l : flt64)
	var z = x + y
	var z_h : flt32 = (z : flt32)
	var z_l : flt32 = ((z - (z_h : flt64)) : flt32)
	-> (z_h, z_l)
}

const split_mul32 = {x_h : flt32, x_l : flt32, y_h : flt32, y_l : flt32
	var x : flt64 = (x_h : flt64) + (x_l : flt64)
	var y : flt64 = (y_h : flt64) + (y_l : flt64)
	var z = x * y
	var z_h : flt32 = (z : flt32)
	var z_l : flt32 = ((z - (z_h : flt64)) : flt32)
	-> (z_h, z_l)
}

const pown32 = {x : flt32, n : int32
	-> powngen(x, n, desc32)
}

const pown64 = {x : flt64, n : int64
	-> powngen(x, n, desc64)
}

generic powngen = {x : @f, n : @i, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i
	var xb
	xb = d.tobits(x)

	var xn : bool, xe : @i, xs : @u
	(xn, xe, xs) = d.explode(x)

	var nf : @f = (n : @f)

	/*
	   Special cases. Note we do not follow IEEE exceptions.
	 */
	if n == 0
		/*
		   Anything^0 is 1. We're taking the view that x is a tiny range of reals,
		   so a dense subset of them are 1, even if x is 0.0.
		 */
		-> 1.0
	elif std.isnan(x)
		/* Propagate NaN (why doesn't this come first? Ask IEEE.) */
		-> d.frombits(d.nan)
	elif (x == 0.0 || x == -0.0)
		if n < 0 && (n % 2 == -1) && xn
			/* (+/- 0)^n = +/- oo */
			-> d.frombits(d.neginf)
		elif n < 0
			-> d.frombits(d.inf)
		elif n % 2 == 1
			/* (+/- 0)^n = +/- 0 (n odd) */
			-> d.assem(xn, d.emin - 1, 0)
		else
			-> 0.0
		;;
	elif n == 1
		/* Anything^1 is itself */
		-> x
	elif n == -1
		/* The CPU is probably better at division than we are at pow(). */
		-> 1.0/x
	;;

	/* (-f)^n = (-1)^n * (f)^n. Figure this out now, then pretend f >= 0.0 */
	var ult_sgn = 1.0
	if xn && (n % 2 == 1 || n % 2 == -1)
		ult_sgn = -1.0
	;;

	/*
	   Compute (with x = xs * 2^e)

	     x^n  = 2^(n*log2(xs)) * 2^(n*e)

	          = 2^(I + F) * 2^(n*e)
	          = 2^(F) * 2^(I+n*e)

	   Since n and e, and I are all integers, we can get the last part from
	   scale2. The hard part is computing I and F, and then computing 2^F.
	 */
	if xe > 0
		/*
		   But first: do some rough calculations: if we can show n*log(xs) has the
		   same sign as n*e, and n*e would cause overflow, then we might as well
		   return right now.

		   This also takes care of subnormals very nicely, so we don't have to do
		   any special handling to reconstitute xs "right", as we do in rootn.
		 */
		var exp_rough_estimate = n * xe
		if n > 0 && (exp_rough_estimate > d.emax + 1 || (exp_rough_estimate / n != xe))
			-> ult_sgn * d.frombits(d.inf)
		elif n < 0 && (exp_rough_estimate < d.emin - d.precision - 1 || (exp_rough_estimate / n != xe))
			-> ult_sgn * 0.0
		;;
	elif xe < 0
		/*
		   Also, if consider xs/2 and xe + 1, we can analyze the case in which
		   n*log(xs) has a different sign from n*e.
	         */
		var exp_rough_estimate = n * (xe + 1)
		if n > 0 && (exp_rough_estimate < d.emin - d.precision - 1 || (exp_rough_estimate / n != (xe + 1)))
			-> ult_sgn * 0.0
		elif n < 0 && (exp_rough_estimate > d.emax + 1 || (exp_rough_estimate / n != (xe + 1)))
			-> ult_sgn * d.frombits(d.inf)
		;;
	;;

	var ln_xs_hi, ln_xs_lo
	(ln_xs_hi, ln_xs_lo) = d.log_overkill(d.assem(false, 0, xs))

	/* Now x^n = 2^(n * [ ln_xs / ln(2) ]) * 2^(n + e) */
	var E1, E2
	(E1, E2) = d.split_mul(ln_xs_hi, ln_xs_lo, d.frombits(d.one_over_ln2_hi), d.frombits(d.one_over_ln2_lo))

	/*
	   Now log2(xs) = E1 + E2, so

	     x^n = 2^(n * E1 + E2) * 2^(n * e)
	 */

	var F1, F2
	(F1, F2) = d.split_mul(E1, E2, nf, 0.0)

	var I = rn(F1)
	(F1, F2) = d.split_add(-1.0 * (I : @f), 0.0, F1, F2)

	/* Now, x^n = 2^(F1 + F2) * 2^(I + n*e). */
	var log2_hi, log2_lo
	(log2_hi, log2_lo) = d.C[128]
	var G1, G2
	(G1, G2) = d.split_mul(F1, F2, d.frombits(log2_hi), d.frombits(log2_lo))

	var base = exp(G1) + G2
	var pow_xen = xe * n
	var pow = pow_xen + I
	if pow_xen / n != xe || (I > 0 && d.imax - I < pow_xen) || (I < 0 && d.imin - I > pow_xen)
		/*
		   The exponent overflowed. There's no way this is representable. We need
		   to at least recover the correct sign. If the overflow was from the
		   multiplication, then the sign we want is the sign that pow_xen should
		   have been. If the overflow was from the addition, then we still want
		   the sign that pow_xen should have had.
		 */
		if (xe > 0) == (n > 0)
			pow = 2 * d.emax
		else
			pow = 2 * d.emin
		;;
	;;

	-> ult_sgn * scale2(base, pow)
}

/*
   Rootn is barely different enough from pown to justify being split out
   into an entirely separate function.
 */
const rootn32 = {x : flt32, q : uint32
	-> rootngen(x, q, desc32)
}

const rootn64 = {x : flt64, q : uint64
	-> rootngen(x, q, desc64)
}

generic rootngen = {x : @f, q : @u, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i
	var xb
	xb = d.tobits(x)

	var xn : bool, xe : @i, xs : @u
	(xn, xe, xs) = d.explode(x)

	var qf : @f = (q : @f)

	/*
	   Special cases. Note we do not follow IEEE exceptions.
	 */
	if q == 0
		/* "for any x (even a zero, quiet NaN, or infinity" */
		-> 1.0
	elif std.isnan(x)
		-> d.frombits(d.nan)
	elif (x == 0.0 || x == -0.0)
		if xn && q % 2 == 1
			/* (+/- 0)^1/q = +/- oo (q odd) */
			-> d.assem(xn, d.emax, 0)
		else
			-> d.frombits(d.inf)
		;;
	elif q == 1
		/* Anything^1/1 is itself */
		-> x
	elif xe < d.emin
		/*
		   Subnormals are actually a problem. If we naively reconstitute xs, it
		   will be wildly wrong and won't match up with the exponent. So let's
		   pretend we have unbounded exponent range. We know the loop terminates
		   because we covered the +/-0.0 case above.
		 */
		xe++
		var check = 1 << d.implicit_bit
		while xs & check == 0
			xs <<= 1
			xe--
		;;
	;;

	/* As in pown */
	var ult_sgn = 1.0
	if xn && (q % 2 == 1)
		ult_sgn = -1.0
	;;

	/*
	   If we're looking at (1 + 2^-h)^1/q, and the answer will be 1 + e, with
	   (1 + e)^q = 1 + 2^-h, then for q and h large enough, e might be below
	   the representable range. Specifically,

	     (1 + e)^q ≅ 1 + qe + (q choose 2)e^2 + ...

	  So (using single-precision as the example)

	    (1 + 2^-23)^q ≅ 1 + q 2^-23 + (absolutely tiny terms)

	  And anything in [1, 1 + q 2^-24) will just truncate to 1.0 when
	  calculated.
	 */
	if xe == 0
		var cutoff = scale2(qf, -1 * d.precision - 1) + 1.0
		if (xb & d.nosgn_mask) < d.tobits(cutoff)
			-> 1.0
		;;
	elif xe == -1
		/* Something similar for (1 - e)^q */
		var cutoff = 1.0 - scale2(qf, -1 * d.precision - 1)
		if (xb & d.nosgn_mask) > d.tobits(cutoff)
			-> 1.0
		;;
	;;

	/* Similar to pown. Let e/q = E + psi, with E an integer.

	   x^(1/q) = e^(log(xs)/q) * 2^(e/q)
	           = e^(log(xs)/q) * 2^(psi) * 2^E
	           = e^(log(xs)/q) * e^(log(2) * psi) * 2^E
	           = e^( log(xs)/q  +  log(2) * psi ) * 2^E

	   I've opted to do things just in terms of natural base here because we
	   don't have an integer part, I, that we can slide over in infinite
	   precision.
	 */

	/* Calculate 1/q in very high precision */
	var qinv_hi = 1.0 / qf
	var qinv_lo = -math.fma(qinv_hi, qf, -1.0) / qf
	var ln_xs_hi, ln_xs_lo
	(ln_xs_hi, ln_xs_lo) = d.log_overkill(d.assem(false, 0, xs))

	var G1, G2
	(G1, G2) = d.split_mul(ln_xs_hi, ln_xs_lo, qinv_hi, qinv_lo)

	var E : @i
	if q > std.abs(xe)
		/* Don't cast q to @i unless we're sure it's in small range */
		E = 0
	else
		E = xe / (q : @i)
	;;
	var qpsi = xe - q * E
	var psi_hi = (qpsi : @f) / qf
	var psi_lo = -math.fma(psi_hi, qf, -(qpsi : @f)) / qf
	var log2_hi, log2_lo
	(log2_hi, log2_lo) = d.C[128]
	var H1, H2
	(H1, H2) = d.split_mul(psi_hi, psi_lo, d.frombits(log2_hi), d.frombits(log2_lo))

	var J1, J2, t1, t2
	/*
	   We can't use split_add; we don't kow the relative magitudes of G and H
	 */
	(t1, t2) = slow2sum(G2, H2)
	(J2, t1) = slow2sum(H1, t1)
	(J1, J2) = slow2sum(G1, J2)
	J2 = J2 + (t1 + t2)

	/* J1 + J2 approximates log(xs)/q + log(2)*psi */
	var base = exp(J1) + J2

	-> ult_sgn * scale2(base, E)
}