shithub: mc

ref: 0f3f41fae775b2647cd0cdaf0a9ef395c19b9878
dir: /lib/math/util.myr/

View raw version
use std

pkg math =
	const flt32fromflt64 : (f : flt64 -> flt32)
	const flt64fromflt32 : (x : flt32 -> flt64)

	/* For use in various normalizations */
	const find_first1_64 : (b : uint64, start : int64 -> int64)
	const find_first1_64_hl : (h : uint64, l : uint64, start : int64 -> int64)

	/* >> and <<, but without wrapping when the shift is >= 64 */
	const shr : (u : uint64, s : int64 -> uint64)
	const shl : (u : uint64, s : int64 -> uint64)

	/* Whether RN() requires incrementing after truncating */
	const need_round_away : (h : uint64, l : uint64, bitpos_last : int64 -> bool)

	/* Multiply x * y to z1 + z2 */
	const two_by_two64 : (x : flt64, y : flt64 -> (flt64, flt64))
	const two_by_two32 : (x : flt32, y : flt32 -> (flt32, flt32))

	/* Multiply (a_hi + a_lo) * (b_hi + b_lo) to (z_hi + z_lo) */
	const hl_mult : (a_h : flt64, a_l : flt64, b_h : flt64, b_l : flt64 -> (flt64, flt64))

	/* Add (a_hi + a_lo) * (b_hi + b_lo) to (z_hi + z_lo). Must have |a| > |b|. */
	const hl_add : (a_h : flt64, a_l : flt64, b_h : flt64, b_l : flt64 -> (flt64, flt64))

	/* Compare by magnitude */
	const mag_cmp32 : (f : flt32, g : flt32 -> std.order)
	const mag_cmp64 : (f : flt64, g : flt64 -> std.order)

	/* Return (s, t) such that s + t = a + b, with s = rn(a + b). */
	generic fast2sum : (x : @f, y : @f -> (@f, @f)) :: floating, numeric @f
	generic slow2sum : (x : @f, y : @f -> (@f, @f)) :: floating, numeric @f

	/* return (a, b, c), a decent sum for q */
	const triple_compensated_sum : (q : flt64[:] -> (flt64, flt64, flt64))

	/* Rounds a + b (as flt64s) to a flt32. */
	const round_down : (a : flt64, b : flt64 -> flt32)
;;

/* Split precision down the middle */
const twentysix_bits_mask = (0xffffffffffffffff << 27)

const flt64fromflt32 = {f : flt32
	var n, e, s
	(n, e, s) = std.flt32explode(f)
	var xs : uint64 = (s : uint64)
	var xe : int64 = (e : int64)

	if e == 128
		-> std.flt64assem(n, 1024, xs)
	elif e == -127
		/*
		  All subnormals in single precision (except 0.0s)
		  can be upgraded to double precision, since the
		  exponent range is so much wider.
		 */
		var first1 = find_first1_64(xs, 23)
		if first1 < 0
			-> std.flt64assem(n, -1023, 0)
		;;
		xs = xs << (52 - (first1 : uint64))
		xe = -126 - (23 - first1)
		-> std.flt64assem(n, xe, xs)
	;;

	-> std.flt64assem(n, xe, xs << (52 - 23))
}

const flt32fromflt64 = {f : flt64
	var n : bool, e : int64, s : uint64
	(n, e, s) = std.flt64explode(f)
	var ts : uint32
	var te : int32 = (e : int32)

	if e >= 128
		if e == 1023 && s != 0
			/* NaN */
			-> std.flt32assem(n, 128, 1)
		else
			/* infinity */
			-> std.flt32assem(n, 128, 0)
		;;
	;;

	if e >= -127
		/* normal */
		ts = ((s >> (52 - 23)) : uint32)
		if need_round_away(0, s, 52 - 23)
			ts++
			if ts & (1 << 24) != 0
				ts >>= 1
				te++
			;;
		;;
		if te >= -126
			-> std.flt32assem(n, te, ts)
		;;
	;;

	/* subnormal already, will have to go to 0 */
	if e == -1023
		-> std.flt32assem(n, -127, 0)
	;;

	/* subnormal (at least, it will be) */
	te = -127
	var shift : int64 = (52 - 23) + (-126 - e)
	var ts1 = shr(s, shift)
	ts = (ts1 : uint32)
	if need_round_away(0, s, shift)
		ts++
		if ts & (1 << 23) != 0
			/* false alarm, it's normal again */
			te++
		;;
	;;
	-> std.flt32assem(n, te, ts)
}

