ref: 7e935fc4a7e5cdcd55180f6ec4821c8848c30b0e
parent: 260efa662329b449683813b6662d79e1df932b66
author: Ori Bernstein <ori@eigenstate.org>
date: Thu Jan 8 06:29:19 EST 2015
Rework how checkns() works. This way we can't forget to check for namespaces at all points, at least not as easily.
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -45,8 +45,8 @@
size_t nspecializationscope;
};
-static void infernode(Inferstate *st, Node *n, Type *ret, int *sawret);
-static void inferexpr(Inferstate *st, Node *n, Type *ret, int *sawret);
+static void infernode(Inferstate *st, Node **np, Type *ret, int *sawret);
+static void inferexpr(Inferstate *st, Node **np, Type *ret, int *sawret);
static void inferdecl(Inferstate *st, Node *n);
static void typesub(Inferstate *st, Node *n);
static void tybind(Inferstate *st, Type *t);
@@ -365,7 +365,7 @@
/* Walk through aggregate type members */
if (t->type == Tystruct) {
for (i = 0; i < t->nmemb; i++)
- infernode(st, t->sdecls[i], NULL, NULL);
+ infernode(st, &t->sdecls[i], NULL, NULL);
} else if (t->type == Tyunion) {
for (i = 0; i < t->nmemb; i++) {
t->udecls[i]->utype = t;
@@ -376,7 +376,7 @@
}
}
} else if (t->type == Tyarray) {
- infernode(st, t->asize, NULL, NULL);
+ infernode(st, &t->asize, NULL, NULL);
}
for (i = 0; i < t->nsub; i++)
@@ -890,7 +890,7 @@
if (ft->sub[i]->type == Tyvalist)
break;
- inferexpr(st, n->expr.args[i], NULL, NULL);
+ inferexpr(st, &n->expr.args[i], NULL, NULL);
unify(st, n->expr.args[0], ft->sub[i], type(st, n->expr.args[i]));
}
if (i < ft->nsub && ft->sub[i]->type != Tyvalist)
@@ -967,7 +967,7 @@
* member. If it is, it transforms it into the variable
* reference we should have, instead of the Omemb expr
* that we do have */
-static void checkns(Inferstate *st, Node *n, Node **ret)
+static Node *checkns(Inferstate *st, Node *n, Node **ret)
{
Node *var, *name, *nsname;
Node **args;
@@ -976,18 +976,18 @@
/* check that this is a namespaced declaration */
if (n->type != Nexpr)
- return;
+ return n;
if (exprop(n) != Omemb)
- return;
+ return n;
if (!n->expr.nargs)
- return;
+ return n;
args = n->expr.args;
if (args[0]->type != Nexpr || exprop(args[0]) != Ovar)
- return;
+ return n;
name = args[0]->expr.args[0];
stab = getns(curstab(), name);
if (!stab)
- return;
+ return n;
/* substitute the namespaced name */
nsname = mknsname(n->loc, namestr(name), namestr(args[1]));
@@ -998,6 +998,7 @@
var->expr.idx = n->expr.idx;
initvar(st, var, s);
*ret = var;
+ return var;
}
static void inferstruct(Inferstate *st, Node *n, int *isconst)
@@ -1006,7 +1007,7 @@
*isconst = 1;
for (i = 0; i < n->expr.nargs; i++) {
- infernode(st, n->expr.args[i], NULL, NULL);
+ infernode(st, &n->expr.args[i], NULL, NULL);
if (!n->expr.args[i]->expr.isconst)
*isconst = 0;
}
@@ -1024,7 +1025,7 @@
len = mkintlit(n->loc, n->expr.nargs);
t = mktyarray(n->loc, mktyvar(n->loc), len);
for (i = 0; i < n->expr.nargs; i++) {
- infernode(st, n->expr.args[i], NULL, NULL);
+ infernode(st, &n->expr.args[i], NULL, NULL);
unify(st, n, t->sub[0], type(st, n->expr.args[i]));
if (!n->expr.args[i]->expr.isconst)
*isconst = 0;
@@ -1040,7 +1041,7 @@
*isconst = 1;
types = xalloc(sizeof(Type *)*n->expr.nargs);
for (i = 0; i < n->expr.nargs; i++) {
- infernode(st, n->expr.args[i], NULL, NULL);
+ infernode(st, &n->expr.args[i], NULL, NULL);
n->expr.isconst = n->expr.isconst && n->expr.args[i]->expr.isconst;
types[i] = type(st, n->expr.args[i]);
}
@@ -1057,8 +1058,7 @@
t = tyfreshen(st, tf(st, uc->utype));
uc = tybase(t)->udecls[uc->id];
if (uc->etype) {
- checkns(st, n->expr.args[1], &n->expr.args[1]);
- inferexpr(st, n->expr.args[1], NULL, NULL);
+ inferexpr(st, &n->expr.args[1], NULL, NULL);
unify(st, n, uc->etype, type(st, n->expr.args[1]));
}
*isconst = n->expr.args[0]->expr.isconst;
@@ -1065,17 +1065,19 @@
settype(st, n, delayeducon(st, t));
}
-static void inferpat(Inferstate *st, Node *n, Node *val, Node ***bind, size_t *nbind)
+static void inferpat(Inferstate *st, Node **np, Node *val, Node ***bind, size_t *nbind)
{
size_t i;
Node **args;
- Node *s;
+ Node *s, *n;
Type *t;
+ n = *np;
+ n = checkns(st, n, np);
args = n->expr.args;
for (i = 0; i < n->expr.nargs; i++)
if (args[i]->type == Nexpr)
- inferpat(st, args[i], val, bind, nbind);
+ inferpat(st, &args[i], val, bind, nbind);
switch (exprop(n)) {
case Otup:
case Ostruct:
@@ -1082,7 +1084,8 @@
case Oarr:
case Olit:
case Omemb:
- infernode(st, n, NULL, NULL); break;
+ infernode(st, np, NULL, NULL);
+ break;
/* arithmetic expressions just need to be constant */
case Oneg:
case Oadd:
@@ -1095,7 +1098,7 @@
case Obor:
case Obxor:
case Obnot:
- infernode(st, n, NULL, NULL);
+ infernode(st, np, NULL, NULL);
if (!n->expr.isconst)
fatal(n, "matching against non-constant expression");
break;
@@ -1152,7 +1155,7 @@
/* Omemb can sometimes resolve to a namespace. We have to check
* this. Icky. */
checkns(st, args[i], &args[i]);
- inferexpr(st, args[i], ret, sawret);
+ inferexpr(st, &args[i], ret, sawret);
isconst = isconst && args[i]->expr.isconst;
}
}
@@ -1163,21 +1166,20 @@
*exprconst = n->expr.isconst;
}
-static void inferexpr(Inferstate *st, Node *n, Type *ret, int *sawret)
+static void inferexpr(Inferstate *st, Node **np, Type *ret, int *sawret)
{
Node **args;
size_t i, nargs;
- Node *s;
+ Node *s, *n;
Type *t;
int isconst;
+ n = *np;
assert(n->type == Nexpr);
args = n->expr.args;
nargs = n->expr.nargs;
- infernode(st, n->expr.idx, NULL, NULL);
- for (i = 0; i < nargs; i++)
- if (args[i]->type == Nexpr && exprop(args[i]) == Omemb)
- checkns(st, args[i], &args[i]);
+ infernode(st, &n->expr.idx, NULL, NULL);
+ n = checkns(st, n, np);
switch (exprop(n)) {
/* all operands are same type */
case Oadd: /* @a + @a -> @a */
@@ -1355,7 +1357,7 @@
infersub(st, n, ret, sawret, &isconst);
switch (args[0]->lit.littype) {
case Lfunc:
- infernode(st, args[0]->lit.fnval, NULL, NULL); break;
+ infernode(st, &args[0]->lit.fnval, NULL, NULL); break;
/* FIXME: env capture means this is non-const */
n->expr.isconst = 1;
default:
@@ -1385,8 +1387,8 @@
sawret = 0;
for (i = 0; i < n->func.nargs; i++)
- infernode(st, n->func.args[i], NULL, NULL);
- infernode(st, n->func.body, n->func.type->sub[0], &sawret);
+ infernode(st, &n->func.args[i], NULL, NULL);
+ infernode(st, &n->func.body, n->func.type->sub[0], &sawret);
/* if there's no return stmt in the function, assume void ret */
if (!sawret)
unify(st, n, type(st, n)->sub[0], mktype(Zloc, Tyvoid));
@@ -1463,8 +1465,7 @@
}
settype(st, n, t);
if (n->decl.init) {
- checkns(st, n->decl.init, &n->decl.init);
- inferexpr(st, n->decl.init, NULL, NULL);
+ inferexpr(st, &n->decl.init, NULL, NULL);
unify(st, n, type(st, n), type(st, n->decl.init));
if (n->decl.isconst && !n->decl.init->expr.isconst)
fatal(n, "non-const initializer for \"%s\"", ctxstr(st, n));
@@ -1491,12 +1492,13 @@
free(k);
}
-static void infernode(Inferstate *st, Node *n, Type *ret, int *sawret)
+static void infernode(Inferstate *st, Node **np, Type *ret, int *sawret)
{
size_t i, nbound;
- Node **bound;
+ Node **bound, *n;
Type *t;
+ n = *np;
if (!n)
return;
switch (n->type) {
@@ -1504,7 +1506,7 @@
pushstab(n->file.globls);
inferstab(st, n->file.globls);
for (i = 0; i < n->file.nstmts; i++)
- infernode(st, n->file.stmts[i], NULL, sawret);
+ infernode(st, &n->file.stmts[i], NULL, sawret);
popstab();
break;
case Ndecl:
@@ -1525,22 +1527,21 @@
pushstab(n->block.scope);
inferstab(st, n->block.scope);
for (i = 0; i < n->block.nstmts; i++) {
- checkns(st, n->block.stmts[i], &n->block.stmts[i]);
- infernode(st, n->block.stmts[i], ret, sawret);
+ infernode(st, &n->block.stmts[i], ret, sawret);
}
popstab();
break;
case Nifstmt:
- infernode(st, n->ifstmt.cond, NULL, sawret);
- infernode(st, n->ifstmt.iftrue, ret, sawret);
- infernode(st, n->ifstmt.iffalse, ret, sawret);
+ infernode(st, &n->ifstmt.cond, NULL, sawret);
+ infernode(st, &n->ifstmt.iftrue, ret, sawret);
+ infernode(st, &n->ifstmt.iffalse, ret, sawret);
unify(st, n, type(st, n->ifstmt.cond), mktype(n->loc, Tybool));
break;
case Nloopstmt:
- infernode(st, n->loopstmt.init, ret, sawret);
- infernode(st, n->loopstmt.cond, NULL, sawret);
- infernode(st, n->loopstmt.step, ret, sawret);
- infernode(st, n->loopstmt.body, ret, sawret);
+ infernode(st, &n->loopstmt.init, ret, sawret);
+ infernode(st, &n->loopstmt.cond, NULL, sawret);
+ infernode(st, &n->loopstmt.step, ret, sawret);
+ infernode(st, &n->loopstmt.body, ret, sawret);
unify(st, n, type(st, n->loopstmt.cond), mktype(n->loc, Tybool));
break;
case Niterstmt:
@@ -1547,12 +1548,11 @@
bound = NULL;
nbound = 0;
- inferpat(st, n->iterstmt.elt, NULL, &bound, &nbound);
+ inferpat(st, &n->iterstmt.elt, NULL, &bound, &nbound);
addbindings(st, n->iterstmt.body, bound, nbound);
- checkns(st, n->iterstmt.seq, &n->iterstmt.seq);
- infernode(st, n->iterstmt.seq, NULL, sawret);
- infernode(st, n->iterstmt.body, ret, sawret);
+ infernode(st, &n->iterstmt.seq, NULL, sawret);
+ infernode(st, &n->iterstmt.body, ret, sawret);
t = mktyidxhack(n->loc, mktyvar(n->loc));
constrain(st, n, type(st, n->iterstmt.seq), traittab[Tcidx]);
@@ -1560,11 +1560,11 @@
unify(st, n, type(st, n->iterstmt.elt), t->sub[0]);
break;
case Nmatchstmt:
- infernode(st, n->matchstmt.val, NULL, sawret);
+ infernode(st, &n->matchstmt.val, NULL, sawret);
if (tybase(type(st, n->matchstmt.val))->type == Tyvoid)
fatal(n, "Can't match against a void type near %s", ctxstr(st, n->matchstmt.val));
for (i = 0; i < n->matchstmt.nmatches; i++) {
- infernode(st, n->matchstmt.matches[i], ret, sawret);
+ infernode(st, &n->matchstmt.matches[i], ret, sawret);
unify(st, n, type(st, n->matchstmt.val), type(st, n->matchstmt.matches[i]->match.pat));
}
break;
@@ -1571,12 +1571,12 @@
case Nmatch:
bound = NULL;
nbound = 0;
- inferpat(st, n->match.pat, NULL, &bound, &nbound);
+ inferpat(st, &n->match.pat, NULL, &bound, &nbound);
addbindings(st, n->match.block, bound, nbound);
- infernode(st, n->match.block, ret, sawret);
+ infernode(st, &n->match.block, ret, sawret);
break;
case Nexpr:
- inferexpr(st, n, ret, sawret);
+ inferexpr(st, np, ret, sawret);
break;
case Nfunc:
setsuper(n->func.scope, curstab());
@@ -2173,7 +2173,7 @@
/* do the inference */
applytraits(&st, file);
- infernode(&st, file, NULL, NULL);
+ infernode(&st, &file, NULL, NULL);
postcheck(&st, file);
/* and replace type vars with actual types */