ref: dfda74cbaf3d2f129cc4d7da87435464f00a29fe
parent: b2c417daf11c7b5e6c369e6e4b36fb35cc18eae1
author: Lennart Augustsson <lennart@augustsson.net>
date: Sun Dec 8 06:43:32 EST 2024
More pattern synonyms
--- a/src/MicroHs/Expr.hs
+++ b/src/MicroHs/Expr.hs
@@ -33,7 +33,7 @@
eApp2, eApp3, eApps,
lhsToType,
subst,
- allVarsExpr, allVarsBind, allVarsEqns,
+ allVarsExpr, allVarsBind, allVarsEqns, allVarsPat,
setSLocExpr,
errorMessage,
Assoc(..), Fixity,
--- a/src/MicroHs/TypeCheck.hs
+++ b/src/MicroHs/TypeCheck.hs
@@ -946,9 +946,8 @@
dst <- tcDefsType ds
-- tcTrace ("tcDefs 2:\n" ++ showEDefs dst)
mapM_ addTypeAndData dst
- dstp <- tcPatSyn dst
- dste <- tcExpand impt dstp
- tcTrace ("tcDefs 3:\n" ++ showEDefs dste)
+ dste <- tcExpandClassInst impt dst
+-- tcTrace ("tcDefs 3:\n" ++ showEDefs dste)
case impt of
ImpNormal -> do
setDefault dste
@@ -999,8 +998,8 @@
return $ M.fromList iks'
-- Expand class and instance definitions (must be done after type synonym processing)
-tcExpand :: ImpType -> [EDef] -> T [EDef]
-tcExpand impt dst = withTypeTable $ do
+tcExpandClassInst :: ImpType -> [EDef] -> T [EDef]
+tcExpandClassInst impt dst = withTypeTable $ do
dsc <- concat <$> mapM (expandClass impt) dst -- Expand all class definitions
dsf <- concat <$> mapM expandField dsc -- Add HasField instances
-- tcTrace $ showEDefs dsf
@@ -1091,6 +1090,7 @@
Type lhs t -> withLHS lhs $ \ lhs' -> first (Type lhs') <$> tInferTypeT t
Class ctx lhs fds ms -> withLHS lhs $ \ lhs' -> flip (,) kConstraint <$> (Class <$> tcCtx ctx <*> return lhs' <*> mapM tcFD fds <*> mapM tcMethod ms)
Sign is t -> Sign is <$> tCheckTypeTImpl kType t
+ PatternSign is t -> PatternSign is <$> tCheckTypeTImpl kType t
ForImp ie i t -> ForImp ie i <$> tCheckTypeTImpl kType t
Instance ct m -> Instance <$> tCheckTypeTImpl kConstraint ct <*> return m
Default mc ts -> Default (Just c) <$> mapM (tcDefault c) ts
@@ -1276,14 +1276,17 @@
-- tcTrace $ "tcDefsValue: ------------ start"
-- Gather up all type signatures, and put them in the environment.
mapM_ addValueType defs
- let smap = M.fromList [ (i, ()) | Sign is _ <- defs, i <- is ]
+ let smap = M.fromList $ [ (i, ()) | Sign is _ <- defs, i <- is ] ++
+ [ (i, ()) | PatternSign is _ <- defs, i <- is ]
-- Split Fcn into those without and with type signatures
unsigned = filter noSign defs
where noSign (Fcn i _) = isNothing $ M.lookup i smap
+ noSign (Pattern (i, _) _ _) = isNothing $ M.lookup i smap
noSign _ = False
-- split the unsigned defs into strongly connected components
sccs = stronglyConnComp $ map node unsigned
where node d@(Fcn i e) = (d, i, allVarsEqns e)
+ node d@(Pattern (i, _) p me) = (d, i, allVarsPat p $ maybe [] allVarsEqns me)
node _ = undefined
tcSCC (AcyclicSCC d) = tInferDefs [d]
tcSCC (CyclicSCC ds) = tInferDefs ds
@@ -1293,7 +1296,7 @@
-- type check all definitions (the inferred ones will be rechecked)
-- tcTrace $ "tcDefsValue: ------------ check"
defs' <- mapM (\ d -> do { tcReset; tcDefValue d}) defs
- return $ concat signDefs ++ defs'
+ tcPatSyn $ concat signDefs ++ defs'
-- Infer a type for a definition
tInferDefs :: [EDef] -> T [EDef]
@@ -1300,13 +1303,19 @@
tInferDefs fcns = do
tcReset
-- Invent type variables for the definitions
- xts <- mapM (\ (Fcn i _) -> (,) i <$> newUVar) fcns
+ let idOf (Fcn i _) = i
+ idOf (Pattern (i, _) _ _) = i
+ idOf _ = undefined
+ xts <- mapM (\ d -> (,) (idOf d) <$> newUVar) fcns
--tcTrace $ "tInferDefs: " ++ show (map fst xts)
-- Temporarily extend the local environment with the type variables
withExtVals xts $ do
-- Infer types for all the Fcns, ignore the new bodies.
-- The bodies will be re-typecked in tcDefsValues.
- zipWithM_ (\ (Fcn _ eqns) (_, t) -> tcEqns False t eqns) fcns xts
+ let tc (Fcn _ eqns) (_, t) = do _ <- tcEqns False t eqns; return ()
+ tc d@(Pattern _ _ _) (_, t) = do _ <- tcPattern d t; return ()
+ tc _ _ = impossible
+ zipWithM_ tc fcns xts
-- Get the unsolved constraints
ctx <- getUnsolved
-- For each definition, quantify over the free meta variables, and include
@@ -1363,6 +1372,14 @@
addConFields tycon con
ForImp _ i t -> extValQTop i t
Class ctx (i, vks) fds ms -> addValueClass ctx i vks fds ms
+ PatternSign is at -> do
+ let t' =
+ -- Patterns must have two universals.
+ -- XXX Add double contexts
+ case at of
+ EForall b vs t -> EForall b vs $ EForall False [] t
+ _ -> EForall False [] $ EForall False [] at
+ mapM_ (\ i -> extValQTop i t') is
_ -> return ()
-- XXX FunDep
@@ -1415,8 +1432,22 @@
mn <- gets moduleName
t' <- expandSyn t
return (ForImp ie (qualIdent mn i) t')
+ Pattern (i, _) _ _ -> do
+ (_, t) <- tLookup "pattern type signature" i
+ t' <- expandSyn t
+ tcPattern adef t'
_ -> return adef
+tcPattern :: EDef -> EType -> T EDef
+tcPattern (Pattern (i, vks) p me) t = do
+ traceM ("Pattern " ++ show (i, vks, p, t))
+ p' <- return p -- XXX
+ me' <- traverse (tcEqns True t) me
+ mn <- gets moduleName
+ checkConstraints
+ return $ Pattern (qualIdent mn i, vks) p' me'
+tcPattern _ _ = error "tcPattern"
+
-- Add implicit forall and type check.
tCheckTypeTImpl :: HasCallStack => EType -> EType -> T EType
tCheckTypeTImpl tchk t@(EForall _ _ _) = tCheckTypeT tchk t
@@ -2947,14 +2978,66 @@
tcPatSyn :: [EDef] -> T [EDef]
tcPatSyn ds = do
- let one d@(Pattern (i, iks) p me) = do
- addPatSyn i (map idKindIdent iks, p)
- case me of
- Nothing -> return [d]
- Just e -> return [d, Fcn i e]
- one d = return [d]
- concat <$> mapM one ds
+ let patSyns = [ (i, iks, p, mes) | Pattern (i, iks) p mes <- ds ]
+ if null patSyns then
+ return ds
+ else do
+ let ds' = ds ++ [ Fcn i es | (i, _, _, Just es) <- patSyns ]
+ ps = M.fromList [ (i, (map idKindIdent iks, p)) | (i, iks, p, _) <- patSyns ]
+ tr as (EVar i) | Just (vs, p) <- M.lookup i ps = if length as /= length vs then tr [] $ subst (zip vs as) p
+ else errorMessage (getSLoc i) "Bad synonym arity"
+ tr as (EApp f a) = tr (as ++ [a]) f
+ tr [] p = p
+ tr _ _ = undefined
+ return $ transformPat (tr []) ds'
-addPatSyn :: Ident -> ([Ident], EPat) -> T ()
-addPatSyn i ps =
- modify $ \ ts -> ts{ patSynTable = M.insert i ps (patSynTable ts) }
+class TransformPat a where
+ transformPat :: (EPat -> EPat) -> a -> a
+
+instance (TransformPat a) => TransformPat [a] where
+ transformPat f es = map (transformPat f) es
+
+instance TransformPat EDef where
+ transformPat f (Fcn i eqns) = Fcn i (transformPat f eqns)
+ transformPat _ d = d
+
+instance TransformPat Expr where
+ transformPat _ e@(EVar _) = e
+ transformPat f (EApp e1 e2) = EApp (transformPat f e1) (transformPat f e2)
+ transformPat f (ELam es) = ELam (transformPat f es)
+ transformPat _ e@(ELit _ _) = e
+ transformPat f (ECase e arms) = ECase (transformPat f e) (transformPat f arms)
+ transformPat f (ELet bs e) = ELet (transformPat f bs) (transformPat f e)
+ transformPat f (EListish l) = EListish (transformPat f l)
+ transformPat f (EIf e1 e2 e3) = EIf (transformPat f e1) (transformPat f e2) (transformPat f e3)
+ transformPat _ e@(ECon _) = e
+ transformPat _ e = impossibleShow e
+
+instance TransformPat Listish where
+ transformPat f (LList es) = LList (transformPat f es)
+ transformPat f (LCompr e ss) = LCompr (transformPat f e) (transformPat f ss)
+ transformPat _ _ = impossible
+
+instance TransformPat ECaseArm where
+ transformPat f (p, alts) = (f p, transformPat f alts)
+
+instance TransformPat EStmt where
+ transformPat f (SBind p e) = SBind (f p) (transformPat f e)
+ transformPat f (SThen e) = SThen (transformPat f e)
+ transformPat f (SLet bs) = SLet (transformPat f bs)
+
+instance TransformPat EBind where
+ transformPat f (BFcn i es) = BFcn i (transformPat f es)
+ transformPat f (BPat p e) = BPat (f p) (transformPat f e)
+ transformPat _ b = b
+
+instance TransformPat Eqn where
+ transformPat f (Eqn ps as) = Eqn (map f ps) (transformPat f as)
+
+instance TransformPat EAlts where
+ transformPat f (EAlts as bs) = EAlts (transformPat f as) (transformPat f bs)
+
+instance TransformPat EAlt where
+ transformPat f (ss, e) = (transformPat f ss, transformPat f e)
+
+