shithub: mc

Download patch

ref: 7e935fc4a7e5cdcd55180f6ec4821c8848c30b0e
parent: 260efa662329b449683813b6662d79e1df932b66
author: Ori Bernstein <ori@eigenstate.org>
date: Thu Jan 8 06:29:19 EST 2015

Rework how checkns() works.

    This way we can't forget to check for namespaces at all points,
    at least not as easily.

--- a/parse/infer.c
+++ b/parse/infer.c
@@ -45,8 +45,8 @@
     size_t nspecializationscope;
 };
 
-static void infernode(Inferstate *st, Node *n, Type *ret, int *sawret);
-static void inferexpr(Inferstate *st, Node *n, Type *ret, int *sawret);
+static void infernode(Inferstate *st, Node **np, Type *ret, int *sawret);
+static void inferexpr(Inferstate *st, Node **np, Type *ret, int *sawret);
 static void inferdecl(Inferstate *st, Node *n);
 static void typesub(Inferstate *st, Node *n);
 static void tybind(Inferstate *st, Type *t);
@@ -365,7 +365,7 @@
     /* Walk through aggregate type members */
     if (t->type == Tystruct) {
         for (i = 0; i < t->nmemb; i++)
-            infernode(st, t->sdecls[i], NULL, NULL);
+            infernode(st, &t->sdecls[i], NULL, NULL);
     } else if (t->type == Tyunion) {
         for (i = 0; i < t->nmemb; i++) {
             t->udecls[i]->utype = t;
@@ -376,7 +376,7 @@
             }
         }
     } else if (t->type == Tyarray) {
-        infernode(st, t->asize, NULL, NULL);
+        infernode(st, &t->asize, NULL, NULL);
     }
 
     for (i = 0; i < t->nsub; i++)
@@ -890,7 +890,7 @@
 
         if (ft->sub[i]->type == Tyvalist)
             break;
-        inferexpr(st, n->expr.args[i], NULL, NULL);
+        inferexpr(st, &n->expr.args[i], NULL, NULL);
         unify(st, n->expr.args[0], ft->sub[i], type(st, n->expr.args[i]));
     }
     if (i < ft->nsub && ft->sub[i]->type != Tyvalist)
@@ -967,7 +967,7 @@
  * member. If it is, it transforms it into the variable
  * reference we should have, instead of the Omemb expr
  * that we do have */
-static void checkns(Inferstate *st, Node *n, Node **ret)
+static Node *checkns(Inferstate *st, Node *n, Node **ret)
 {
     Node *var, *name, *nsname;
     Node **args;
@@ -976,18 +976,18 @@
 
     /* check that this is a namespaced declaration */
     if (n->type != Nexpr)
-        return;
+        return n;
     if (exprop(n) != Omemb)
-        return;
+        return n;
     if (!n->expr.nargs)
-        return;
+        return n;
     args = n->expr.args;
     if (args[0]->type != Nexpr || exprop(args[0]) != Ovar)
-        return;
+        return n;
     name = args[0]->expr.args[0];
     stab = getns(curstab(), name);
     if (!stab)
-        return;
+        return n;
 
     /* substitute the namespaced name */
     nsname = mknsname(n->loc, namestr(name), namestr(args[1]));
@@ -998,6 +998,7 @@
     var->expr.idx = n->expr.idx;
     initvar(st, var, s);
     *ret = var;
+    return var;
 }
 
 static void inferstruct(Inferstate *st, Node *n, int *isconst)
@@ -1006,7 +1007,7 @@
 
     *isconst = 1;
     for (i = 0; i < n->expr.nargs; i++) {
-        infernode(st, n->expr.args[i], NULL, NULL);
+        infernode(st, &n->expr.args[i], NULL, NULL);
         if (!n->expr.args[i]->expr.isconst)
             *isconst = 0;
     }
@@ -1024,7 +1025,7 @@
     len = mkintlit(n->loc, n->expr.nargs);
     t = mktyarray(n->loc, mktyvar(n->loc), len);
     for (i = 0; i < n->expr.nargs; i++) {
-        infernode(st, n->expr.args[i], NULL, NULL);
+        infernode(st, &n->expr.args[i], NULL, NULL);
         unify(st, n, t->sub[0], type(st, n->expr.args[i]));
         if (!n->expr.args[i]->expr.isconst)
             *isconst = 0;
@@ -1040,7 +1041,7 @@
     *isconst = 1;
     types = xalloc(sizeof(Type *)*n->expr.nargs);
     for (i = 0; i < n->expr.nargs; i++) {
-        infernode(st, n->expr.args[i], NULL, NULL);
+        infernode(st, &n->expr.args[i], NULL, NULL);
         n->expr.isconst = n->expr.isconst && n->expr.args[i]->expr.isconst;
         types[i] = type(st, n->expr.args[i]);
     }
@@ -1057,8 +1058,7 @@
     t = tyfreshen(st, tf(st, uc->utype));
     uc = tybase(t)->udecls[uc->id];
     if (uc->etype) {
-        checkns(st, n->expr.args[1], &n->expr.args[1]);
-        inferexpr(st, n->expr.args[1], NULL, NULL);
+        inferexpr(st, &n->expr.args[1], NULL, NULL);
         unify(st, n, uc->etype, type(st, n->expr.args[1]));
     }
     *isconst = n->expr.args[0]->expr.isconst;
@@ -1065,17 +1065,19 @@
     settype(st, n, delayeducon(st, t));
 }
 
