shithub: MicroHs

Download patch

ref: dfda74cbaf3d2f129cc4d7da87435464f00a29fe
parent: b2c417daf11c7b5e6c369e6e4b36fb35cc18eae1
author: Lennart Augustsson <lennart@augustsson.net>
date: Sun Dec 8 06:43:32 EST 2024

More pattern synonyms

--- a/src/MicroHs/Expr.hs
+++ b/src/MicroHs/Expr.hs
@@ -33,7 +33,7 @@
   eApp2, eApp3, eApps,
   lhsToType,
   subst,
-  allVarsExpr, allVarsBind, allVarsEqns,
+  allVarsExpr, allVarsBind, allVarsEqns, allVarsPat,
   setSLocExpr,
   errorMessage,
   Assoc(..), Fixity,
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -946,9 +946,8 @@
   dst <- tcDefsType ds
 --  tcTrace ("tcDefs 2:\n" ++ showEDefs dst)
   mapM_ addTypeAndData dst
-  dstp <- tcPatSyn dst
-  dste <- tcExpand impt dstp
-  tcTrace ("tcDefs 3:\n" ++ showEDefs dste)
+  dste <- tcExpandClassInst impt dst
+--  tcTrace ("tcDefs 3:\n" ++ showEDefs dste)
   case impt of
     ImpNormal -> do
       setDefault dste
@@ -999,8 +998,8 @@
   return $ M.fromList iks'
 
 -- Expand class and instance definitions (must be done after type synonym processing)
-tcExpand :: ImpType -> [EDef] -> T [EDef]
-tcExpand impt dst = withTypeTable $ do
+tcExpandClassInst :: ImpType -> [EDef] -> T [EDef]
+tcExpandClassInst impt dst = withTypeTable $ do
   dsc <- concat <$> mapM (expandClass impt) dst       -- Expand all class definitions
   dsf <- concat <$> mapM expandField dsc              -- Add HasField instances
 --  tcTrace $ showEDefs dsf
@@ -1091,6 +1090,7 @@
     Type    lhs t          -> withLHS lhs $ \ lhs' -> first                    (Type    lhs') <$> tInferTypeT t
     Class   ctx lhs fds ms -> withLHS lhs $ \ lhs' -> flip (,) kConstraint <$> (Class         <$> tcCtx ctx <*> return lhs' <*> mapM tcFD fds <*> mapM tcMethod ms)
     Sign      is t         ->                                                  Sign      is   <$> tCheckTypeTImpl kType t
+    PatternSign is t       ->                                                  PatternSign is   <$> tCheckTypeTImpl kType t
     ForImp ie i t          ->                                                  ForImp ie i    <$> tCheckTypeTImpl kType t
     Instance ct m          ->                                                  Instance       <$> tCheckTypeTImpl kConstraint ct <*> return m
     Default mc ts          ->                                                  Default (Just c) <$> mapM (tcDefault c) ts
@@ -1276,14 +1276,17 @@
 --  tcTrace $ "tcDefsValue: ------------ start"
   -- Gather up all type signatures, and put them in the environment.
   mapM_ addValueType defs
-  let smap = M.fromList [ (i, ()) | Sign is _ <- defs, i <- is ]
+  let smap = M.fromList $ [ (i, ()) | Sign is _ <- defs, i <- is ] ++
+                          [ (i, ()) | PatternSign is _ <- defs, i <- is ]
       -- Split Fcn into those without and with type signatures
       unsigned = filter noSign defs
         where noSign (Fcn i _) = isNothing $ M.lookup i smap
+              noSign (Pattern (i, _) _ _) = isNothing $ M.lookup i smap
               noSign _ = False
       -- split the unsigned defs into strongly connected components
       sccs = stronglyConnComp $ map node unsigned
         where node d@(Fcn i e) = (d, i, allVarsEqns e)
+              node d@(Pattern (i, _) p me) = (d, i, allVarsPat p $ maybe [] allVarsEqns me)
               node _ = undefined
       tcSCC (AcyclicSCC d) = tInferDefs [d]
       tcSCC (CyclicSCC ds) = tInferDefs ds
@@ -1293,7 +1296,7 @@
   --  type check all definitions (the inferred ones will be rechecked)
 --  tcTrace $ "tcDefsValue: ------------ check"
   defs' <- mapM (\ d -> do { tcReset; tcDefValue d}) defs
-  return $ concat signDefs ++ defs'
+  tcPatSyn $ concat signDefs ++ defs'
 
 -- Infer a type for a definition
 tInferDefs :: [EDef] -> T [EDef]
@@ -1300,13 +1303,19 @@
 tInferDefs fcns = do
   tcReset
   -- Invent type variables for the definitions
-  xts <- mapM (\ (Fcn i _) -> (,) i <$> newUVar) fcns
+  let idOf (Fcn i _) = i
+      idOf (Pattern (i, _) _ _) = i
+      idOf _ = undefined
+  xts <- mapM (\ d -> (,) (idOf d) <$> newUVar) fcns
   --tcTrace $ "tInferDefs: " ++ show (map fst xts)
   -- Temporarily extend the local environment with the type variables
   withExtVals xts $ do
     -- Infer types for all the Fcns, ignore the new bodies.
     -- The bodies will be re-typecked in tcDefsValues.
-    zipWithM_ (\ (Fcn _ eqns) (_, t) -> tcEqns False t eqns) fcns xts
+    let tc (Fcn _ eqns) (_, t) = do _ <- tcEqns False t eqns; return ()
+        tc d@(Pattern _ _ _) (_, t) = do _ <- tcPattern d t; return ()
+        tc _ _ = impossible
+    zipWithM_ tc fcns xts
   -- Get the unsolved constraints
   ctx <- getUnsolved
   -- For each definition, quantify over the free meta variables, and include
