shithub: mc

Download patch

ref: 8a2b92eff783f9489be2d6ed28187c0e69cc26ca
parent: c73cc580f7e048a81571b602ff0eae60afd2a5e3
author: Ori Bernstein <ori@eigenstate.org>
date: Sun Dec 11 14:34:34 EST 2016

Fix sign extension bugs in division.

	Also does some cleanup.

--- a/lib/std/bigint.myr
+++ b/lib/std/bigint.myr
@@ -86,7 +86,11 @@
 	const bigbitcount	: (a : bigint# -> size)
 ;;
 
+/* put for debugging */
+extern const put : (fmt : byte[:], args : ... -> size)
+
 const Base = 0x100000000ul
+
 generic mkbigint = {v : @a::(integral,numeric)
 	var a
 	var val
@@ -518,6 +522,7 @@
 	else
 		a.sign = 1
 	;;
+
 	w  = slzalloc(a.dig.len + b.dig.len)
 	for j = 0; j < b.dig.len; j++
 		carry = 0
@@ -526,10 +531,10 @@
 			bj = (b.dig[j]  : uint64)
 			wij = (w[i+j]  : uint64)
 			t = ai * bj + wij + carry
-			w[i + j] = ((t  : uint32))
+			w[i+j] = (t  : uint32)
 			carry = t >> 32
 		;;
-		w[i+j] = (carry  : uint32)
+		w[i + j] = (carry  : uint32)
 	;;
 	slfree(a.dig)
 	a.dig = w
@@ -561,6 +566,7 @@
 	var m : int64, n : int64
 	var qhat, rhat, carry, shift
 	var x, y, z, w, p, t /* temporaries */
+	var pt, tt
 	var b0, aj
 	var u, v
 	var i, j : int64
@@ -600,13 +606,15 @@
 	m = u.dig.len
 	n = v.dig.len
 
+	/* normalize */
 	shift = nlz(v.dig[n - 1])
 	bigshli(u, shift)
 	bigshli(v, shift)
 	slzgrow(&u.dig, u.dig.len + 1)
 
+	/* Since we're little endian, we iterate backwards from Knuth */
 	for j = m - n; j >= 0; j--
-		/* load a few temps */
+		/* load a few temps for less casting */
 		x = (u.dig[j + n]  : uint64)
 		y = (u.dig[j + n - 1]  : uint64)
 		z = (v.dig[n - 1]  : uint64)
@@ -613,7 +621,7 @@
 		w = (v.dig[n - 2]  : uint64)
 		t = (u.dig[j + n - 2]  : uint64)
 
-		/* estimate qhat */
+		/* calculate qhat */
 		qhat = (x*Base + y)/z
 		rhat = (x*Base + y) - qhat*z
 :divagain
@@ -630,12 +638,15 @@
 		for i = 0; i < n; i++
 			p = (qhat * (v.dig[i]  : uint64))
 
-			t = ((u.dig[i+j]  : uint64)) - carry - (p % Base)
+			t = (u.dig[i+j]  : uint64) - carry - (p % Base)
 			u.dig[i+j] = (t  : uint32)
-			carry = (((p : int64) >> 32) - ((t : int64) >> 32) : uint64)
+			tt = (t : int64) >> 32
+			pt = (p >> 32)
+			carry = ((pt : int64) - (tt : int64) : uint64)
 		;;
 		t = (u.dig[j + n] : uint64) - carry
 		u.dig[j + n] = (t  : uint32)
+
 		q.dig[j] = (qhat  : uint32)
 		/* adjust */
 		if (t : int64) < 0
@@ -651,7 +662,8 @@
 
 	;;
 	/* undo the biasing for remainder */
-	u = bigshri(u, shift)
+	bigshri(u, shift)
+	trim(q)
 	-> (trim(q), trim(u))
 }
 
@@ -824,9 +836,8 @@
 	for var i = 0; i < a.dig.len - off; i++
 		a.dig[i] = a.dig[i + off]
 	;;
-	for var i = a.dig.len - off; i < a.dig.len; i++
-		a.dig[i] = 0
-	;;
+	a.dig = a.dig[:a.dig.len - off]
+
 	/* and shift over by the remainder */
 	carry = 0
 	for var i = a.dig.len; i > 0; i--
@@ -880,11 +891,9 @@
 			break
 		;;
 	;;
-	slgrow(&a.dig, i)
+	a.dig = a.dig[:i]
 	if i == 0
 		a.sign = 0
-	elif a.sign == 0
-		a.sign = 1
 	;;
 	-> a
 }
--- a/lib/std/test/bigint.myr
+++ b/lib/std/test/bigint.myr
@@ -96,6 +96,16 @@
 		std.mk(`Val "755578"))), \
 		"49054")
 	run(std.mk(`Modpow (\
+		std.mk(`Val "2393"), \
+		std.mk(`Val "2"), \
+		std.mk(`Val "6737"))), \
+		"6736")
+	run(std.mk(`Modpow (\
+		std.mk(`Val "6193257528475266832463188301662235"), \
+		std.mk(`Val "6157075615645799356061575607567581"), \
+		std.mk(`Val "12314151231291598712123151215135163"))), \
+		"1540381241336817586803754632242117")
+	run(std.mk(`Modpow (\
 		std.mk(`Val "7220"), \
 		std.mk(`Val "755578"), \
 		std.mk(`Val "75557863709417659441940"))), \
@@ -110,7 +120,7 @@
 	v = eval(e)
 	n = std.bigbfmt(buf[:], v, 0)
 	if !std.sleq(buf[:n], res)
-		std.fatal("%s != %s\n", buf[:n], res)
+		std.fatal("{} != {}\n", buf[:n], res)
 	;;
 }