shithub: MicroHs

Download patch

ref: 63e3e99fb622bd61349ec3c83e81ade81fbaa205
parent: 814d813eaf5a5f10c9df88e6336166c210632459
author: Lennart Augustsson <lennart@augustsson.net>
date: Fri Sep 20 07:44:35 EDT 2024

Implement compare primitives for Int and Word.

--- a/lib/Data/Int/Instances.hs
+++ b/lib/Data/Int/Instances.hs
@@ -34,7 +34,7 @@
 bini8 :: (Int -> Int -> Int) -> (Int8 -> Int -> Int8)
 bini8 op (I8 x) y = i8 (x `op` y)
 
-cmp8 :: (Int -> Int -> Bool) -> (Int8 -> Int8 -> Bool)
+cmp8 :: (Int -> Int -> a) -> (Int8 -> Int8 -> a)
 cmp8 op (I8 x) (I8 y) = x `op` y
 
 una8 :: (Int -> Int) -> (Int8 -> Int8)
@@ -87,6 +87,7 @@
   (/=) = cmp8 primIntNE
 
 instance Ord Int8 where
+  compare = cmp8 primIntCompare
   (<)  = cmp8 primIntLT
   (<=) = cmp8 primIntLE
   (>)  = cmp8 primIntGT
@@ -121,7 +122,7 @@
 bini16 :: (Int -> Int -> Int) -> (Int16 -> Int -> Int16)
 bini16 op (I16 x) y = i16 (x `op` y)
 
-cmp16 :: (Int -> Int -> Bool) -> (Int16 -> Int16 -> Bool)
+cmp16 :: (Int -> Int -> a) -> (Int16 -> Int16 -> a)
 cmp16 op (I16 x) (I16 y) = x `op` y
 
 una16 :: (Int -> Int) -> (Int16 -> Int16)
@@ -174,6 +175,7 @@
   (/=) = cmp16 primIntNE
 
 instance Ord Int16 where
+  compare = cmp16 primIntCompare
   (<)  = cmp16 primIntLT
   (<=) = cmp16 primIntLE
   (>)  = cmp16 primIntGT
@@ -208,7 +210,7 @@
 bini32 :: (Int -> Int -> Int) -> (Int32 -> Int -> Int32)
 bini32 op (I32 x) y = i32 (x `op` y)
 
-cmp32 :: (Int -> Int -> Bool) -> (Int32 -> Int32 -> Bool)
+cmp32 :: (Int -> Int -> a) -> (Int32 -> Int32 -> a)
 cmp32 op (I32 x) (I32 y) = x `op` y
 
 una32 :: (Int -> Int) -> (Int32 -> Int32)
@@ -261,6 +263,7 @@
   (/=) = cmp32 primIntNE
 
 instance Ord Int32 where
+  compare = cmp32 primIntCompare
   (<)  = cmp32 primIntLT
   (<=) = cmp32 primIntLE
   (>)  = cmp32 primIntGT
@@ -293,7 +296,7 @@
 bini64 :: (Int -> Int -> Int) -> (Int64 -> Int -> Int64)
 bini64 op (I64 x) y = i64 (x `op` y)
 
-cmp64 :: (Int -> Int -> Bool) -> (Int64 -> Int64 -> Bool)
+cmp64 :: (Int -> Int -> a) -> (Int64 -> Int64 -> a)
 cmp64 op (I64 x) (I64 y) = x `op` y
 
 una64 :: (Int -> Int) -> (Int64 -> Int64)
@@ -346,6 +349,7 @@
   (/=) = cmp64 primIntNE
 
 instance Ord Int64 where
+  compare = cmp64 primIntCompare
   (<)  = cmp64 primIntLT
   (<=) = cmp64 primIntLE
   (>)  = cmp64 primIntGT
--- a/lib/Data/Word.hs
+++ b/lib/Data/Word.hs
@@ -74,6 +74,7 @@
   (/=) = primWordNE
 
 instance Ord Word where
+  compare = primWordCompare
   (<)  = primWordLT
   (<=) = primWordLE
   (>)  = primWordGT
@@ -113,7 +114,7 @@
 bini8 :: (Word -> Int -> Word) -> (Word8 -> Int -> Word8)
 bini8 op (W8 x) y = w8 (x `op` y)
 
-cmp8 :: (Word -> Word -> Bool) -> (Word8 -> Word8 -> Bool)
+cmp8 :: (Word -> Word -> a) -> (Word8 -> Word8 -> a)
 cmp8 op (W8 x) (W8 y) = x `op` y
 
 una8 :: (Word -> Word) -> (Word8 -> Word8)
@@ -166,6 +167,7 @@
   (/=) = cmp8 primWordNE
 
 instance Ord Word8 where
+  compare = cmp8 primWordCompare
   (<)  = cmp8 primWordLT
   (<=) = cmp8 primWordLE
   (>)  = cmp8 primWordGT
@@ -202,7 +204,7 @@
 bini16 :: (Word -> Int -> Word) -> (Word16 -> Int -> Word16)
 bini16 op (W16 x) y = w16 (x `op` y)
 
-cmp16 :: (Word -> Word -> Bool) -> (Word16 -> Word16 -> Bool)
+cmp16 :: (Word -> Word -> a) -> (Word16 -> Word16 -> a)
 cmp16 op (W16 x) (W16 y) = x `op` y
 
 una16 :: (Word -> Word) -> (Word16 -> Word16)
@@ -255,6 +257,7 @@
   (/=) = cmp16 primWordNE
 
 instance Ord Word16 where
+  compare = cmp16 primWordCompare
   (<)  = cmp16 primWordLT
   (<=) = cmp16 primWordLE
   (>)  = cmp16 primWordGT