-static void inferpat(Inferstate *st, Node *n, Node *val, Node ***bind, size_t *nbind)
+static void inferpat(Inferstate *st, Node **np, Node *val, Node ***bind, size_t *nbind)
 {
     size_t i;
     Node **args;
-    Node *s;
+    Node *s, *n;
     Type *t;
 
+    n = *np;
+    n = checkns(st, n, np);
     args = n->expr.args;
     for (i = 0; i < n->expr.nargs; i++)
         if (args[i]->type == Nexpr)
-            inferpat(st, args[i], val, bind, nbind);
+            inferpat(st, &args[i], val, bind, nbind);
     switch (exprop(n)) {
         case Otup:
         case Ostruct:
@@ -1082,7 +1084,8 @@
         case Oarr:
         case Olit:
         case Omemb:
-            infernode(st, n, NULL, NULL);   break;
+            infernode(st, np, NULL, NULL);
+            break;
         /* arithmetic expressions just need to be constant */
         case Oneg:
         case Oadd:
@@ -1095,7 +1098,7 @@
         case Obor:
         case Obxor:
         case Obnot:
-            infernode(st, n, NULL, NULL);
+            infernode(st, np, NULL, NULL);
             if (!n->expr.isconst)
                 fatal(n, "matching against non-constant expression");
             break;
@@ -1152,7 +1155,7 @@
             /* Omemb can sometimes resolve to a namespace. We have to check
              * this. Icky. */
             checkns(st, args[i], &args[i]);
-            inferexpr(st, args[i], ret, sawret);
+            inferexpr(st, &args[i], ret, sawret);
             isconst = isconst && args[i]->expr.isconst;
         }
     }
@@ -1163,21 +1166,20 @@
     *exprconst = n->expr.isconst;
 }
 
