shithub: mc

Download patch

ref: 056ef88fc6d281ae626ddf2c598afe2abdb06888
parent: cde7e0d999cf3d56f48d830cf3149f779d86fb31
author: Ori Bernstein <ori@eigenstate.org>
date: Sat Sep 26 22:29:24 EDT 2015

Working closures.

--- a/6/asm.h
+++ b/6/asm.h
@@ -268,6 +268,7 @@
 char *genlocallblstr(char *buf, size_t sz);
 char *genlblstr(char *buf, size_t sz);
 Type *codetype(Type *ft);
+Type *closuretype(Type *ft);
 Node *genlbl(Srcloc loc);
 Loc *loclbl(Node *lbl);
 Loc *locstrlbl(char *lbl);
--- a/6/gen.c
+++ b/6/gen.c
@@ -55,6 +55,16 @@
     return ft;
 }
 
+Type *closuretype(Type *ft)
+{
+    ft = tybase(ft);
+    if (ft->type == Tyfunc)
+        return ft;
+    assert(ft->type == Tycode);
+    ft = tydup(ft);
+    ft->type = Tyfunc;
+    return ft;
+}
 
 static int islocal(Node *dcl)
 {
--- a/6/insns.def
+++ b/6/insns.def
@@ -320,7 +320,8 @@
     "\tcall *%v\n",
     "\tCALL *%V\n",
     Use(.l={1}, .r={
-            Rrdi, Rrsi, Rrdx, Rrcx, Rr8, Rr9,
+            /* ABI note: Rrax holds the env */
+            Rrax, Rrdi, Rrsi, Rrdx, Rrcx, Rr8, Rr9,
             Rxmm0d, Rxmm1d, Rxmm2d, Rxmm3d,
             Rxmm4d, Rxmm5d, Rxmm6d, Rxmm7d,
     }),
--- a/6/isel.c
+++ b/6/isel.c
@@ -94,7 +94,7 @@
 
 static Loc *varloc(Isel *s, Node *n)
 {
-    ssize_t stkoff;
+    ssize_t off;
     Loc *l, *rip;
 
     /* we need to try getting it from the stack first, in case we
@@ -103,11 +103,11 @@
         rip = locphysreg(Rrip);
         l = locmeml(htget(s->globls, n), rip, NULL, mode(n));
     } else if (hthas(s->envoff, n)) {
-        stkoff = ptoi(htget(s->envoff, n));
-        l = locmem(-stkoff, locphysreg(Rrax), NULL, mode(n));
+        off = ptoi(htget(s->envoff, n));
+        l = locmem(off, locphysreg(Rrax), NULL, mode(n));
     } else if (hthas(s->stkoff, n)) {
-        stkoff = ptoi(htget(s->stkoff, n));
-        l = locmem(-stkoff, locphysreg(Rrbp), NULL, mode(n));
+        off = ptoi(htget(s->stkoff, n));
+        l = locmem(-off, locphysreg(Rrbp), NULL, mode(n));
     }  else {
         l = htget(s->reglocs, n);
         if (!l) {
@@ -482,7 +482,7 @@
 {
     AsmOp op;
     Node *fn;
-    Loc *f;
+    Loc *f, *e;
 
     if (exprop(n) == Ocall) {
         op = Icall;
@@ -492,6 +492,8 @@
     } else {
         op = Icallind;
         f = selexpr(s, n->expr.args[0]);
+        e = selexpr(s, n->expr.args[1]);
+        g(s, Imov, e, locphysreg(Rrax), NULL);
     }
     g(s, op, f, NULL);
 }
@@ -514,13 +516,14 @@
 static Loc *gencall(Isel *s, Node *n)
 {
     Loc *src, *dst, *arg;  /* values we reduced */
-    size_t argsz, argoff, nargs;
+    size_t argsz, argoff, nargs, vasplit;
     size_t nfloats, nints;
     Loc *retloc, *rsp, *ret;       /* hard-coded registers */
     Loc *stkbump;        /* calculated stack offset */
+    Type *t, *fn;
+    Node **args;
     size_t i, a;
     int vararg;
-    Type *t, *fn;
 
     rsp = locphysreg(Rrsp);
     t = exprtype(n);
@@ -537,8 +540,15 @@
     fn = tybase(exprtype(n->expr.args[0]));
     /* calculate the number of args we expect to see, adjust
      * for a hidden return argument. */
-    nargs = countargs(fn);
+    vasplit = countargs(fn);
     argsz = 0;
+    if (exprop(n) == Ocall) {
+        args = &n->expr.args[1];
+        nargs = n->expr.nargs - 1;
+    } else {
+        args = &n->expr.args[2];
+        nargs = n->expr.nargs - 2;
+    }
     /* Have to calculate the amount to bump the stack
      * pointer by in one pass first, otherwise if we push
      * one at a time, we evaluate the args in reverse order.
@@ -545,9 +555,9 @@
      * Not good.
      *
      * Skip the first operand, since it's the function itself */
-    for (i = 1; i < n->expr.nargs; i++) {
-        argsz = align(argsz, min(size(n->expr.args[i]), Ptrsz));
-        argsz += size(n->expr.args[i]);
+    for (i = 0; i < nargs; i++) {
+        argsz = align(argsz, min(size(args[i]), Ptrsz));
+        argsz += size(args[i]);
     }
     argsz = align(argsz, 16);
     stkbump = loclit(argsz, ModeQ);
@@ -559,17 +569,17 @@
     nfloats = 0;
     nints = 0;
     vararg = 0;
-    for (i = 1; i < n->expr.nargs; i++) {
-        arg = selexpr(s, n->expr.args[i]);
-        argoff = alignto(argoff, exprtype(n->expr.args[i]));
-        if (i > nargs)
+    for (i = 0; i < nargs; i++) {
+        arg = selexpr(s, args[i]);
+        argoff = alignto(argoff, exprtype(args[i]));
+        if (i >= vasplit)
             vararg = 1;
-        if (stacknode(n->expr.args[i])) {
+        if (stacknode(args[i])) {
             src = locreg(ModeQ);
             g(s, Ilea, arg, src, NULL);
-            a = tyalign(exprtype(n->expr.args[i]));
-            blit(s, rsp, src, argoff, 0, size(n->expr.args[i]), a);
-            argoff += size(n->expr.args[i]);
+            a = tyalign(exprtype(args[i]));
+            blit(s, rsp, src, argoff, 0, size(args[i]), a);
+            argoff += size(args[i]);
         } else if (!vararg && isfloatmode(arg->mode) && nfloats < Nfloatregargs) {
             dst = coreg(floatargregs[nfloats], arg->mode);
             arg = inri(s, arg);
@@ -584,7 +594,7 @@
             dst = locmem(argoff, rsp, NULL, arg->mode);
             arg = inri(s, arg);
             stor(s, arg, dst);
-            argoff += size(n->expr.args[i]);
+            argoff += size(args[i]);
         }
     }
     call(s, n);
--- a/6/simp.c
+++ b/6/simp.c
@@ -62,6 +62,8 @@
 static Node *rval(Simp *s, Node *n, Node *dst);
 static Node *lval(Simp *s, Node *n);
 static Node *assign(Simp *s, Node *lhs, Node *rhs);
+static Node *assignat(Simp *s, Node *r, size_t off, Node *val);
+static Node *getcode(Simp *s, Node *n);
 static void simpcond(Simp *s, Node *n, Node *ltrue, Node *lfalse);
 static void simpconstinit(Simp *s, Node *dcl);
 static Node *simpcast(Simp *s, Node *val, Type *to);
@@ -329,9 +331,22 @@
     return load(addk(addr(s, sl, tyintptr), Ptrsz));
 }
 
-Node *loadvar(Simp *s, Node *n)
+Node *loadvar(Simp *s, Node *n, Node *dst)
 {
-    return n;
+    Node *p, *f, *r;
+
+    if (isconstfn(n)) {
+        if (dst)
+            r = dst;
+        else
+            r = temp(s, n);
+        f = getcode(s, n);
+        p = addr(s, r, exprtype(n));
+        assignat(s, p, Ptrsz, f);
+    } else {
+        r = n;
+    }
+    return r;
 }
 
 static Node *seqlen(Simp *s, Node *n, Type *ty)
@@ -680,29 +695,48 @@
     }
 }
 
-static Node *simpblob(Simp *s, Node *blob, Node ***l, size_t *nl)
+static Node *geninitdecl(Node *init, Type *ty, Node **dcl)
 {
     Node *n, *d, *r;
     char lbl[128];
 
-    n = mkname(blob->loc, genlblstr(lbl, 128));
-    d = mkdecl(blob->loc, n, blob->expr.type);
-    r = mkexpr(blob->loc, Ovar, n, NULL);
+    n = mkname(init->loc, genlblstr(lbl, 128));
+    d = mkdecl(init->loc, n, ty);
+    r = mkexpr(init->loc, Ovar, n, NULL);
 
-    d->decl.init = blob;
-    d->decl.type = blob->expr.type;
+    d->decl.init = init;
+    d->decl.type = ty;
     d->decl.isconst = 1;
     d->decl.isglobl = 1;
-    htput(s->globls, d, asmname(d));
 
     r->expr.did = d->decl.did;
-    r->expr.type = blob->expr.type;
+    r->expr.type = ty;
     r->expr.isconst = 1;
+    if (dcl)
+        *dcl = d;
+    return r;
+}
 
-    lappend(l, nl, d);
+static Node *simpcode(Simp *s, Node *fn)
+{
+    Node *r, *d;
+
+    r = geninitdecl(fn, codetype(exprtype(fn)), &d);
+    htput(s->globls, d, asmname(d));
+    lappend(&file->file.stmts, &file->file.nstmts, d);
     return r;
 }
 
+static Node *simpblob(Simp *s, Node *blob)
+{
+    Node *r, *d;
+
+    r = geninitdecl(blob, exprtype(blob), &d);
+    htput(s->globls, d, asmname(d));
+    lappend(&s->blobs, &s->nblobs, d);
+    return r;
+}
+
 static Node *ptrsized(Simp *s, Node *v)
 {
     if (size(v) == Ptrsz)
@@ -817,7 +851,7 @@
 
     args = n->expr.args;
     switch (exprop(n)) {
-        case Ovar:      r = loadvar(s, n);  break;
+        case Ovar:      r = loadvar(s, n, NULL);  break;
         case Oidx:      r = deref(idxaddr(s, args[0], args[1]), NULL); break;
         case Oderef:    r = deref(rval(s, args[0], NULL), NULL); break;
         case Omemb:     r = rval(s, n, NULL); break;
@@ -1069,7 +1103,6 @@
     Node *sz;
     Node *st;
 
-    val = rval(s, val, NULL);
     pdst = add(r, disp(val->loc, off));
 
     if (stacknode(val)) {
@@ -1101,7 +1134,7 @@
     off = 0;
     for (i = 0; i < n->expr.nargs; i++) {
         off = alignto(off, exprtype(args[i]));
-        assignat(s, r, off, args[i]);
+        assignat(s, r, off, rval(s, args[i], NULL));
         off += size(args[i]);
     }
     return dst;
@@ -1289,48 +1322,65 @@
 
 static Node *capture(Simp *s, Node *n, Node *dst)
 {
+    Node *fn, *t, *f, *e, *val, *dcl, *fp;
     size_t nenv, nenvt, off, i;
-    Node *fn, *t, *f, *e, *val, *dcl;
     Type **envt;
     Node **env;
 
-    f = simpblob(s, n, &file->file.stmts, &file->file.nstmts);
+    f = simpcode(s, n);
     fn = n->expr.args[0];
     fn = fn->lit.fnval;
+    if (!dst) {
+        dst = gentemp(s, n->loc, closuretype(exprtype(f)), &dcl);
+        forcelocal(s, dcl);
+    }
+    fp = addr(s, dst, exprtype(dst));
+
     env = getclosure(fn->func.scope, &nenv);
-    if (!env)
-        return f;
-    /* we need these in a deterministic order so that we can
-       put them in the right place both when we use them and
-       when we capture them.  */
-    qsort(env, nenv, sizeof(Node*), envcmp);
+    if (env) {
+        /* we need these in a deterministic order so that we can
+           put them in the right place both when we use them and
+           when we capture them.  */
+        qsort(env, nenv, sizeof(Node*), envcmp);
 
-    /* make the tuple that will hold the environment */
-    envt = NULL;
-    nenvt = 0;
-    for (i = 0; i < nenv; i++)
-        lappend(&envt, &nenvt, decltype(env[i]));
+        /* make the tuple that will hold the environment */
+        envt = NULL;
+        nenvt = 0;
+        /* reserve space for size */
+        lappend(&envt, &nenvt, tyintptr);
+        for (i = 0; i < nenv; i++)
+            lappend(&envt, &nenvt, decltype(env[i]));
 
-    t = gentemp(s, n->loc, mktytuple(n->loc, envt, nenvt), &dcl);
-    forcelocal(s, dcl);
-    e = addr(s, t, exprtype(t));
+        t = gentemp(s, n->loc, mktytuple(n->loc, envt, nenvt), &dcl);
+        forcelocal(s, dcl);
+        e = addr(s, t, exprtype(t));
 
-    off = Ptrsz;    /* we start with the size of the env */
-    for (i = 0; i < nenv; i++) {
-        off = alignto(off, decltype(env[i]));
-        val = mkexpr(n->loc, Ovar, env[i]->decl.name, NULL);
-        val->expr.type = env[i]->decl.type;
-        val->expr.did = env[i]->decl.did;
-        assignat(s, e, off, val);
-        off += size(env[i]);
+        off = Ptrsz;    /* we start with the size of the env */
+        for (i = 0; i < nenv; i++) {
+            off = alignto(off, decltype(env[i]));
+            val = mkexpr(n->loc, Ovar, env[i]->decl.name, NULL);
+            val->expr.type = env[i]->decl.type;
+            val->expr.did = env[i]->decl.did;
+            assignat(s, e, off, rval(s, val, NULL));
+            off += size(env[i]);
+        }
+        free(env);
+        assignat(s, fp, 0, e);
     }
-    free(env);
-    return f;
+    assignat(s, fp, Ptrsz, f);
+    return dst;
 }
 
+static Node *getenvptr(Simp *s, Node *n)
+{
+    assert(tybase(exprtype(n))->type == Tyfunc);
+    return load(addr(s, n, tyintptr));
+}
+
 static Node *getcode(Simp *s, Node *n)
 {
-    Node *r, *d;
+    Node *r, *p, *d;
+    Type *ty;
 
     if (isconstfn(n)) {
         d = decls[n->expr.did];
@@ -1338,7 +1388,10 @@
         r->expr.did = d->decl.did;
         r->expr.type = codetype(exprtype(n));
     } else {
-        r = rval(s, n, NULL);
+        ty = tybase(exprtype(n));
+        assert(ty->type == Tyfunc);
+        p = addr(s, rval(s, n, NULL), codetype(ty));
+        r = load(addk(p, Ptrsz));
     }
     return r;
 }
@@ -1351,7 +1404,11 @@
     Type *ft;
     Op op;
 
+    /* NB: If we called rval() on a const function, , we would end up with
+    a stack allocated closure. We don't want to do this. */
     fn = n->expr.args[0];
+    if (!isconstfn(fn))
+        fn = rval(s, fn, NULL);
     ft = tybase(exprtype(fn));
     if (exprtype(n)->type == Tyvoid)
         r = NULL;
@@ -1362,11 +1419,12 @@
 
     args = NULL;
     nargs = 0;
-    if (isconstfn(fn))
-        op = Ocall;
-    else
-        op = Ocallind;
+    op = Ocall;
     lappend(&args, &nargs, getcode(s, fn));
+    if (!isconstfn(fn)) {
+        op = Ocallind;
+        lappend(&args, &nargs, getenvptr(s, fn));
+    }
 
     if (exprtype(n)->type != Tyvoid && isstacktype(exprtype(n)))
         lappend(&args, &nargs, addr(s, r, exprtype(n)));
@@ -1459,7 +1517,7 @@
                 dst = temp(s, n);
             t = addr(s, dst, exprtype(dst));
             for (i = 0; i < n->expr.nargs; i++)
-                assignat(s, t, size(n->expr.args[i])*i, n->expr.args[i]);
+                assignat(s, t, size(n->expr.args[i])*i, rval(s, n->expr.args[i], NULL));
             r = dst;
             break;
         case Ostruct:
@@ -1474,7 +1532,7 @@
             if (tybase(ty)->nmemb != n->expr.nargs)
                 append(s, mkexpr(n->loc, Oclear, t, mkintlit(n->loc, size(n)), NULL));
             for (i = 0; i < n->expr.nargs; i++)
-                assignat(s, t, offset(n, n->expr.args[i]->expr.idx), n->expr.args[i]);
+                assignat(s, t, offset(n, n->expr.args[i]->expr.idx), rval(s, n->expr.args[i], NULL));
             r = dst;
             break;
         case Ocast:
@@ -1534,10 +1592,10 @@
                     if ((uint64_t)args[0]->lit.intval < 0x7fffffffULL)
                         r = n;
                     else
-                        r = simpblob(s, n, &s->blobs, &s->nblobs);
+                        r = simpblob(s, n);
                     break;
                 case Lstr: case Lflt:
-                    r = simpblob(s, n, &s->blobs, &s->nblobs);
+                    r = simpblob(s, n);
                     break;
                 case Lfunc:
                     r = capture(s, n, dst);
@@ -1545,7 +1603,7 @@
             }
             break;
         case Ovar:
-            r = loadvar(s, n);
+            r = loadvar(s, n, dst);
             break;
         case Ogap:
             fatal(n, "'_' may not be an rvalue");
@@ -1590,7 +1648,7 @@
                 u = mkexpr(n->loc, Olit, t, NULL);
                 t->lit.type = n->expr.type;
                 u->expr.type = n->expr.type;
-                v = simpblob(s, u, &s->blobs, &s->nblobs);
+                v = simpblob(s, u);
                 r = mkexpr(n->loc, Ofmul, v, rval(s, args[0], NULL), NULL);
                 r->expr.type = n->expr.type;
             } else {
@@ -1830,7 +1888,7 @@
     return fn;
 }
 
-static void extractsub(Simp *s, Node ***blobs, size_t *nblobs, Node *e)
+static void extractsub(Simp *s, Node *e)
 {
     size_t i;
 
@@ -1838,12 +1896,12 @@
     switch (exprop(e)) {
         case Oslice:
             if (exprop(e->expr.args[0]) == Oarr)
-                e->expr.args[0] = simpblob(s, e->expr.args[0], blobs, nblobs);
+                e->expr.args[0] = simpblob(s, e->expr.args[0]);
             break;
         case Oarr:
         case Ostruct:
             for (i = 0; i < e->expr.nargs; i++)
-                extractsub(s, blobs, nblobs, e->expr.args[i]);
+                extractsub(s, e->expr.args[i]);
             break;
         default:
             break;
@@ -1858,7 +1916,7 @@
     e = dcl->decl.init;
     if (e && exprop(e) == Olit) {
         if (e->expr.args[0]->lit.littype == Lfunc)
-            simpblob(s, e, &file->file.stmts, &file->file.nstmts);
+            simpcode(s, e);
         else
             lappend(&s->blobs, &s->nblobs, dcl);
     } else if (dcl->decl.isconst) {
@@ -1866,7 +1924,7 @@
             case Oarr:
             case Ostruct:
             case Oslice:
-                extractsub(s, &s->blobs, &s->nblobs, e);
+                extractsub(s, e);
                 lappend(&s->blobs, &s->nblobs, dcl);
                 break;
             default:
--- a/6/typeinfo.c
+++ b/6/typeinfo.c
@@ -260,10 +260,10 @@
         case Tyflt64:
             return 8;
 
-        case Tyfunc:
-            return Ptrsz;
         case Tycode:
             return Ptrsz;
+        case Tyfunc:
+            return 2*Ptrsz;
         case Tyslice:
             return 2*Ptrsz; /* len; ptr */
         case Tyname:
--- a/mi/cfg.c
+++ b/mi/cfg.c
@@ -266,49 +266,54 @@
     return cfg;
 }
 
-void dumpcfg(Cfg *cfg, FILE *fd)
+void dumpbb(Bb *bb, FILE *fd)
 {
-    size_t i, j;
-    Bb *bb;
+    size_t i;
     char *sep;
 
-    for (j = 0; j < cfg->nbb; j++) {
-        bb = cfg->bb[j];
-        if (!bb)
-            continue;
-        fprintf(fd, "\n");
-        fprintf(fd, "Bb: %d labels=(", bb->id);
-        sep = "";
-        for (i = 0; i < bb->nlbls; i++) {;
-            fprintf(fd, "%s%s", bb->lbls[i], sep);
+    fprintf(fd, "Bb: %d labels=(", bb->id);
+    sep = "";
+    for (i = 0; i < bb->nlbls; i++) {;
+        fprintf(fd, "%s%s", bb->lbls[i], sep);
+        sep = ",";
+    }
+    fprintf(fd, ")\n");
+
+    /* in edges */
+    fprintf(fd, "Pred: ");
+    sep = "";
+    for (i = 0; i < bsmax(bb->pred); i++) {
+        if (bshas(bb->pred, i)) {
+            fprintf(fd, "%s%zd", sep, i);
             sep = ",";
         }
-        fprintf(fd, ")\n");
+    }
+    fprintf(fd, "\n");
 
-        /* in edges */
-        fprintf(fd, "Pred: ");
-        sep = "";
-        for (i = 0; i < bsmax(bb->pred); i++) {
-            if (bshas(bb->pred, i)) {
-                fprintf(fd, "%s%zd", sep, i);
-                sep = ",";
-            }
+    /* out edges */
+    fprintf(fd, "Succ: ");
+    sep = "";
+    for (i = 0; i < bsmax(bb->succ); i++) {
+        if (bshas(bb->succ, i)) {
+            fprintf(fd, "%s%zd", sep, i);
+            sep = ",";
         }
-        fprintf(fd, "\n");
+    }
+    fprintf(fd, "\n");
 
-        /* out edges */
-        fprintf(fd, "Succ: ");
-        sep = "";
-        for (i = 0; i < bsmax(bb->succ); i++) {
-             if (bshas(bb->succ, i)) {
-                fprintf(fd, "%s%zd", sep, i);
-                sep = ",";
-             }
-        }
-        fprintf(fd, "\n");
+    for (i = 0; i < bb->nnl; i++)
+        dump(bb->nl[i], fd);
+    fprintf(fd, "\n");
+}
 
-        for (i = 0; i < bb->nnl; i++)
-            dump(bb->nl[i], fd);
+void dumpcfg(Cfg *cfg, FILE *fd)
+{
+    size_t i;
+
+    for (i = 0; i < cfg->nbb; i++) {
+        if (!cfg->bb[i])
+            continue;
         fprintf(fd, "\n");
+        dumpbb(cfg->bb[i], fd);
     }
 }
--- a/mi/dfcheck.c
+++ b/mi/dfcheck.c
@@ -70,7 +70,6 @@
     Bb *bb;
 
     r = reaching(cfg);
-//    dumpcfg(cfg, stdout);
     for (i = 0; i < cfg->nbb; i++) {
         bb = cfg->bb[i];
         if (!bb)
--- a/parse/types.def
+++ b/parse/types.def
@@ -30,7 +30,7 @@
 
 /* end atomic types */
 Ty(Typtr, NULL, 0)
-Ty(Tyfunc, NULL, 0)
+Ty(Tyfunc, NULL, 1)
 
 /* these types live on the stack */
 Ty(Tyslice, NULL, 1)
--- a/test/closure.myr
+++ b/test/closure.myr
@@ -1,3 +1,5 @@
+use std
+
 /* checks that functions with environment capture work. should exit with 42. */
 const main = {
 	var a = 42
@@ -4,5 +6,5 @@
 	var f = {b
 		-> a + b
 	}
-	f(13)
+	std.exit(f(13))
 }
--- a/test/tests
+++ b/test/tests
@@ -58,7 +58,7 @@
 B callbig	E	42
 B nestfn	E	42
 B foldidx	P	123,456
-# B closure	E	55      ## BUGGERED
+B closure	E	55
 B loop		P	0123401236789
 B subrangefor	P       12
 B patiter	P	23512