shithub: MicroHs

Download patch

ref: 01e3ea89201c49757154ce04c475a95b0040d3df
parent: 0a6f76bc5c3e2d48b503734bc7ea7ed39292d182
parent: 83c472492f864dce83e948d6968417bc3538db18
author: Lennart Augustsson <lennart@augustsson.net>
date: Sun Jan 12 01:14:20 EST 2025

Merge pull request #87 from konsumlamm/Integer

Optimize `Integer` a bit

--- a/lib/Data/Integer.hs
+++ b/lib/Data/Integer.hs
@@ -14,27 +14,29 @@
 import Prelude()              -- do not import Prelude
 import Primitives
 import Control.Error
+import Data.Bits
 import Data.Bool
 import Data.Char
 import Data.Enum
 import Data.Eq
 import Data.Function
-import Data.Int
 import Data.Integer_Type
 import Data.Integral
 import Data.List
+import Data.Maybe_Type
 import Data.Num
 import Data.Ord
 import Data.Ratio_Type
 import Data.Real
+import Data.Word ()
 import Numeric.Show
 import Text.Show
 
 --
--- The Integer is stored in sign-magniture format with digits in base maxD (2^31)
+-- The Integer is stored in sign-magnitude format with digits in base maxD (2^32)
 -- It has the following invariants:
 --  * each digit is >= 0 and < maxD
