shithub: mc

Download patch

ref: a33e94a67326d7774e7373b83597d84ee1eb347b
parent: 7c4d2abbe14d3d0e9d2f5872e0964677c5dc783d
author: S. Gilles <sgilles@math.umd.edu>
date: Wed Mar 21 15:17:41 EDT 2018

Replace fma32 with simpler, cleaner algorithm from musl

--- a/lib/math/fpmath-fma-impl.myr
+++ b/lib/math/fpmath-fma-impl.myr
@@ -28,34 +28,73 @@
 	;;
 
 	var r : flt64 = prod + zd
+	var rn, re, rs
+	(rn, re, rs) = std.flt64explode(r)
 
-	var zn, ze, zs
-	(zn, ze, zs) = std.flt32explode(z)
-	if ze == -127
-		ze = -126
+	/*
+	   At this point, r is probably the correct answer. The
+	   only issue is that if truncating r to a flt32 causes
+	   rounding-to-even, and it was obtained by rounding in the
+	   first place, direction, then we'd be over-rounding. The
+	   only way this could possibly be a problem is if the
+	   product step rounded to the halfway point (a 100...0,
+	   with the 1 just outside truncation range).
+         */
+	if re == 1024 || rs & 0x1fffffff != 0x10000000
+		-> flt32fromflt64(r)
 	;;
 
-	var beyond : int64 = 0
-	if pe - 52 > (ze : int64) - 23
-		/*
-		   Identify all the bits of z that were too far
-		   away to be added into the product. If there are
-		   any, they will be used later to influence rounding
-		   away from the halfway mark.
-		 */
-		var shift = std.min((pe - 52) - ((ze : int64) - 23) - 1, 64)
-		var zs1 : uint64 = (zs : uint64)
-		if zs1 & shr((-1 : uint64), 64 - shift) != 0
-			if zn == pn
-				beyond = 1
-			else
-				beyond = -1
-			;;
+	/* We can check if rounding was performed by undoing */
+	if flt32fromflt64(r - prod) == z
+		-> flt32fromflt64(r)
+	;;
+
+	/*
+	   At this point, there's definitely about to be a rounding
+	   error. To figure out what to do, compute prod + z with
+	   round-to-zero. If we get r again, then it's okay to round
+	   r upwards, because it hasn't been rounded away from zero
+	   yet and we allow ourselves one such rounding.
+	 */
+	var zn, ze, zs
+	(zn, ze, zs) = std.flt64explode(zd)
+	if ze == -1023
+		ze = -1022
+	;;
+	var rtn, rte, rts
+	if pe >= ze && pn == zn
+		(rtn, rte, rts) = (pn, pe, ps)
+		rts += shr(zs, pe - ze)
+	elif pe > ze || (pe == ze && ps > zs)
+		(rtn, rte, rts) = (pn, pe, ps)
+		rts -= shr(zs, pe - ze)
+		if shr((-1 : uint64), 64 - std.min(64, (pe - ze))) & zs != 0
+			rts--
 		;;
+	elif pe < ze && pn == zn
+		(rtn, rte, rts) = (zn, ze, zs)
+		rts += shr(ps, ze - pe)
+	else
+		(rtn, rte, rts) = (zn, ze, zs)
+		rts -= shr(ps, ze - pe)
+		if shr((-1 : uint64), 64 - std.min(64, (ze - pe))) & ps != 0
+			rts--
+		;;
 	;;
 
-	var r32 : flt32 = flt32fromflt64(r, beyond)
-	-> r32
+	if rts & (1 << 53) != 0
+		rts >>= 1
+		rte++
+	;;
+
+	if rn == rtn && rs == rts && re == rte
+		rts++
+		if rts & (1 << 53) != 0
+			rts >>= 1
+			rte++
+		;;
+	;;
+	-> flt32fromflt64(std.flt64assem(rtn, rte, rts))
 }
 
 const flt64fromflt32 = {f : flt32
@@ -84,7 +123,7 @@
 	-> std.flt64assem(n, xe, xs << (52 - 23))
 }
 
-const flt32fromflt64 = {f : flt64, beyond : int64
+const flt32fromflt64 = {f : flt64
 	var n : bool, e : int64, s : uint64
 	(n, e, s) = std.flt64explode(f)
 	var ts : uint32
@@ -103,7 +142,7 @@
 	if e >= -126
 		/* normal */
 		ts = ((s >> (52 - 23)) : uint32)
-		if need_round_away(0, s, 52 - 23, beyond)
+		if need_round_away(0, s, 52 - 23)
 			ts++
 			if ts & (1 << 24) != 0
 				ts >>= 1
@@ -123,7 +162,7 @@
 	var shift : int64 = (52 - 23) + (-126 - e)
 	var ts1 = shr(s, shift)
 	ts = (ts1 : uint32)
-	if need_round_away(0, s, shift, beyond)
+	if need_round_away(0, s, shift)
 		ts++
 	;;
 	-> std.flt32assem(n, te, ts)
@@ -283,7 +322,7 @@
 			res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022))
 		;;
 
-		if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e), 0)
+		if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e))
 			res_s++
 		;;
 
@@ -306,13 +345,13 @@
 
 	if res_first1 - 52 >= 64
 		res_s = shr(res_h, (res_first1 : int64) - 64 - 52)
-		if need_round_away(res_h, res_l, res_first1 - 52, 0)
+		if need_round_away(res_h, res_l, res_first1 - 52)
 			res_s++
 		;;
 	elif res_first1 - 52 >= 0
 		res_s = shl(res_h, 64 - (res_first1 - 52))
 		res_s |= shr(res_l, res_first1 - 52)
-		if need_round_away(res_h, res_l, res_first1 - 52, 0)
+		if need_round_away(res_h, res_l, res_first1 - 52)
 			res_s++
 		;;
 	else
@@ -484,14 +523,10 @@
 
     - following bitpos_last is a 1, then a zero sequence, and the
       round would be to even
-
-   The beyond parameter indicates whether there is lingering
-   addition/subtraction past the range of l. This adds a bit more
-   information about rounding before we hit round-to-even.
  */
-const need_round_away = {h : uint64, l : uint64, bitpos_last : int64, beyond : int64
+const need_round_away = {h : uint64, l : uint64, bitpos_last : int64
 	var first_omitted_is_1 = false
-	var nonzero_beyond = beyond > 0
+	var nonzero_beyond = false
 	if bitpos_last > 64
 		first_omitted_is_1 = h & shl(1, bitpos_last - 1 - 64) != 0
 		nonzero_beyond = nonzero_beyond || h & shr((-1 : uint64), 2 + 64 - (bitpos_last - 64)) != 0
@@ -517,5 +552,5 @@
 		hl_is_odd = l & shl(1, bitpos_last) != 0
 	;;
 
-	-> hl_is_odd && (beyond >= 0)
+	-> hl_is_odd
 }