/* >> and <<, but without wrapping when the shift is >= 64 */
const shr = {u : uint64, s : int64
	if (s : uint64) >= 64
		-> 0
	else
		-> u >> (s : uint64)
	;;
}

const shl = {u : uint64, s : int64
	if (s : uint64) >= 64
		-> 0
	else
		-> u << (s : uint64)
	;;
}

/* Find the first 1 bit in a bitstring */
const find_first1_64 = {b : uint64, start : int64
	for var j = start; j >= 0; --j
		var m = shl(1, j)
		if b & m != 0
			-> j
		;;
	;;

	-> -1
}

const find_first1_64_hl = {h, l, start
	var first1_h = find_first1_64(h, start - 64)
	if first1_h >= 0
		-> first1_h + 64
	;;

	-> find_first1_64(l, 63)
}

/*
   For [ h ][ l ], where bitpos_last is the position of the last
   bit that was included in the truncated result (l's last bit has
   position 0), decide whether rounding up/away is needed. This is
   true if

    - following bitpos_last is a 1, then a non-zero sequence, or

    - following bitpos_last is a 1, then a zero sequence, and the
      round would be to even
 */
const need_round_away = {h : uint64, l : uint64, bitpos_last : int64
	var first_omitted_is_1 = false
	var nonzero_beyond = false
	if bitpos_last > 64
		first_omitted_is_1 = h & shl(1, bitpos_last - 1 - 64) != 0
		nonzero_beyond = nonzero_beyond || h & shr((-1 : uint64), 2 + 64 - (bitpos_last - 64)) != 0
		nonzero_beyond = nonzero_beyond || (l != 0)
	else
		first_omitted_is_1 = l & shl(1, bitpos_last - 1) != 0
		nonzero_beyond = nonzero_beyond || l & shr((-1 : uint64), 1 + 64 - bitpos_last) != 0
	;;

	if !first_omitted_is_1
		-> false
	;;

	if nonzero_beyond
		-> true
	;;

	var hl_is_odd = false

	if bitpos_last >= 64
		hl_is_odd = h & shl(1, bitpos_last - 64) != 0
	else
		hl_is_odd = l & shl(1, bitpos_last) != 0
	;;

	-> hl_is_odd
}

/*
   Perform high-prec multiplication: x * y = z1 + z2.
 */
const two_by_two64 = {x : flt64, y : flt64
	var xh : flt64 = std.flt64frombits(std.flt64bits(x) & twentysix_bits_mask)
	var xl : flt64 = x - xh
	var yh : flt64 = std.flt64frombits(std.flt64bits(y) & twentysix_bits_mask)
	var yl : flt64 = y - yh

	/* Multiply out */
	var a1 : flt64 = xh * yh
	var a2 : flt64 = xh * yl
	var a3 : flt64 = xl * yh
	var a4 : flt64 = xl * yl

	/* By-hand compensated summation */
	var yy, u, t, v, z, s, c
	if a2 < a3
		std.swap(&a3, &a2)
	;;

	s = a1
	c = 0.0

	/* a2 */
	(s, c) = fast2sum(s, a2)

	/* a3 */
	(yy, u) = slow2sum(c, a3)
	(t, v) = fast2sum(s, yy)
	z = u + v
	(s, c) = fast2sum(t, z)

	/* a4 */
	(yy, u) = slow2sum(c, a4)
	(t, v) = fast2sum(s, yy)
	z = u + v
	(s, c) = fast2sum(t, z)

	-> (s, c)
}

/*
   The same, for flt32s
 */
const two_by_two32 = {x : flt32, y : flt32
	var xL : flt64 = (x : flt64)
	var yL : flt64 = (y : flt64)
	var zL : flt64 = xL * yL
	var s  : flt32 = (zL : flt32)
	var sL : flt64 = (s : flt64)
	var cL : flt64 = zL - sL
	var c  : flt32 = (cL : flt32)

	-> (s, c)
}

