ref: a68e92f1e80f262e04d6e6c1d27af80199afbdc0
dir: /lib/math/fma-impl.myr/
use std
use "util"
pkg math =
pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
;;
const exp_mask32 : uint32 = 0xff << 23
const exp_mask64 : uint64 = 0x7ff << 52
pkglocal const fma32 = {x : flt32, y : flt32, z : flt32
var xn, yn
(xn, _, _) = std.flt32explode(x)
(yn, _, _) = std.flt32explode(y)
var xd : flt64 = flt64fromflt32(x)
var yd : flt64 = flt64fromflt32(y)
var zd : flt64 = flt64fromflt32(z)
var prod : flt64 = xd * yd
var pn, pe, ps
(pn, pe, ps) = std.flt64explode(prod)
if pe == -1023
pe = -1022
;;
if pn != (xn != yn)
/* In case of NaNs, sign might not have been preserved */
pn = (xn != yn)
prod = std.flt64assem(pn, pe, ps)
;;
var r : flt64 = prod + zd
var rn, re, rs
(rn, re, rs) = std.flt64explode(r)
/*
At this point, r is probably the correct answer. The
only issue is the rounding.
Ex 1: If x*y > 0 and z is a tiny, negative number, then
adding z probably does no rounding. However, if
truncating to 23 bits of precision would cause round-to-even,
and that round would be upwards, then we need to remember
those trailing bits of z and cancel the rounding.
Ex 2: If x, y, z > 0, and z is small, with
last bit in flt64 |
last bit in flt32 v v
x * y = ...............101011..11
z = 10000...,
then x * y + z will be rounded to
...............101100..00,
and then as a flt32 it will become
...............110,
Even though, looking at the original bits, it doesn't
"deserve" the final rounding.
These can only happen if r is non-inf, non-NaN, and the
lower 29 bits correspond to "exactly halfway".
*/
if re == 1024 || rs & 0x1fffffff != 0x10000000
-> flt32fromflt64(r)
;;
/*
At this point, a rounding is about to happen. We need
to know what direction that rounding is, so that we can
tell if it's wrong. +1 means "away from 0", -1 means
"towards 0".
*/
var zn, ze, zs
(zn, ze, zs) = std.flt64explode(zd)
var round_direction = 0
if rs & 0x20000000 == 0
round_direction = -1
else
round_direction = 1
;;
var smaller, larger, smaller_e, larger_e
if pe > ze || (pe == ze && ps > zs)
(smaller, larger, smaller_e, larger_e) = (zs, ps, ze, pe)
else
(smaller, larger, smaller_e, larger_e) = (ps, zs, pe, ze)
;;
var mask = shr((-1 : uint64), 64 - std.min(64, larger_e - smaller_e))
var prevent_rounding = false
if (round_direction > 0 && pn != zn) || (round_direction < 0 && pn == zn)
/*
The prospective rounding disagrees with the
signage. We are potentially in the case of Ex
1.
Look at the bits (of the smaller flt64) that are
outside the range of r. If there are any such
bits, we need to cancel the rounding.
We certainly need to consider bits very far to
the right, but there's an awkwardness concerning
the bit just outside the flt64 range: it governed
round-to-even, so it might have had an effect.
We only care about bits which did not have an
effect. Therefore, we perform the subtraction
using only the bits from smaller that lie in
larger's range, then check whether the result
is susceptible to round-to-even.
(Since we only care about the last bit, and the
base is 2, subtraction or addition are equally
useful.)
*/
if (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 == 0
prevent_rounding = smaller & mask != 0
;;
else
/*
The prospective rounding agrees with the signage.
We are potentially in the case of Ex 2.
We just need to check if r was obtained by
rounding in the addition step. In this case, we
still check the smaller/larger, and we only
care about round-to-even. Any
rounding that happened previously is enough
reason to disqualify this next rounding.
*/
prevent_rounding = (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 != 0
;;
if prevent_rounding
if round_direction > 0
rs--
else
rs++
;;
;;
-> flt32fromflt64(std.flt64assem(rn, re, rs))
}
pkglocal const fma64 = {x : flt64, y : flt64, z : flt64
var xn : bool, yn : bool, zn : bool
var xe : int64, ye : int64, ze : int64
var xs : uint64, ys : uint64, zs : uint64
var xb : uint64 = std.flt64bits(x)
var yb : uint64 = std.flt64bits(y)
var zb : uint64 = std.flt64bits(z)
/* check for both NaNs and infinities */
if xb & exp_mask64 == exp_mask64 || \
yb & exp_mask64 == exp_mask64
-> x * y + z
elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0
-> x * y + z
elif zb & exp_mask64 == exp_mask64
-> z
;;
(xn, xe, xs) = std.flt64explode(x)
(yn, ye, ys) = std.flt64explode(y)
(zn, ze, zs) = std.flt64explode(z)
if xe == -1023
xe = -1022
;;
if ye == -1023
ye = -1022
;;
if ze == -1023
ze = -1022
;;
/* Keep product in high/low uint64s */
var xs_h : uint64 = xs >> 32
var ys_h : uint64 = ys >> 32
var xs_l : uint64 = xs & 0xffffffff
var ys_l : uint64 = ys & 0xffffffff
var t_l : uint64 = xs_l * ys_l
var t_m : uint64 = xs_l * ys_h + xs_h * ys_l
var t_h : uint64 = xs_h * ys_h
var prod_l : uint64 = t_l + (t_m << 32)
var prod_h : uint64 = t_h + (t_m >> 32)
if t_l > prod_l
prod_h++
;;
var prod_n = xn != yn
var prod_lastbit_e = (xe - 52) + (ye - 52)
var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105)
var prod_firstbit_e = prod_lastbit_e + prod_first1
var z_firstbit_e = ze
var z_lastbit_e = ze - 52
var z_first1 = 52
/* subnormals could throw firstbit_e calculations out of whack */
if (zb & exp_mask64 == 0)
z_first1 = find_first1_64(zs, z_first1)
z_firstbit_e = z_lastbit_e + z_first1
;;
var res_n
var res_h = 0
var res_l = 0
var res_first1
var res_lastbit_e
var res_firstbit_e
if prod_n == zn
res_n = prod_n
/*
Align prod and z so that the top bit of the
result is either 53 or 54, then add.
*/
if prod_firstbit_e >= z_firstbit_e
/*
[ prod_h ][ prod_l ]
[ z...
*/
res_lastbit_e = prod_lastbit_e
(res_h, res_l) = (prod_h, prod_l)
(res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e)
else
/*
[ prod_h ][ prod_l ]
[ z...
*/
res_lastbit_e = z_lastbit_e - 64
res_h = zs
res_l = 0
if prod_lastbit_e >= res_lastbit_e + 64
/* In this situation, prod must be extremely subnormal */
res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64)
elif prod_lastbit_e >= res_lastbit_e
res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e)
res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e)
res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e)
elif prod_lastbit_e + 64 >= res_lastbit_e
res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e)
var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e)
var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e)
res_l = l1 + l2
if res_l < l1
res_h++
;;
elif prod_lastbit_e + 128 >= res_lastbit_e
res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64)
;;
;;
else
match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e)
| `std.Equal: -> 0.0
| `std.Before:
/* prod > z */
res_n = prod_n
res_lastbit_e = prod_lastbit_e
(res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e)
| `std.After:
/* z > prod */
res_n = zn
res_lastbit_e = z_lastbit_e - 64
(res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64))
(res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64))
;;
;;
res_first1 = 64 + find_first1_64(res_h, 55)
if res_first1 == 63
res_first1 = find_first1_64(res_l, 63)
;;
res_firstbit_e = res_first1 + res_lastbit_e
/*
Finally, res_h and res_l are the high and low bits of
the result. They now need to be assembled into a flt64.
Subnormals and infinities could be a problem.
*/
var res_s = 0
if res_firstbit_e <= -1023
/* Subnormal case */
if res_lastbit_e + 128 < 12 - 1022
res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128))
res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
elif res_lastbit_e + 64 < 12 - 1022
res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
else
res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022))
;;
if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e))
res_s++
;;
/* No need for exponents, they are all zero */
var res = res_s
if res_n
res |= (1 << 63)
;;
-> std.flt64frombits(res)
;;
if res_firstbit_e >= 1024
/* Infinity case */
if res_n
-> std.flt64frombits(0xfff0000000000000)
else
-> std.flt64frombits(0x7ff0000000000000)
;;
;;
if res_first1 - 52 >= 64
res_s = shr(res_h, (res_first1 : int64) - 64 - 52)
if need_round_away(res_h, res_l, res_first1 - 52)
res_s++
;;
elif res_first1 - 52 >= 0
res_s = shl(res_h, 64 - (res_first1 - 52))
res_s |= shr(res_l, res_first1 - 52)
if need_round_away(res_h, res_l, res_first1 - 52)
res_s++
;;
else
res_s = shl(res_h, res_first1 - 52)
;;
/* The res_s++s might have messed everything up */
if res_s & (1 << 53) != 0
res_s >= 1
res_firstbit_e++
if res_firstbit_e >= 1024
if res_n
-> std.flt64frombits(0xfff0000000000000)
else
-> std.flt64frombits(0x7ff0000000000000)
;;
;;
;;
-> std.flt64assem(res_n, res_firstbit_e, res_s)
}
/*
Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding
right-shift is used. This is aligned such that if s == 0, then
the result is [ h ][ l + a ]
*/
const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64
if s >= 64
-> (h + shl(a, s - 64), l)
elif s >= 0
var new_h = h + shr(a, 64 - s)
var sa = shl(a, s)
var new_l = l + sa
if new_l < l
new_h++
;;
-> (new_h, new_l)
else
var new_h = h
var sa = shr(a, -s)
var new_l = l + sa
if new_l < l
new_h++
;;
-> (new_h, new_l)
;;
}
/* As above, but subtract (a << s) */
const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64
if s >= 64
-> (h - shl(a, s - 64), l)
elif s >= 0
var new_h = h - shr(a, 64 - s)
var sa = shl(a, s)
var new_l = l - sa
if sa > l
new_h--
;;
-> (new_h, new_l)
else
var new_h = h
var sa = shr(a, -s)
var new_l = l - sa
if sa > l
new_h--
;;
-> (new_h, new_l)
;;
}
const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64
if hl_firstbit_e > z_firstbit_e
-> `std.Before
elif hl_firstbit_e < z_firstbit_e
-> `std.After
;;
var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64)
var z_k : int64 = (z_firstbit_e - z_lastbit_e)
while h_k >= 0 && z_k >= 0
var h1 = h & shl(1, h_k) != 0
var z1 = z & shl(1, z_k) != 0
if h1 && !z1
-> `std.Before
elif !h1 && z1
-> `std.After
;;
h_k--
z_k--
;;
if z_k < 0
if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0)
-> `std.Before
else
-> `std.Equal
;;
;;
var l_k : int64 = 63
while l_k >= 0 && z_k >= 0
var l1 = l & shl(1, l_k) != 0
var z1 = z & shl(1, z_k) != 0
if l1 && !z1
-> `std.Before
elif !l1 && z1
-> `std.After
;;
l_k--
z_k--
;;
if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0)
-> `std.Before
elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0)
-> `std.After
;;
-> `std.Equal
}