shithub: sl

Download patch

ref: 99a1b91c3e6b40fe1726c14aa703c3dbbbe75901
parent: 6cf5e215a4f19645a018eba3b696ce06dd8b4268
author: Sigrid Solveig Haflínudóttir <sigrid@ftrv.se>
date: Sun May 11 19:16:43 EDT 2025

rewrite numerical comparison logic

This should fix many corner cases.
One big change from how this previously worked is that no comparison
operators will return "these are equal" if at least one of the arguments
is NaN - eq?, equal?, eqv? and = all will return NIL.

--- a/src/cvalues.c
+++ b/src/cvalues.c
@@ -414,7 +414,7 @@
 tosize(sl_v n)
 {
 	if(sl_unlikely(iscvalue(n))){
-		if(sl_likely(numeric_compare(n, fixnum(0), false, false, true) >= 0)){
+		if(sl_likely(numeric_compare(n, fixnum(0), false, true) >= 0)){
 			sl_cv *cv = ptr(n);
 			if(sizeof(usize) > 4)
 				return conv_to_u64(n, cv_data(cv), cv_numtype(cv));
@@ -1088,13 +1088,10 @@
 /*
   returns -1, 0, or 1 based on ordering of a and b
   eq: consider equality only, returning 0 or nonzero
-  eqnans: NaNs considered equal to each other
-		  -0.0 not considered equal to 0.0
-		  inexact not considered equal to exact
   typeerr: if not 0, throws type errors, else returns 2 for type errors
 */
 int
-numeric_compare(sl_v a, sl_v b, bool eq, bool eqnans, bool typeerr)
+numeric_compare(sl_v a, sl_v b, bool eq, bool typeerr)
 {
 	sl_fx ai, bi;
 	sl_numtype ta, tb;
@@ -1117,15 +1114,7 @@
 			cthrow(type_error("b", "num", b), a);
 		return 2;
 	}
-	if(eq && eqnans && ((ta >= T_F32) != (tb >= T_F32)))
-		return 1;
-	if(cmp_eq(aptr, ta, bptr, tb, eqnans))
-		return 0;
-	if(eq)
-		return 1;
-	if(cmp_lt(aptr, ta, bptr, tb))
-		return -1;
-	return 1;
+	return cmp_numeric(aptr, ta, bptr, tb);
 }
 
 _Noreturn void
@@ -1370,7 +1359,7 @@
 	if(isfixnum(a))
 		return fixnum(~numval(a));
 	if(isubnum(a)){
-		if((ubnumtype(a) & 1) == 0){
+		if(is_nt_sint(ubnumtype(a))){
 			sl_fx s = ubnumval(a);
 			s = (sl_v)~s << TAG_EXT_BITS;
 			return ((sl_v)s & ~(sl_v)TAG_EXT_MASK) | (a & TAG_EXT_MASK);
--- a/src/equal.c
+++ b/src/equal.c
@@ -80,7 +80,7 @@
 		if(isfixnum(b))
 			return (sl_fx)a < (sl_fx)b ? fixnum(-1) : fixnum(1);
 		if(isubnum(b) || (iscvalue(b) && (cv = ptr(b), valid_numtype(cv_numtype(cv)))))
-			return fixnum(numeric_compare(a, b, eq, true, false));
+			return fixnum(numeric_compare(a, b, eq, false));
 		if(isrune(b))
 			return fixnum(1);
 		return fixnum(-1);
@@ -90,7 +90,7 @@
 		if(isrune(b))
 			return fixnum(1);
 		if(isfixnum(b) || isubnum(b) || (iscvalue(b) && (cv = ptr(b), valid_numtype(cv_numtype(cv)))))
-			return fixnum(numeric_compare(a, b, eq, true, false));
+			return fixnum(numeric_compare(a, b, eq, false));
 		return fixnum(-1);
 	case TAG_SYM:
 		if(eq || tagb < TAG_SYM)
@@ -105,7 +105,7 @@
 	case TAG_CVALUE:
 		cv = ptr(a);
 		if(valid_numtype(cv_numtype(cv))){
-			if((c = numeric_compare(a, b, eq, true, false)) != 2)
+			if((c = numeric_compare(a, b, eq, false)) != 2)
 				return fixnum(c);
 		}
 		if(tagb == TAG_CVALUE)
--- a/src/equal.h
+++ b/src/equal.h
@@ -8,5 +8,5 @@
 uintptr hash_lispvalue(sl_v a);
 sl_v bounded_compare(sl_v a, sl_v b, int bound, bool eq);
 sl_v sl_compare(sl_v a, sl_v b, bool eq);
-int numeric_compare(sl_v a, sl_v b, bool eq, bool eqnans, bool typeerr);
+int numeric_compare(sl_v a, sl_v b, bool eq, bool typeerr);
 void comparehash_init(void);
--- a/src/operators.c
+++ b/src/operators.c
@@ -34,7 +34,7 @@
 	case T_U32: case T_P32: return *(u32int*)data;
 	case T_S64:
 		d = *(s64int*)data;
-		if(d > 0 && *(s64int*)data < 0)  // can happen!
+		if(d > 0 && *(s64int*)data < 0) // can happen!
 			d = -d;
 		return d;
 	case T_U64: case T_P64: return *(u64int*)data;
@@ -103,7 +103,7 @@
 	cthrow(type_error(nil, "num", v), v);
 }
 
-sl_purefn
+static sl_purefn
 bool
 cmp_same_lt(void *a, void *b, sl_numtype tag)
 {
@@ -117,13 +117,13 @@
 	case T_S64:             return *(s64int*)a < *(s64int*)b;
 	case T_U64: case T_P64: return *(u64int*)a < *(u64int*)b;
 	case T_BIG:             return mpcmp(*(mpint**)a, *(mpint**)b) < 0;
-	case T_F32:             return *(float*)a < *(float*)b;
-	case T_F64:             return *(double*)a < *(double*)b;
+	case T_F32:             return *(float*)a < *(float*)b && !isnan(*(float*)a) && !isnan(*(float*)b);
+	case T_F64:             return *(double*)a < *(double*)b && !isnan(*(double*)a) && !isnan(*(double*)b);
 	}
 	return false;
 }
 
-sl_purefn
+static sl_purefn
 bool
 cmp_same_eq(void *a, void *b, sl_numtype tag)
 {
@@ -137,8 +137,8 @@
 	case T_S64:             return *(s64int*)a == *(s64int*)b;
 	case T_U64: case T_P64: return *(u64int*)a == *(u64int*)b;
 	case T_BIG:             return mpcmp(*(mpint**)a, *(mpint**)b) == 0;
-	case T_F32:             return *(float*)a == *(float*)b && !isnan(*(float*)a);
-	case T_F64:             return *(double*)a == *(double*)b && !isnan(*(double*)b);
+	case T_F32:             return *(float*)a == *(float*)b && !isnan(*(float*)a) && !isnan(*(float*)b);
+	case T_F64:             return *(double*)a == *(double*)b && !isnan(*(double*)a) && !isnan(*(double*)b);
 	}
 	return false;
 }
@@ -146,120 +146,86 @@
 /* FIXME one is allocated for all compare ops */
 static mpint *cmpmpint;
 
-bool
-cmp_lt(void *a, sl_numtype atag, void *b, sl_numtype btag)
+static sl_numtype
+promote(void *x, sl_numtype nt, void *d, mpint **m)
 {
-	if(atag == btag)
-		return cmp_same_lt(a, b, atag);
-
-	double da = conv_to_double(sl_nil, a, atag);
-	double db = conv_to_double(sl_nil, b, btag);
-
-	if(isnan(da) || isnan(db))
-		return false;
-
-	// casting to double will only get the wrong answer for big int64s
-	// that differ in low bits
-	if(da < db)
-		return true;
-	if(db < da)
-		return false;
-
-	if(cmpmpint == nil && (atag == T_BIG || btag == T_BIG))
-		cmpmpint = mpnew(0);
-
-	if(atag == T_U64){
-		if(btag == T_S64)
-			return *(s64int*)b >= 0 && *(u64int*)a < (u64int)*(s64int*)b;
-		if(btag == T_F64)
-			return db >= 0 ? *(u64int*)a < (u64int)*(double*)b : 0;
-		if(btag == T_BIG)
-			return mpcmp(uvtomp(*(u64int*)a, cmpmpint), *(mpint**)b) < 0;
-	}
-	if(atag == T_S64){
-		if(btag == T_U64)
-			return *(s64int*)a >= 0 && (u64int)*(s64int*)a < *(u64int*)b;
-		if(btag == T_F64)
-			return *(s64int*)a < (s64int)*(double*)b;
-		if(btag == T_BIG)
-			return mpcmp(vtomp(*(s64int*)a, cmpmpint), *(mpint**)b) < 0;
-	}
-	if(btag == T_U64){
-		if(atag == T_F64)
-			return da >= 0 ? *(u64int*)b > (u64int)*(double*)a : 0;
-		if(atag == T_BIG)
-			return mpcmp(*(mpint**)a, uvtomp(*(u64int*)b, cmpmpint)) < 0;
-	}
-	if(btag == T_S64){
-		if(atag == T_F64)
-			return *(s64int*)b > (s64int)*(double*)a;
-		if(atag == T_BIG)
-			return mpcmp(*(mpint**)a, vtomp(*(s64int*)b, cmpmpint)) < 0;
-	}
-	return false;
+	if(nt >= T_F32){
+		*(double*)d = conv_to_double(sl_nil, x, nt);
+		nt = T_F64;
+	}else if(is_nt_uint(nt)){
+		*(u64int*)d = conv_to_u64(sl_nil, x, nt);
+		nt = T_U64;
+	}else if(is_nt_sint(nt)){
+		*(s64int*)d = conv_to_s64(sl_nil, x, nt);
+		nt = T_S64;
+	}else if(nt == T_BIG){
+		if(cmpmpint == nil)
+			cmpmpint = mpnew(128);
+		*m = *(mpint**)x;
+	}else
+		abort();
+	return nt;
 }
 
-bool
-cmp_eq(void *a, sl_numtype atag, void *b, sl_numtype btag, bool equalnans)
+int
+cmp_numeric(void *ap, sl_numtype ant, void *bp, sl_numtype bnt)
 {
-	union {
+	if(ant == bnt){
+		if(cmp_same_eq(ap, bp, ant))
+			return 0;
+		return cmp_same_lt(ap, bp, ant) ? -1 : 1;
+	}
+
+	union{
+		u64int u;
+		s64int s;
 		double d;
-		s64int i64;
-	}u, v;
+		mpint *m;
+	}a, b;
+	ant = promote(ap, ant, &a.u, &a.m);
+	bnt = promote(bp, bnt, &b.u, &b.m);
 
-	if(atag == btag && (!equalnans || atag < T_F32))
-		return cmp_same_eq(a, b, atag);
-
-	double da = conv_to_double(sl_nil, a, atag);
-	double db = conv_to_double(sl_nil, b, btag);
-
-	if((int)atag >= T_F32 && (int)btag >= T_F32){
-		if(equalnans){
-			u.d = da; v.d = db;
-			return u.i64 == v.i64;
+	switch(ant){
+	case T_U64:
+		switch(bnt){
+		case T_U64: return a.u < b.u ? -1 : (a.u == b.u ? 0 : 1);
+		case T_S64: return a.u < b.u && b.s > 0 ? -1 : (a.u == b.u && b.s >= 0 ? 0 : 1);
+		case T_F64: return isnan(b.d) ? 1 : (a.u < b.d && b.d > 0 ? -1 : (a.u == b.d && b.d >= 0 ? 0 : 1));
+		case T_BIG: return mpcmp(uvtomp(a.u, cmpmpint), b.m);
+		default: break;
 		}
-		return da == db;
+		break;
+	case T_S64:
+		switch(bnt){
+		case T_S64: return a.s < b.s ? -1 : (a.s == b.s ? 0 : 1);
+		case T_U64: return (a.s < 0 || a.u < b.u) ? -1 : (a.u == b.u && a.s >= 0 ? 0 : 1);
+		case T_F64: return isnan(b.d) ? 1 : (a.s < b.d ? -1 : (a.s == b.d ? 0 : 1));
+		case T_BIG: return mpcmp(vtomp(a.s, cmpmpint), b.m);
+		default: break;
+		}
+		break;
+	case T_F64:
+		if(isnan(a.d))
+			return 1;
+		switch(bnt){
+		case T_F64: return isnan(b.d) ? 1 : (a.d < b.d ? -1 : (a.d == b.d ? 0 : 1));
+		case T_U64: return (a.d < 0 || a.d < b.u) ? -1 : (a.d == b.u ? 0 : 1);
+		case T_S64: return a.d < b.s ? -1 : (a.d == b.s ? 0 : 1);
+		case T_BIG: return mpcmp(dtomp(a.d, cmpmpint), b.m);
+		default: break;
+		}
+		break;
+	case T_BIG:
+		switch(bnt){
+		case T_U64: return mpcmp(a.m, uvtomp(b.u, cmpmpint));
+		case T_S64: return mpcmp(a.m, vtomp(b.s, cmpmpint));
+		case T_F64: return isnan(b.d) ? 1 : mpcmp(a.m, dtomp(b.d, cmpmpint));
+		default: break;
+		}
+		break;
+	default:
+		break;
 	}
 
-	if(da != db)
-		return false;
-
-	if(cmpmpint == nil && (atag == T_BIG || btag == T_BIG))
-		cmpmpint = mpnew(0);
-
-	if(atag == T_U64){
-		// this is safe because if a had been bigger than INT64_MAX,
-		// we would already have concluded that it's bigger than b.
-		if(btag == T_S64)
-			return *(s64int*)b >= 0 && *(u64int*)a == *(u64int*)b;
-		if(btag == T_F64)
-			return *(double*)b >= 0 && *(u64int*)a == (u64int)*(double*)b;
-		if(btag == T_BIG)
-			return mpcmp(uvtomp(*(u64int*)a, cmpmpint), *(mpint**)b) == 0;
-	}
-	if(atag == T_S64){
-		if(btag == T_U64)
-			return *(s64int*)a >= 0 && *(u64int*)a == *(u64int*)b;
-		if(btag == T_F64)
-			return *(s64int*)a == (s64int)*(double*)b;
-		if(btag == T_BIG)
-			return mpcmp(vtomp(*(s64int*)a, cmpmpint), *(mpint**)b) == 0;
-	}
-	if(btag == T_U64){
-		if(atag == T_S64)
-			return *(s64int*)a >= 0 && *(u64int*)b == *(u64int*)a;
-		if(atag == T_F64)
-			return *(double*)a >= 0 && *(u64int*)b == (u64int)*(double*)a;
-		if(atag == T_BIG)
-			return mpcmp(*(mpint**)a, uvtomp(*(u64int*)b, cmpmpint)) == 0;
-	}
-	if(btag == T_S64){
-		if(atag == T_U64)
-			return *(s64int*)b >= 0 && *(u64int*)b == *(u64int*)a;
-		if(atag == T_F64)
-			return *(s64int*)b == (s64int)*(double*)a;
-		if(atag == T_BIG)
-			return mpcmp(*(mpint**)a, vtomp(*(s64int*)b, cmpmpint)) == 0;
-	}
-	return true;
+	abort();
 }
--- a/src/operators.h
+++ b/src/operators.h
@@ -16,7 +16,4 @@
 #define conv_to_ptr conv_to_p32
 #endif
 
-bool cmp_same_lt(void *a, void *b, sl_numtype tag) sl_nonnull(1, 2);
-bool cmp_same_eq(void *a, void *b, sl_numtype tag) sl_nonnull(1, 2);
-bool cmp_lt(void *a, sl_numtype atag, void *b, sl_numtype btag) sl_nonnull(1, 3);
-bool cmp_eq(void *a, sl_numtype atag, void *b, sl_numtype btag, bool equalnans) sl_nonnull(1, 3);
+int cmp_numeric(void *ap, sl_numtype ant, void *bp, sl_numtype bnt);
--- a/src/sl.h
+++ b/src/sl.h
@@ -58,9 +58,14 @@
 	T_S64, T_U64,
 	T_P64,
 	T_UNBOXED_NUM,
-	T_BIG = T_UNBOXED_NUM,
 	T_F32, T_F64,
+
+	// always a cvalue
+	T_BIG = T_UNBOXED_NUM,
 }sl_numtype;
+
+#define is_nt_sint(nt) ((1<<(nt)) & (1<<T_S8 | 1<<T_S16 | 1<<T_S32 | 1<<T_S64))
+#define is_nt_uint(nt) ((1<<(nt)) & (1<<T_U8 | 1<<T_U16 | 1<<T_U32 | 1<<T_U64 | 1<<T_P32 | 1<<T_P64))
 
 #if defined(BITS64)
 typedef s64int sl_fx;
--- a/src/vm.h
+++ b/src/vm.h
@@ -195,7 +195,7 @@
 				break;
 			}
 		}else{
-			int x = numeric_compare(a, b, false, false, false);
+			int x = numeric_compare(a, b, false, false);
 			if(x > 1)
 				x = numval(sl_compare(a, b, false));
 			if(x >= 0){
@@ -663,7 +663,7 @@
 				v = sl_nil;
 				break;
 			}
-		}else if(numeric_compare(a, b, true, false, true) != 0){
+		}else if(numeric_compare(a, b, true, true) != 0){
 			v = sl_nil;
 			break;
 		}
--- a/test/unittest.sl
+++ b/test/unittest.sl
@@ -217,10 +217,10 @@
 (assert (nan? -nan.0))
 (assert (nan? (f32 +nan.0)))
 (assert (nan? (f32 -nan.0)))
-(assert (equal? +nan.0 +nan.0))
-(assert (equal? -nan.0 -nan.0))
-(assert (equal? (f32 +nan.0) (f32 +nan.0)))
-(assert (equal? (f32 -nan.0) (f32 -nan.0)))
+(assert (not (equal? +nan.0 +nan.0)))
+(assert (not (equal? -nan.0 -nan.0)))
+(assert (not (equal? (f32 +nan.0) (f32 +nan.0))))
+(assert (not (equal? (f32 -nan.0) (f32 -nan.0))))
 (assert (not (= +nan.0 +nan.0)))
 (assert (not (= +nan.0 -nan.0)))
 (assert (not (= -nan.0 -nan.0)))
@@ -261,22 +261,22 @@
 (assert (not (and (>= 2 1 2) (<= 2 1 2))))
 
 ; -0.0 etc.
-(assert (not (equal? 0.0 0)))
+(assert (equal? 0.0 0))
 (assert (equal? 0.0 0.0))
-(assert (not (equal? -0.0 0.0)))
-(assert (not (equal? -0.0 0)))
-(assert (not (eqv? 0.0 0)))
-(assert (not (eqv? -0.0 0)))
-(assert (not (eqv? -0.0 0.0)))
+(assert (equal? -0.0 0.0))
+(assert (equal? -0.0 0))
+(assert (eqv? 0.0 0))
+(assert (eqv? -0.0 0))
+(assert (eqv? -0.0 0.0))
 (assert (= 0.0 -0.0))
 ; same but f32
-(assert (not (equal? 0.0f 0)))
+(assert (equal? 0.0f 0))
 (assert (equal? 0.0f 0.0f))
-(assert (not (equal? -0.0f 0.0f)))
-(assert (not (equal? -0.0f 0)))
-(assert (not (eqv? 0.0f 0)))
-(assert (not (eqv? -0.0f 0)))
-(assert (not (eqv? -0.0f 0.0f)))
+(assert (equal? -0.0f 0.0f))
+(assert (equal? -0.0f 0))
+(assert (eqv? 0.0f 0))
+(assert (eqv? -0.0f 0))
+(assert (eqv? -0.0f 0.0f))
 (assert (= 0.0f -0.0f))
 
 ; and, or
--