const hl_mult = {a_h : flt64, a_l : flt64, b_h : flt64, b_l : flt64
	/*
	       [     a_h    ][     a_l    ] * [     b_h    ][     b_l    ]
	         =
	   (A) [          a_h*b_h         ]
	   (B)   +           [          a_h*b_l         ]
	   (C)   +           [          a_l*b_h         ]
	   (D)   +                         [          a_l*b_l         ]

	   We therefore want to keep all of A, and the top halves of the two
	   smaller products B and C.

	   To be pedantic, *_l could be much smaller than pictured above; *_h and
	   *_l need not butt up right against each other. But that won't cause
	   any problems; there's no way we'll ignore important information.
	 */
	var Ah, Al
	(Ah, Al) = two_by_two64(a_h, b_h)
	var Bh = a_h * b_l
	var Ch = a_l * b_h
	var resh, resl, t1, t2
	(resh, resl) = slow2sum(Bh, Ch)
	(resl, t1) = fast2sum(Al, resl)
	(resh, t2) = slow2sum(resh, resl)
	(resh, resl) = slow2sum(Ah, resh)
	-> (resh, resl + (t1 + t2))
}

const hl_add = {a_h : flt64, a_l : flt64, b_h : flt64, b_l : flt64
	/*
	  Not terribly clever, we just chain a couple of 2sums together. We are
	  free to impose the requirement that |a_h| > |b_h|, because we'll only
	  be using this for a = 1/5, 1/3, and the log(Fi)s from the C tables.
	  However, we can't guarantee that a_l > b_l. For example, compare C1[10]
	  to C2[18].
	 */

	var resh, resl, t1, t2
	(t1, t2) = slow2sum(a_l, b_l)
	(resl, t1) = slow2sum(b_h, t1)
	(resh, resl) = fast2sum(a_h, resl)
	-> (resh, resl + (t1 + t2))
}

const mag_cmp32 = {f : flt32, g : flt32
	var u = std.flt32bits(f) & ~(1 << 31)
	var v = std.flt32bits(g) & ~(1 << 31)
	-> std.numcmp(v, u)
}

const mag_cmp64 = {f : flt64, g : flt64
	var u = std.flt64bits(f) & ~(1 << 63)
	var v = std.flt64bits(g) & ~(1 << 63)
	-> std.numcmp(v, u)
}

/* Return (s, t) such that s + t = a + b, with s = rn(a + b). */
generic slow2sum = {a : @f, b : @f :: floating, numeric @f
	var s = a + b
	var aa = s - b
	var bb = s - aa
	var da = a - aa
	var db = b - bb
	var t = da + db
	-> (s, t)
}

/* Return (s, t) such that s + t = a + b, with s = rn(a + b), when you KNOW |a| > |b|. */
generic fast2sum = {a : @f, b : @f :: floating, numeric @f
	var s = a + b
	var z = s - a
	var t = b - z
	-> (s, t)
}

const triple_compensated_sum = {q : flt64[:]
	/* TODO: verify, with GAPPA or something, that this is correct. */
	std.sort(q, mag_cmp64)
	var s1 : flt64, s2 : flt64, s3
	var t1 : flt64, t2 : flt64, t3 : flt64, t4 : flt64, t5 : flt64, t6
	s1 = q[0]
	s2 = 0.0
	s3 = 0.0
	for qq : q[1:]
		(t5, t6) = slow2sum(s3, qq)
		(t3, t4) = slow2sum(s2, t5)
		(t1, t2) = slow2sum(s1, t3)
		s1 = t1
		(s2, s3) = slow2sum(t2, t4 + t6)
	;;

	-> (s1, s2, s3)
}

/*
   Round a + b to a flt32. Only notable if round(a) is a rounding
   tie, and b is non-zero
 */
const round_down = {a : flt64, b : flt64
	var au : uint64 = std.flt64bits(a)
	if au & 0x0000000070000000 == 0x0000000070000000
		if b > 0.0
			au++
		elif b < 0.0
			au--
		;;
		-> (std.flt64frombits(au) : flt32)
	;;

	-> (a : flt32)
}