@@ -291,7 +294,7 @@
 bini32 :: (Word -> Int -> Word) -> (Word32 -> Int -> Word32)
 bini32 op (W32 x) y = w32 (x `op` y)
 
-cmp32 :: (Word -> Word -> Bool) -> (Word32 -> Word32 -> Bool)
+cmp32 :: (Word -> Word -> a) -> (Word32 -> Word32 -> a)
 cmp32 op (W32 x) (W32 y) = x `op` y
 
 una32 :: (Word -> Word) -> (Word32 -> Word32)
@@ -344,6 +347,7 @@
   (/=) = cmp32 primWordNE
 
 instance Ord Word32 where
+  compare = cmp32 primWordCompare
   (<)  = cmp32 primWordLT
   (<=) = cmp32 primWordLE
   (>)  = cmp32 primWordGT
@@ -380,7 +384,7 @@
 bini64 :: (Word -> Int -> Word) -> (Word64 -> Int -> Word64)
 bini64 op (W64 x) y = w64 (x `op` y)
 
-cmp64 :: (Word -> Word -> Bool) -> (Word64 -> Word64 -> Bool)
+cmp64 :: (Word -> Word -> a) -> (Word64 -> Word64 -> a)
 cmp64 op (W64 x) (W64 y) = x `op` y
 
 una64 :: (Word -> Word) -> (Word64 -> Word64)
@@ -433,6 +437,7 @@
   (/=) = cmp64 primWordNE
 
 instance Ord Word64 where
+  compare = cmp64 primWordCompare
   (<)  = cmp64 primWordLT
   (<=) = cmp64 primWordLE
   (>)  = cmp64 primWordGT
--- a/lib/Primitives.hs
+++ b/lib/Primitives.hs
@@ -185,6 +185,8 @@
 primIntCompare  = primitive "icmp"
 primCharCompare :: forall a . Char -> Char -> Ordering
 primCharCompare  = primitive "icmp"
+primWordCompare :: forall a . Word -> Word -> Ordering
+primWordCompare  = primitive "ucmp"
 
 primStringEQ  :: [Char] -> [Char] -> Bool
 primStringEQ  = primitive "sequal"
--- a/src/runtime/eval.c
+++ b/src/runtime/eval.c
@@ -172,7 +172,7 @@
                 T_K2, T_K3, T_K4, T_CCB,
                 T_ADD, T_SUB, T_MUL, T_QUOT, T_REM, T_SUBR, T_UQUOT, T_UREM, T_NEG,
                 T_AND, T_OR, T_XOR, T_INV, T_SHL, T_SHR, T_ASHR,
-                T_EQ, T_NE, T_LT, T_LE, T_GT, T_GE, T_ULT, T_ULE, T_UGT, T_UGE,
+                T_EQ, T_NE, T_LT, T_LE, T_GT, T_GE, T_ULT, T_ULE, T_UGT, T_UGE, T_ICMP, T_UCMP,
                 T_FPADD, T_FP2P, T_FPNEW, T_FPFIN,
                 T_TOPTR, T_TOINT, T_TODBL, T_TOFUNPTR,
                 T_BININT2, T_BININT1, T_UNINT1,
@@ -718,7 +718,8 @@
   { "sequal", T_EQUAL, T_EQUAL },
   { "compare", T_COMPARE },
   { "scmp", T_COMPARE },
-  { "icmp", T_COMPARE },
+  { "icmp", T_ICMP },
+  { "ucmp", T_UCMP },
   { "rnf", T_RNF },
   { "fromUTF8", T_BSFROMUTF8 },
   { "toUTF8", T_BSTOUTF8 },
@@ -2124,6 +2125,8 @@
   case T_ULE: putsb("u<=", f); break;
   case T_UGT: putsb("u>", f); break;
   case T_UGE: putsb("u>=", f); break;
+  case T_ICMP: putsb("icmp", f); break;
+  case T_UCMP: putsb("ucmp", f); break;
   case T_FPADD: putsb("fp+", f); break;
   case T_FP2P: putsb("fp2p", f); break;
   case T_FPNEW: putsb("fpnew", f); break;
@@ -3020,10 +3023,12 @@
   case T_LE:
   case T_GT:
   case T_GE:
+  case T_ICMP:
   case T_ULT:
   case T_ULE:
   case T_UGT:
   case T_UGE:
+  case T_UCMP:
     CHECK(2);
     n = ARG(TOP(1));
     if (GETTAG(n) == T_INT) {
@@ -3371,10 +3376,12 @@
       case T_ULE:   GOIND(xu <= yu ? combTrue : combFalse);
       case T_UGT:   GOIND(xu >  yu ? combTrue : combFalse);
       case T_UGE:   GOIND(xu >= yu ? combTrue : combFalse);
+      case T_UCMP:  GOIND(xu <  yu ? combLT   : xu > yu ? combGT : combEQ);
       case T_LT:    GOIND((value_t)xu <  (value_t)yu ? combTrue : combFalse);
       case T_LE:    GOIND((value_t)xu <= (value_t)yu ? combTrue : combFalse);
       case T_GT:    GOIND((value_t)xu >  (value_t)yu ? combTrue : combFalse);
       case T_GE:    GOIND((value_t)xu >= (value_t)yu ? combTrue : combFalse);
+      case T_ICMP:  GOIND((value_t)xu <  (value_t)yu ? combLT   : (value_t)xu > (value_t)yu ? combGT : combEQ);
 
       default:
         //fprintf(stderr, "tag=%d\n", GETTAG(FUN(TOP(0))));
--