ref: 55f06152abf5c462d459acbc33ac2fc09e37fe46
dir: /lib/math/fpmath-fma-impl.myr/
use std
pkg math =
pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
;;
const exp_mask32 : uint32 = 0xff << 23
const exp_mask64 : uint64 = 0x7ff << 52
pkglocal const fma32 = {x : flt32, y : flt32, z : flt32
var xn, yn
(xn, _, _) = std.flt32explode(x)
(yn, _, _) = std.flt32explode(y)
var xd : flt64 = flt64fromflt32(x)
var yd : flt64 = flt64fromflt32(y)
var zd : flt64 = flt64fromflt32(z)
var prod : flt64 = xd * yd
var pn, pe, ps
(pn, pe, ps) = std.flt64explode(prod)
if pe == -1023
pe = -1022
;;
if pn != (xn != yn)
/* In case of NaNs, sign might not have been preserved */
pn = (xn != yn)
prod = std.flt64assem(pn, pe, ps)
;;
var r : flt64 = prod + zd
var rn, re, rs
(rn, re, rs) = std.flt64explode(r)
/*
At this point, r is probably the correct answer. The
only issue is that if truncating r to a flt32 causes
rounding-to-even, and it was obtained by rounding in the
first place, direction, then we'd be over-rounding. The
only way this could possibly be a problem is if the
product step rounded to the halfway point (a 100...0,
with the 1 just outside truncation range).
*/
if re == 1024 || rs & 0x1fffffff != 0x10000000
-> flt32fromflt64(r)
;;
/* We can check if rounding was performed by undoing */
if flt32fromflt64(r - prod) == z
-> flt32fromflt64(r)
;;
/*
At this point, there's definitely about to be a rounding
error. To figure out what to do, compute prod + z with
round-to-zero. If we get r again, then it's okay to round
r upwards, because it hasn't been rounded away from zero
yet and we allow ourselves one such rounding.
*/
var zn, ze, zs
(zn, ze, zs) = std.flt64explode(zd)
if ze == -1023
ze = -1022
;;
var rtn, rte, rts
if pe >= ze && pn == zn
(rtn, rte, rts) = (pn, pe, ps)
rts += shr(zs, pe - ze)
elif pe > ze || (pe == ze && ps > zs)
(rtn, rte, rts) = (pn, pe, ps)
rts -= shr(zs, pe - ze)
if shr((-1 : uint64), 64 - std.min(64, (pe - ze))) & zs != 0
rts--
;;
elif pe < ze && pn == zn
(rtn, rte, rts) = (zn, ze, zs)
rts += shr(ps, ze - pe)
else
(rtn, rte, rts) = (zn, ze, zs)
rts -= shr(ps, ze - pe)
if shr((-1 : uint64), 64 - std.min(64, (ze - pe))) & ps != 0
rts--
;;
;;
if rts & (1 << 53) != 0
rts >>= 1
rte++
;;
if rn == rtn && rs == rts && re == rte
rts++
if rts & (1 << 53) != 0
rts >>= 1
rte++
;;
;;
-> flt32fromflt64(std.flt64assem(rtn, rte, rts))
}
const flt64fromflt32 = {f : flt32
var n, e, s
(n, e, s) = std.flt32explode(f)
var xs : uint64 = (s : uint64)
var xe : int64 = (e : int64)
if e == 128
-> std.flt64assem(n, 1024, xs)
elif e == -127
/*
All subnormals in single precision (except 0.0s)
can be upgraded to double precision, since the
exponent range is so much wider.
*/
var first1 = find_first1_64(xs, 23)
if first1 < 0
-> std.flt64assem(n, -1023, 0)
;;
xs = xs << (52 - (first1 : uint64))
xe = -126 - (23 - first1)
-> std.flt64assem(n, xe, xs)
;;
-> std.flt64assem(n, xe, xs << (52 - 23))
}
const flt32fromflt64 = {f : flt64
var n : bool, e : int64, s : uint64
(n, e, s) = std.flt64explode(f)
var ts : uint32
var te : int32 = (e : int32)
if e >= 128
if e == 1023 && s != 0
/* NaN */
-> std.flt32assem(n, 128, 1)
else
/* infinity */
-> std.flt32assem(n, 128, 0)
;;
;;
if e >= -126
/* normal */
ts = ((s >> (52 - 23)) : uint32)
if need_round_away(0, s, 52 - 23)
ts++
if ts & (1 << 24) != 0
ts >>= 1
te++
;;
;;
-> std.flt32assem(n, te, ts)
;;
/* subnormal already, will have to go to 0 */
if e == -1023
-> std.flt32assem(n, -127, 0)
;;
/* subnormal (at least, it will be) */
te = -127
var shift : int64 = (52 - 23) + (-126 - e)
var ts1 = shr(s, shift)
ts = (ts1 : uint32)
if need_round_away(0, s, shift)
ts++
;;
-> std.flt32assem(n, te, ts)
}
pkglocal const fma64 = {x : flt64, y : flt64, z : flt64
var xn : bool, yn : bool, zn : bool
var xe : int64, ye : int64, ze : int64
var xs : uint64, ys : uint64, zs : uint64
var xb : uint64 = std.flt64bits(x)
var yb : uint64 = std.flt64bits(y)
var zb : uint64 = std.flt64bits(z)
/* check for both NaNs and infinities */
if xb & exp_mask64 == exp_mask64 || \
yb & exp_mask64 == exp_mask64
-> x * y + z
elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0
-> x * y + z
elif zb & exp_mask64 == exp_mask64
-> z
;;
(xn, xe, xs) = std.flt64explode(x)
(yn, ye, ys) = std.flt64explode(y)
(zn, ze, zs) = std.flt64explode(z)
if xe == -1023
xe = -1022
;;
if ye == -1023
ye = -1022
;;
if ze == -1023
ze = -1022
;;
/* Keep product in high/low uint64s */
var xs_h : uint64 = xs >> 32
var ys_h : uint64 = ys >> 32
var xs_l : uint64 = xs & 0xffffffff
var ys_l : uint64 = ys & 0xffffffff
var t_l : uint64 = xs_l * ys_l
var t_m : uint64 = xs_l * ys_h + xs_h * ys_l
var t_h : uint64 = xs_h * ys_h
var prod_l : uint64 = t_l + (t_m << 32)
var prod_h : uint64 = t_h + (t_m >> 32)
if t_l > prod_l
prod_h++
;;
var prod_n = xn != yn
var prod_lastbit_e = (xe - 52) + (ye - 52)
var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105)
var prod_firstbit_e = prod_lastbit_e + prod_first1
var z_firstbit_e = ze
var z_lastbit_e = ze - 52
var z_first1 = 52
/* subnormals could throw firstbit_e calculations out of whack */
if (zb & exp_mask64 == 0)
z_first1 = find_first1_64(zs, z_first1)
z_firstbit_e = z_lastbit_e + z_first1
;;
var res_n
var res_h = 0
var res_l = 0
var res_first1
var res_lastbit_e
var res_firstbit_e
if prod_n == zn
res_n = prod_n
/*
Align prod and z so that the top bit of the
result is either 53 or 54, then add.
*/
if prod_firstbit_e >= z_firstbit_e
/*
[ prod_h ][ prod_l ]
[ z...
*/
res_lastbit_e = prod_lastbit_e
(res_h, res_l) = (prod_h, prod_l)
(res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e)
else
/*
[ prod_h ][ prod_l ]
[ z...
*/
res_lastbit_e = z_lastbit_e - 64
res_h = zs
res_l = 0
if prod_lastbit_e >= res_lastbit_e + 64
/* In this situation, prod must be extremely subnormal */
res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64)
elif prod_lastbit_e >= res_lastbit_e
res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e)
res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e)
res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e)
elif prod_lastbit_e + 64 >= res_lastbit_e
res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e)
var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e)
var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e)
res_l = l1 + l2
if res_l < l1
res_h++
;;
elif prod_lastbit_e + 128 >= res_lastbit_e
res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64)
;;
;;
else
match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e)
| `std.Equal: -> 0.0
| `std.Before:
/* prod > z */
res_n = prod_n
res_lastbit_e = prod_lastbit_e
(res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e)
| `std.After:
/* z > prod */
res_n = zn
res_lastbit_e = z_lastbit_e - 64
(res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64))
(res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64))
;;
;;
res_first1 = 64 + find_first1_64(res_h, 55)
if res_first1 == 63
res_first1 = find_first1_64(res_l, 63)
;;
res_firstbit_e = res_first1 + res_lastbit_e
/*
Finally, res_h and res_l are the high and low bits of
the result. They now need to be assembled into a flt64.
Subnormals and infinities could be a problem.
*/
var res_s = 0
if res_firstbit_e <= -1023
/* Subnormal case */
if res_lastbit_e + 128 < 12 - 1022
res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128))
res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
elif res_lastbit_e + 64 < 12 - 1022
res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
else
res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022))
;;
if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e))
res_s++
;;
/* No need for exponents, they are all zero */
var res = res_s
if res_n
res |= (1 << 63)
;;
-> std.flt64frombits(res)
;;
if res_firstbit_e >= 1024
/* Infinity case */
if res_n
-> std.flt64frombits(0xfff0000000000000)
else
-> std.flt64frombits(0x7ff0000000000000)
;;
;;
if res_first1 - 52 >= 64
res_s = shr(res_h, (res_first1 : int64) - 64 - 52)
if need_round_away(res_h, res_l, res_first1 - 52)
res_s++
;;
elif res_first1 - 52 >= 0
res_s = shl(res_h, 64 - (res_first1 - 52))
res_s |= shr(res_l, res_first1 - 52)
if need_round_away(res_h, res_l, res_first1 - 52)
res_s++
;;
else
res_s = shl(res_h, res_first1 - 52)
;;
/* The res_s++s might have messed everything up */
if res_s & (1 << 53) != 0
res_s >= 1
res_firstbit_e++
if res_firstbit_e >= 1024
if res_n
-> std.flt64frombits(0xfff0000000000000)
else
-> std.flt64frombits(0x7ff0000000000000)
;;
;;
;;
-> std.flt64assem(res_n, res_firstbit_e, res_s)
}
/* >> and <<, but without wrapping when the shift is >= 64 */
const shr : (u : uint64, s : int64 -> uint64) = {u : uint64, s : int64
if (s : uint64) >= 64
-> 0
else
-> u >> (s : uint64)
;;
}
const shl : (u : uint64, s : int64 -> uint64) = {u : uint64, s : int64
if (s : uint64) >= 64
-> 0
else
-> u << (s : uint64)
;;
}
/*
Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding
right-shift is used. This is aligned such that if s == 0, then
the result is [ h ][ l + a ]
*/
const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64
if s >= 64
-> (h + shl(a, s - 64), l)
elif s >= 0
var new_h = h + shr(a, 64 - s)
var sa = shl(a, s)
var new_l = l + sa
if new_l < l
new_h++
;;
-> (new_h, new_l)
else
var new_h = h
var sa = shr(a, -s)
var new_l = l + sa
if new_l < l
new_h++
;;
-> (new_h, new_l)
;;
}
/* As above, but subtract (a << s) */
const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64
if s >= 64
-> (h - shl(a, s - 64), l)
elif s >= 0
var new_h = h - shr(a, 64 - s)
var sa = shl(a, s)
var new_l = l - sa
if sa > l
new_h--
;;
-> (new_h, new_l)
else
var new_h = h
var sa = shr(a, -s)
var new_l = l - sa
if sa > l
new_h--
;;
-> (new_h, new_l)
;;
}
const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64
if hl_firstbit_e > z_firstbit_e
-> `std.Before
elif hl_firstbit_e < z_firstbit_e
-> `std.After
;;
var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64)
var z_k : int64 = (z_firstbit_e - z_lastbit_e)
while h_k >= 0 && z_k >= 0
var h1 = h & shl(1, h_k) != 0
var z1 = z & shl(1, z_k) != 0
if h1 && !z1
-> `std.Before
elif !h1 && z1
-> `std.After
;;
h_k--
z_k--
;;
if z_k < 0
if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0)
-> `std.Before
else
-> `std.Equal
;;
;;
var l_k : int64 = 63
while l_k >= 0 && z_k >= 0
var l1 = l & shl(1, l_k) != 0
var z1 = z & shl(1, z_k) != 0
if l1 && !z1
-> `std.Before
elif !l1 && z1
-> `std.After
;;
l_k--
z_k--
;;
if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0)
-> `std.Before
elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0)
-> `std.After
;;
-> `std.Equal
}
/* Find the first 1 bit in a bitstring */
const find_first1_64 : (b : uint64, start : int64 -> int64) = {b : uint64, start : int64
for var j = start; j >= 0; --j
var m = shl(1, j)
if b & m != 0
-> j
;;
;;
-> -1
}
const find_first1_64_hl = {h, l, start
var first1_h = find_first1_64(h, start - 64)
if first1_h >= 0
-> first1_h + 64
;;
-> find_first1_64(l, 63)
}
/*
For [ h ][ l ], where bitpos_last is the position of the last
bit that was included in the truncated result (l's last bit has
position 0), decide whether rounding up/away is needed. This is
true if
- following bitpos_last is a 1, then a non-zero sequence, or
- following bitpos_last is a 1, then a zero sequence, and the
round would be to even
*/
const need_round_away = {h : uint64, l : uint64, bitpos_last : int64
var first_omitted_is_1 = false
var nonzero_beyond = false
if bitpos_last > 64
first_omitted_is_1 = h & shl(1, bitpos_last - 1 - 64) != 0
nonzero_beyond = nonzero_beyond || h & shr((-1 : uint64), 2 + 64 - (bitpos_last - 64)) != 0
nonzero_beyond = nonzero_beyond || (l != 0)
else
first_omitted_is_1 = l & shl(1, bitpos_last - 1) != 0
nonzero_beyond = nonzero_beyond || l & shr((-1 : uint64), 1 + 64 - bitpos_last) != 0
;;
if !first_omitted_is_1
-> false
;;
if nonzero_beyond
-> true
;;
var hl_is_odd = false
if bitpos_last >= 64
hl_is_odd = h & shl(1, bitpos_last - 64) != 0
else
hl_is_odd = l & shl(1, bitpos_last) != 0
;;
-> hl_is_odd
}