@@ -1363,6 +1372,14 @@
       addConFields tycon con
     ForImp _ i t -> extValQTop i t
     Class ctx (i, vks) fds ms -> addValueClass ctx i vks fds ms
+    PatternSign is at -> do
+      let t' =
+            -- Patterns must have two universals.
+            -- XXX Add double contexts
+            case at of
+              EForall b vs t -> EForall b     vs $ EForall False []  t
+              _              -> EForall False [] $ EForall False [] at
+      mapM_ (\ i -> extValQTop i t') is
     _ -> return ()
 
 -- XXX FunDep
@@ -1415,8 +1432,22 @@
       mn <- gets moduleName
       t' <- expandSyn t
       return (ForImp ie (qualIdent mn i) t')
+    Pattern (i, _) _ _ -> do
+      (_, t) <- tLookup "pattern type signature" i
+      t' <- expandSyn t
+      tcPattern adef t'
     _ -> return adef
 
+tcPattern :: EDef -> EType -> T EDef
+tcPattern (Pattern (i, vks) p me) t = do
+  traceM ("Pattern " ++ show (i, vks, p, t))
+  p' <- return p  -- XXX
+  me' <- traverse (tcEqns True t) me
+  mn <- gets moduleName
+  checkConstraints
+  return $ Pattern (qualIdent mn i, vks) p' me'
+tcPattern _ _ = error "tcPattern"
+
 -- Add implicit forall and type check.
 tCheckTypeTImpl :: HasCallStack => EType -> EType -> T EType
 tCheckTypeTImpl tchk t@(EForall _ _ _) = tCheckTypeT tchk t
@@ -2947,14 +2978,66 @@
 
 tcPatSyn :: [EDef] -> T [EDef]
 tcPatSyn ds = do
-  let one d@(Pattern (i, iks) p me) = do
-        addPatSyn i (map idKindIdent iks, p)
-        case me of
-          Nothing -> return [d]
-          Just e  -> return [d, Fcn i e]
-      one d = return [d]
-  concat <$> mapM one ds
+  let patSyns = [ (i, iks, p, mes) | Pattern (i, iks) p mes <- ds ]
+  if null patSyns then
+    return ds
+   else do
+    let ds' = ds ++ [ Fcn i es | (i, _, _, Just es) <- patSyns ]
+        ps = M.fromList [ (i, (map idKindIdent iks, p)) | (i, iks, p, _) <- patSyns ]
+        tr as (EVar i) | Just (vs, p) <- M.lookup i ps = if length as /= length vs then tr [] $ subst (zip vs as) p
+                                                         else errorMessage (getSLoc i) "Bad synonym arity"
+        tr as (EApp f a) = tr (as ++ [a]) f
+        tr [] p = p
+        tr _ _ = undefined
+    return $ transformPat (tr []) ds'
 
-addPatSyn :: Ident -> ([Ident], EPat) -> T ()
-addPatSyn i ps =
-  modify $ \ ts -> ts{ patSynTable = M.insert i ps (patSynTable ts) }
+class TransformPat a where
+  transformPat :: (EPat -> EPat) -> a -> a
+
+instance (TransformPat a) => TransformPat [a] where
+  transformPat f es = map (transformPat f) es
+
+instance TransformPat EDef where
+  transformPat f (Fcn i eqns) = Fcn i (transformPat f eqns)
+  transformPat _ d = d
+
+instance TransformPat Expr where
+  transformPat _ e@(EVar _) = e
+  transformPat f (EApp e1 e2) = EApp (transformPat f e1) (transformPat f e2)
+  transformPat f (ELam es) = ELam (transformPat f es)
+  transformPat _ e@(ELit _ _) = e
+  transformPat f (ECase e arms) = ECase (transformPat f e) (transformPat f arms)
+  transformPat f (ELet bs e) = ELet (transformPat f bs) (transformPat f e)
+  transformPat f (EListish l) = EListish (transformPat f l)
+  transformPat f (EIf e1 e2 e3) = EIf (transformPat f e1) (transformPat f e2) (transformPat f e3)
+  transformPat _ e@(ECon _) = e
+  transformPat _ e = impossibleShow e
+
+instance TransformPat Listish where
+  transformPat f (LList es) = LList (transformPat f es)
+  transformPat f (LCompr e ss) = LCompr (transformPat f e) (transformPat f ss)
+  transformPat _ _ = impossible
+
+instance TransformPat ECaseArm where
+  transformPat f (p, alts) = (f p, transformPat f alts)
+
+instance TransformPat EStmt where
+  transformPat f (SBind p e) = SBind (f p) (transformPat f e)
+  transformPat f (SThen e) = SThen (transformPat f e)
+  transformPat f (SLet bs) = SLet (transformPat f bs)
+
+instance TransformPat EBind where
+  transformPat f (BFcn i es) = BFcn i (transformPat f es)
+  transformPat f (BPat p e) = BPat (f p) (transformPat f e)
+  transformPat _ b = b
+
+instance TransformPat Eqn where
+  transformPat f (Eqn ps as) = Eqn (map f ps) (transformPat f as)
+
+instance TransformPat EAlts where
+  transformPat f (EAlts as bs) = EAlts (transformPat f as) (transformPat f bs)
+
+instance TransformPat EAlt where
+  transformPat f (ss, e) = (transformPat f ss, transformPat f e)
+
+