shithub: mc

Download patch

ref: debd13efe12eaca072229e9859aaa6df8708a09f
parent: 408acb5431ff66619be76109f0a83f7f06df5d69
author: S. Gilles <sgilles@math.umd.edu>
date: Thu Mar 22 17:32:51 EDT 2018

Rewrite fma32.

The previous algorithm (from musl) fails for

    0xa4932927 * 0xc565bc34 + 0x316887af

which should return 0x31688bcf (it instead returns 0x31688bd0).

--- a/lib/math/fma-impl.myr
+++ b/lib/math/fma-impl.myr
@@ -33,63 +33,108 @@
 
 	/*
 	   At this point, r is probably the correct answer. The
-	   only issue is that if truncating r to a flt32 causes
-	   rounding-to-even, and it was obtained by rounding in the
-	   first place, direction, then we'd be over-rounding. The
-	   only way this could possibly be a problem is if the
-	   product step rounded to the halfway point (a 100...0,
-	   with the 1 just outside truncation range).
-         */
+	   only issue is the rounding.
+
+	   Ex 1: If x*y > 0 and z is a tiny, negative number, then
+	   adding z probably does no rounding. However, if
+	   truncating to 23 bits of precision would cause round-to-even,
+	   and that round would be upwards, then we need to remember
+	   those trailing bits of z and cancel the rounding.
+
+	   Ex 2: If x, y, z > 0, and z is small, with
+	                 last bit in flt64 |
+	          last bit in flt32 v      v
+	   x * y = ...............101011..11
+	       z =                          10000...,
+	   then x * y + z will be rounded to
+	           ...............101100..00,
+	   and then as a flt32 it will become
+	           ...............110,
+	   Even though, looking at the original bits, it doesn't
+	   "deserve" the final rounding.
+
+	   These can only happen if r is non-inf, non-NaN, and the
+	   lower 29 bits correspond to "exactly halfway".
+	 */
 	if re == 1024 || rs & 0x1fffffff != 0x10000000
 		-> flt32fromflt64(r)
 	;;
 
 	/*
-	   At this point, there's definitely about to be a rounding
-	   error. To figure out what to do, compute prod + z with
-	   round-to-zero. If we get r again, then it's okay to round
-	   r upwards, because it hasn't been rounded away from zero
-	   yet and we allow ourselves one such rounding.
+	   At this point, a rounding is about to happen. We need
+	   to know what direction that rounding is, so that we can
+	   tell if it's wrong. +1 means "away from 0", -1 means
+	   "towards 0".
 	 */
 	var zn, ze, zs
 	(zn, ze, zs) = std.flt64explode(zd)
-	if ze == -1023
-		ze = -1022
+	var round_direction = 0
+	if rs & 0x20000000 == 0
+		round_direction = -1
+	else
+		round_direction = 1
 	;;
-	var rtn, rte, rts
-	if pe >= ze && pn == zn
-		(rtn, rte, rts) = (pn, pe, ps)
-		rts += shr(zs, pe - ze)
-	elif pe > ze || (pe == ze && ps > zs)
-		(rtn, rte, rts) = (pn, pe, ps)
-		rts -= shr(zs, pe - ze)
-		if shr((-1 : uint64), 64 - std.min(64, (pe - ze))) & zs != 0
-			rts--
-		;;
-	elif pe < ze && pn == zn
-		(rtn, rte, rts) = (zn, ze, zs)
-		rts += shr(ps, ze - pe)
+
+	var smaller, larger, smaller_e, larger_e
+	if pe > ze || (pe == ze && ps > zs)
+		(smaller, larger, smaller_e, larger_e) = (zs, ps, ze, pe)
 	else
-		(rtn, rte, rts) = (zn, ze, zs)
-		rts -= shr(ps, ze - pe)
-		if shr((-1 : uint64), 64 - std.min(64, (ze - pe))) & ps != 0
-			rts--
-		;;
+		(smaller, larger, smaller_e, larger_e) = (ps, zs, pe, ze)
 	;;
+	var mask = shr((-1 : uint64), 64 - std.min(64, larger_e - smaller_e))
+	var prevent_rounding = false
+	if (round_direction > 0 && pn != zn) || (round_direction < 0 && pn == zn)
+		/*
+		   The prospective rounding disagrees with the
+		   signage. We are potentially in the case of Ex
+		   1.
 
-	if rts & (1 << 53) != 0
-		rts >>= 1
-		rte++
+		   Look at the bits (of the smaller flt64) that are
+		   outside the range of r. If there are any such
+		   bits, we need to cancel the rounding.
+
+		   We certainly need to consider bits very far to
+		   the right, but there's an awkwardness concerning
+		   the bit just outside the flt64 range: it governed
+		   round-to-even, so it might have had an effect.
+		   We only care about bits which did not have an
+		   effect. Therefore, we perform the subtraction
+		   using only the bits from smaller that lie in
+		   larger's range, then check whether the result
+		   is susceptible to round-to-even.
+
+		   (Since we only care about the last bit, and the
+		   base is 2, subtraction or addition are equally
+		   useful.)
+		*/
+		if (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 != 0
+			mask >>= 1
+		;;
+		prevent_rounding = smaller & mask != 0
+	else
+		/*
+		   The prospective rounding agrees with the signage.
+		   We are potentially in the case of Ex 2.
+
+		   We just need to check if r was obtained by
+		   rounding in the addition step. In this case, we
+		   still check the smaller/larger, and we only
+		   care about round-to-even. Any
+		   rounding that happened previously is enough
+		   reason to disqualify this next rounding.
+		*/
+		prevent_rounding = (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 != 0
 	;;
 
-	if rn == rtn && rs == rts && re == rte
-		rts++
-		if rts & (1 << 53) != 0
-			rts >>= 1
-			rte++
+	if prevent_rounding
+		if round_direction > 0
+			rs--
+		else
+			rs++
 		;;
 	;;
-	-> flt32fromflt64(std.flt64assem(rtn, rte, rts))
+
+	-> flt32fromflt64(std.flt64assem(rn, re, rs))
 }
 
 const flt64fromflt32 = {f : flt32
--- a/lib/math/test/fma-impl.myr
+++ b/lib/math/test/fma-impl.myr
@@ -35,7 +35,6 @@
 		(0x3745461a, 0x4db9b736, 0xb6d7deff, 0x458f1cd8),
 		(0xa3ccfd37, 0x7f800000, 0xed328e70, 0xff800000),
 		(0xa3790205, 0x5033a3e6, 0xa001fd11, 0xb42ebbd5),
-		(0xa4932927, 0xc565bc34, 0x316887af, 0x31688bcf),
 		(0x83dd6ede, 0x31ddf8e6, 0x01fea4c8, 0x01fea4c7),
 		(0xa4988128, 0x099a41ad, 0x00800000, 0x00800000),
 		(0x1e0479cd, 0x91d5fcb4, 0x00800000, 0x00800000),
@@ -45,6 +44,10 @@
 		(0xa19e9a6f, 0xb49af3e3, 0xa2468b59, 0xa2468b57),
 		(0xd119e996, 0x8e5ad0e3, 0x247e0028, 0x247e83b7),
 		(0x381adbc6, 0x00ee4f61, 0x005f2aeb, 0x005f2d2c),
+
+		/* These ones are especially tricky */
+		(0x65dbf098, 0xd5beb8b4, 0x7c23db61, 0x73027654),
+		(0xa4932927, 0xc565bc34, 0x316887af, 0x31688bcf),
 	][:]
 
 	for (x, y, z, r) : inputs