---  * least signification digits first, most significant last
+--  * least significant digits first, most significant last
 --  * no trailing 0s in the digits
 --  * 0 is positive
 {- These definitions are in Integer_Type
@@ -41,10 +43,10 @@
 data Integer = I Sign [Digit]
   --deriving Show
 
-type Digit = Int
+type Digit = Word
 
 maxD :: Digit
-maxD = 2147483648  -- 2^31, this is used so multiplication of two digit doesn't overflow a 64 bit Int
+maxD = 4294967296 -- 2^32, this is used so multiplication of two digit doesn't overflow a 64 bit Word
 
 data Sign = Plus | Minus
   --deriving Show
@@ -111,7 +113,8 @@
 -- Trim off 0s and make an Integer
 sI :: Sign -> [Digit] -> Integer
 sI s ds =
-  case trim0 ds of
+  -- Remove trailing 0s
+  case dropWhileEnd (== (0 :: Word)) ds of
     []  -> I Plus []
     ds' -> I s    ds'
 
@@ -148,7 +151,7 @@
 
 -- Add 3 digits with carry
 addD :: Digit -> Digit -> Digit -> (Digit, Digit)
-addD x y z = (quot s maxD, rem s maxD)  where s = x + y + z
+addD x y z = (quotMaxD s, remMaxD s)  where s = x + y + z
 
 -- Invariant: xs >= ys, so result is always >= 0
 sub :: [Digit] -> [Digit] -> [Digit]
@@ -158,21 +161,14 @@
 sub' bi (x : xs) (y : ys) = d : sub' bo xs ys  where (bo, d) = subW bi x y
 sub' bi (x : xs) []       = d : sub' bo xs []  where (bo, d) = subW bi x zeroD
 sub' 0  []       []       = []
-sub' _  []       _        = undefined
+sub' _  []       _        = error "impossible: xs >= ys"
 
 -- Subtract with borrow
 subW :: Digit -> Digit -> Digit -> (Digit, Digit)
 subW b x y =
-  let d = x - y + b
-  in  if d < 0 then
-        (quot d maxD - 1, rem d maxD + maxD)
-      else
-        (quot d maxD, rem d maxD)
+  let d = maxD + x - y - b
+  in (1 - quotMaxD d, remMaxD d)
 
--- Remove trailing 0s
-trim0 :: [Digit] -> [Digit]
-trim0 = reverse . dropWhile (== (0::Int)) . reverse
-
 -- Is axs < ays?
 ltW :: [Digit] -> [Digit] -> Bool
 ltW axs ays = lxs < lys || lxs == lys && cmp (reverse axs) (reverse ays)
@@ -182,7 +178,7 @@
     cmp (x:xs) (y:ys) = x < y || x == y && cmp xs ys
     cmp []     []     = False
     cmp _      _      = error "ltW.cmp"
-    
+
 mulI :: Integer -> Integer -> Integer
 mulI (I _ []) _ = I Plus []         -- 0 * x = 0
 mulI _ (I _ []) = I Plus []         -- x * 0 = 0
@@ -199,13 +195,13 @@
 mulD ci (x:xs) y = r : mulD q xs y
   where
     xy = x * y + ci
-    q = quot xy maxD
-    r = rem  xy maxD
+    q = quotMaxD xy
+    r = remMaxD xy
 
 mulM :: [Digit] -> [Digit] -> [Digit]
 mulM xs ys =
   let rs = map (mulD zeroD xs) ys
-      ss = zipWith (++) (map (`replicate` (0::Int)) [0::Int ..]) rs
+      ss = zipWith (++) (map (`replicate` (0 :: Word)) [0 :: Int ..]) rs
   in  foldl1 add ss
 
 -- Signs:
@@ -216,17 +212,21 @@
 quotRemI :: Integer -> Integer -> (Integer, Integer)
 quotRemI _         (I _  [])  = error "Integer: division by 0" -- n / 0
 quotRemI (I _  [])          _ = (I Plus [], I Plus [])         -- 0 / n
-quotRemI (I sx xs) (I sy ys) | all (== (0::Int)) ys' =
+quotRemI (I sx xs) (I sy ys) | Just (y, n) <- msd ys =
   -- All but the MSD are 0.  Scale numerator accordingly and divide.
   -- Then add back (the ++) the remainder we scaled off.
-    case quotRemD xs' y of
-      (q, r) -> qrRes sx sy (q, rs ++ r)
-  where ys'       = init ys
-        y         = last ys
-        n         = length ys'
-        (rs, xs') = splitAt n xs  -- xs' is the scaled number
+  let (rs, xs') = splitAt n xs  -- xs' is the scaled number
+  in case quotRemD xs' y of
+    (q, r) -> qrRes sx sy (q, rs ++ r)
 quotRemI (I sx xs) (I sy ys)  = qrRes sx sy (quotRemB xs ys)
 
+msd :: [Digit] -> Maybe (Digit, Int)
+msd = go 0
+  where
+    go _ [] = Nothing
+    go n [d] = Just (d, n)
+    go n (d : ds) = if d == 0 then go (n + 1) ds else Nothing
+
 qrRes :: Sign -> Sign -> ([Digit], [Digit]) -> (Integer, Integer)
 qrRes sx sy (ds, rs) = (sI (mulSign sx sy) ds, sI sx rs)
 
@@ -243,7 +243,7 @@
     qr ci []     res = (res, [ci])
     qr ci (x:xs) res = qr r xs (q:res)
       where
-        cx = ci * maxD + x
+        cx = ci `shiftL` shiftD + x
         q = quot cx y
         r = rem cx y
 
@@ -252,7 +252,7 @@
 quotRemB xs ys =
   let n  = I Plus xs
       d  = I Plus ys
-      a  = I Plus $ replicate (length ys - (1::Int)) (0::Int) ++ [last ys]  -- only MSD of ys
+      a  = I Plus $ replicate (length ys - (1 :: Int)) (0 :: Word) ++ [last ys]  -- only MSD of ys
       aq = quotI n a
       ar = addI d oneI
       loop q r =
@@ -411,7 +411,6 @@
 {-
 sanity :: HasCallStack => Integer -> Integer
 sanity (I Minus []) = undefined
-sanity (I _ ds) | any (< 0) ds = undefined
 sanity (I _ ds) | length ds > 1 && last ds == 0 = undefined
 sanity i = i
 -}
@@ -438,7 +437,7 @@
 prop_div x (NonZero y) =
   to (quotRemI x y) == toInteger x `quotRem` toInteger y
   where to (a, b) = (toInteger a, toInteger b)
-  
+
 prop_muldiv :: Integer -> NonZero Integer -> Bool
 prop_muldiv x (NonZero y) =
   let (q, r) = quotRemI x y
@@ -472,5 +471,5 @@
   mapM_ qc [prop_add, prop_sub, prop_mul,
             prop_eq, prop_ne, prop_lt, prop_gt, prop_le, prop_ge]
   mapM_ qc [prop_div, prop_muldiv]
-  
+
 -}
--- a/lib/Data/Integer_Type.hs
+++ b/lib/Data/Integer_Type.hs
@@ -11,57 +11,76 @@
 
 data Sign = Plus | Minus
 
-type Digit = Int
+type Digit = Word
 
 maxD :: Digit
 maxD =
   if _wordSize `primIntEQ` 64 then
-    (2147483648::Int)  -- 2^31, this is used so multiplication of two digits doesn't overflow a 64 bit Int
+    (4294967296 :: Word) -- 2^32, this is used so multiplication of two digits doesn't overflow a 64 bit Word
   else if _wordSize `primIntEQ` 32 then
-    (32768::Int)       -- 2^15, this is used so multiplication of two digits doesn't overflow a 32 bit Int
+    (65536 :: Word)      -- 2^16, this is used so multiplication of two digits doesn't overflow a 32 bit Word
   else
     error "Integer: unsupported word size"
 
+shiftD :: Int
+shiftD =
+  if _wordSize `primIntEQ` 64 then
+    (32::Int)
+  else if _wordSize `primIntEQ` 32 then
+    (16::Int)
+  else
+    error "Integer: unsupported word size"
+
+quotMaxD :: Digit -> Digit
+quotMaxD d = d `primWordShr` shiftD
+
+remMaxD :: Digit -> Digit
+remMaxD d = d `primWordAnd` (maxD `primWordSub` 1)
+
 -- Sadly, we also need a bunch of functions.
 
 _intToInteger :: Int -> Integer
-_intToInteger i | i `primIntGE` 0  = I Plus  (f i)
-                | i `primIntEQ` ni = I Minus [0::Int,0::Int,2::Int]  -- we are at minBound::Int.
-                | True             = I Minus (f ni)
+_intToInteger i
+  | i `primIntEQ` 0 = I Plus []
+  | i `primIntGE` 0 = f Plus (primIntToWord i)
+  | True            = f Minus (primIntToWord (0 `primIntSub` i))
   where
-    ni = (0::Int) `primIntSub` i
-    f :: Int -> [Int]
-    f x = if primIntEQ x (0::Int) then [] else primIntRem x maxD : f (primIntQuot x maxD)
+    f sign i =
+      let
+        high = i `primWordQuot` maxD
+        low = i `primWordRem` maxD
+      in if high `primWordEQ` 0 then I sign [low] else I sign [low, high]
 
 _integerToInt :: Integer -> Int
-_integerToInt (I sign ds) = s `primIntMul` i
-  where
-    i =
-      case ds of
-        []         -> 0::Int
-        [d1]       -> d1
-        [d1,d2]    -> d1 `primIntAdd` (maxD `primIntMul` d2)
-        d1:d2:d3:_ -> d1 `primIntAdd` (maxD `primIntMul` (d2 `primIntAdd` (maxD `primIntMul` d3)))
-    s =
-      case sign of
-        Plus  -> 1::Int
-        Minus -> 0 `primIntSub` 1
+_integerToInt x = primWordToInt (_integerToWord x)
 
 _wordToInteger :: Word -> Integer
-_wordToInteger i = I Plus  (f i)
+_wordToInteger i
+  | i    `primWordEQ` 0 = I Plus []
+  | high `primWordEQ` 0 = I Plus [low]
+  | True                = I Plus [low, high]
   where
-    f :: Word -> [Int]
-    f x = if x `primWordEQ` (0::Word) then [] else primWordToInt (primWordRem x (primIntToWord maxD)) : f (primWordQuot x (primIntToWord maxD))
+    high = i `primWordQuot` maxD
+    low = i `primWordRem` maxD
 
 _integerToWord :: Integer -> Word
-_integerToWord x = primIntToWord (_integerToInt x)
+_integerToWord (I sign ds) =
+  case sign of
+    Plus  -> i
+    Minus -> 0 `primWordSub` i
+  where
+    i =
+      case ds of
+        []          -> 0 :: Word
+        [d1]        -> d1
+        d1 : d2 : _ -> d1 `primWordAdd` (maxD `primWordMul` d2)
 
 _integerToFloatW :: Integer -> FloatW
 _integerToFloatW (I sign ds) = s `primFloatWMul` loop ds
   where
-    loop [] = 0.0::FloatW
-    loop (i : is) = primFloatWFromInt i `primFloatWAdd` (primFloatWFromInt maxD `primFloatWMul` loop is)
+    loop [] = 0.0 :: FloatW
+    loop (d : ds) = primFloatWFromInt (primWordToInt d) `primFloatWAdd` (primFloatWFromInt (primWordToInt maxD) `primFloatWMul` loop ds)
     s =
       case sign of
-        Plus  -> 1.0::FloatW
+        Plus  -> 1.0 :: FloatW
         Minus -> 0.0 `primFloatWSub` 1.0
--- a/lib/Data/Word.hs
+++ b/lib/Data/Word.hs
@@ -12,12 +12,13 @@
 import Data.Eq
 import Data.Function
 import Data.Int()  -- instances only
-import Data.Integer
+import Data.Integer_Type
 import Data.Integral
 import Data.List
 import Data.Maybe_Type
 import Data.Num
 import Data.Ord
+import Data.Ratio_Type
 import Data.Real
 import Numeric.Show
 import Text.Show
@@ -28,7 +29,7 @@
   (*)  = primWordMul
   abs x = x
   signum x = if x == 0 then 0 else 1
-  fromInteger x = primIntToWord (_integerToInt x)
+  fromInteger = _integerToWord
 
 instance Integral Word where
   quot = primWordQuot