-static void inferexpr(Inferstate *st, Node *n, Type *ret, int *sawret)
+static void inferexpr(Inferstate *st, Node **np, Type *ret, int *sawret)
 {
     Node **args;
     size_t i, nargs;
-    Node *s;
+    Node *s, *n;
     Type *t;
     int isconst;
 
+    n = *np;
     assert(n->type == Nexpr);
     args = n->expr.args;
     nargs = n->expr.nargs;
-    infernode(st, n->expr.idx, NULL, NULL);
-    for (i = 0; i < nargs; i++)
-        if (args[i]->type == Nexpr && exprop(args[i]) == Omemb)
-            checkns(st, args[i], &args[i]);
+    infernode(st, &n->expr.idx, NULL, NULL);
+    n = checkns(st, n, np);
     switch (exprop(n)) {
         /* all operands are same type */
         case Oadd:      /* @a + @a -> @a */
@@ -1355,7 +1357,7 @@
             infersub(st, n, ret, sawret, &isconst);
             switch (args[0]->lit.littype) {
                 case Lfunc:
-                    infernode(st, args[0]->lit.fnval, NULL, NULL); break;
+                    infernode(st, &args[0]->lit.fnval, NULL, NULL); break;
                     /* FIXME: env capture means this is non-const */
                     n->expr.isconst = 1;
                 default:
@@ -1385,8 +1387,8 @@
 
     sawret = 0;
     for (i = 0; i < n->func.nargs; i++)
-        infernode(st, n->func.args[i], NULL, NULL);
-    infernode(st, n->func.body, n->func.type->sub[0], &sawret);
+        infernode(st, &n->func.args[i], NULL, NULL);
+    infernode(st, &n->func.body, n->func.type->sub[0], &sawret);
     /* if there's no return stmt in the function, assume void ret */
     if (!sawret)
         unify(st, n, type(st, n)->sub[0], mktype(Zloc, Tyvoid));
@@ -1463,8 +1465,7 @@
     }
     settype(st, n, t);
     if (n->decl.init) {
-        checkns(st, n->decl.init, &n->decl.init);
-        inferexpr(st, n->decl.init, NULL, NULL);
+        inferexpr(st, &n->decl.init, NULL, NULL);
         unify(st, n, type(st, n), type(st, n->decl.init));
         if (n->decl.isconst && !n->decl.init->expr.isconst)
             fatal(n, "non-const initializer for \"%s\"", ctxstr(st, n));
@@ -1491,12 +1492,13 @@
     free(k);
 }
 
-static void infernode(Inferstate *st, Node *n, Type *ret, int *sawret)
+static void infernode(Inferstate *st, Node **np, Type *ret, int *sawret)
 {
     size_t i, nbound;
-    Node **bound;
+    Node **bound, *n;
     Type *t;
 
+    n = *np;
     if (!n)
         return;
     switch (n->type) {
@@ -1504,7 +1506,7 @@
             pushstab(n->file.globls);
             inferstab(st, n->file.globls);
             for (i = 0; i < n->file.nstmts; i++)
-                infernode(st, n->file.stmts[i], NULL, sawret);
+                infernode(st, &n->file.stmts[i], NULL, sawret);
             popstab();
             break;
         case Ndecl:
@@ -1525,22 +1527,21 @@
             pushstab(n->block.scope);
             inferstab(st, n->block.scope);
             for (i = 0; i < n->block.nstmts; i++) {
-                checkns(st, n->block.stmts[i], &n->block.stmts[i]);
-                infernode(st, n->block.stmts[i], ret, sawret);
+                infernode(st, &n->block.stmts[i], ret, sawret);
             }
             popstab();
             break;
         case Nifstmt:
-            infernode(st, n->ifstmt.cond, NULL, sawret);
-            infernode(st, n->ifstmt.iftrue, ret, sawret);
-            infernode(st, n->ifstmt.iffalse, ret, sawret);
+            infernode(st, &n->ifstmt.cond, NULL, sawret);
+            infernode(st, &n->ifstmt.iftrue, ret, sawret);
+            infernode(st, &n->ifstmt.iffalse, ret, sawret);
             unify(st, n, type(st, n->ifstmt.cond), mktype(n->loc, Tybool));
             break;
         case Nloopstmt:
-            infernode(st, n->loopstmt.init, ret, sawret);
-            infernode(st, n->loopstmt.cond, NULL, sawret);
-            infernode(st, n->loopstmt.step, ret, sawret);
-            infernode(st, n->loopstmt.body, ret, sawret);
+            infernode(st, &n->loopstmt.init, ret, sawret);
+            infernode(st, &n->loopstmt.cond, NULL, sawret);
+            infernode(st, &n->loopstmt.step, ret, sawret);
+            infernode(st, &n->loopstmt.body, ret, sawret);
             unify(st, n, type(st, n->loopstmt.cond), mktype(n->loc, Tybool));
             break;
         case Niterstmt:
@@ -1547,12 +1548,11 @@
             bound = NULL;
             nbound = 0;
 
-            inferpat(st, n->iterstmt.elt, NULL, &bound, &nbound);
+            inferpat(st, &n->iterstmt.elt, NULL, &bound, &nbound);
             addbindings(st, n->iterstmt.body, bound, nbound);
 
-            checkns(st, n->iterstmt.seq, &n->iterstmt.seq);
-            infernode(st, n->iterstmt.seq, NULL, sawret);
-            infernode(st, n->iterstmt.body, ret, sawret);
+            infernode(st, &n->iterstmt.seq, NULL, sawret);
+            infernode(st, &n->iterstmt.body, ret, sawret);
 
             t = mktyidxhack(n->loc, mktyvar(n->loc));
             constrain(st, n, type(st, n->iterstmt.seq), traittab[Tcidx]);
@@ -1560,11 +1560,11 @@
             unify(st, n, type(st, n->iterstmt.elt), t->sub[0]);
             break;
         case Nmatchstmt:
-            infernode(st, n->matchstmt.val, NULL, sawret);
+            infernode(st, &n->matchstmt.val, NULL, sawret);
             if (tybase(type(st, n->matchstmt.val))->type == Tyvoid)
                 fatal(n, "Can't match against a void type near %s", ctxstr(st, n->matchstmt.val));
             for (i = 0; i < n->matchstmt.nmatches; i++) {
-                infernode(st, n->matchstmt.matches[i], ret, sawret);
+                infernode(st, &n->matchstmt.matches[i], ret, sawret);
                 unify(st, n, type(st, n->matchstmt.val), type(st, n->matchstmt.matches[i]->match.pat));
             }
             break;
@@ -1571,12 +1571,12 @@
         case Nmatch:
             bound = NULL;
             nbound = 0;
-            inferpat(st, n->match.pat, NULL, &bound, &nbound);
+            inferpat(st, &n->match.pat, NULL, &bound, &nbound);
             addbindings(st, n->match.block, bound, nbound);
-            infernode(st, n->match.block, ret, sawret);
+            infernode(st, &n->match.block, ret, sawret);
             break;
         case Nexpr:
-            inferexpr(st, n, ret, sawret);
+            inferexpr(st, np, ret, sawret);
             break;
         case Nfunc:
             setsuper(n->func.scope, curstab());
@@ -2173,7 +2173,7 @@
 
     /* do the inference */
     applytraits(&st, file);
-    infernode(&st, file, NULL, NULL);
+    infernode(&st, &file, NULL, NULL);
     postcheck(&st, file);
 
     /* and replace type vars with actual types */