shithub: mc

Download patch

ref: 14d4a8e93c6804eb049eec51a5102c9d8b10c1c5
parent: 0b0fb103248ec7b12b3a2b94d88b1fe7a0b4403e
author: Ori Bernstein <ori@eigenstate.org>
date: Sun Oct 18 18:00:30 EDT 2015

Work towards better match statements.

    Generate decision trees from mi/match.c. Still slightly
    broken, so not enabled.

--- a/6/isel.c
+++ b/6/isel.c
@@ -885,9 +885,9 @@
         case Obandeq: case Obxoreq: case Obsleq: case Obsreq: case Omemb:
         case Oslbase: case Osllen: case Ocast: case Outag: case Oudata: 
         case Oucon: case Otup: case Oarr: case Ostruct:
-        case Oslice: case Oidx: case Osize:
-		case Obreak: case Ocontinue:
-		case Numops:
+        case Oslice: case Oidx: case Osize: case Otupget:
+        case Obreak: case Ocontinue:
+        case Numops:
             dump(n, stdout);
             die("Should not see %s in isel", opstr[exprop(n)]);
             break;
--- a/6/simp.c
+++ b/6/simp.c
@@ -1049,29 +1049,51 @@
     return r;
 }
 
+static Node *tupget(Simp *s, Node *tup, size_t idx, Node *dst)
+{
+    Node *plv, *prv, *sz, *stor, *dcl;
+    size_t off, i;
+    Type *ty;
+
+    off = 0;
+    ty = exprtype(tup);
+    for (i = 0; i < ty->nsub; i++) {
+        off = alignto(off, ty->sub[i]);
+        if (i == idx)
+            break;
+        off += tysize(ty->sub[i]);
+    }
+
+    if (!dst) {
+        dst = gentemp(s, tup->loc, ty->sub[idx], &dcl);
+        if (isstacktype(ty->sub[idx]))
+            declarelocal(s, dcl);
+    }
+    prv = add(addr(s, tup, ty->sub[i]), disp(tup->loc, off));
+    if (stacknode(dst)) {
+        sz = disp(dst->loc, size(dst));
+        plv = addr(s, dst, exprtype(dst));
+        stor = mkexpr(dst->loc, Oblit, plv, prv, sz, NULL);
+    } else {
+        stor = set(dst, load(prv));
+    }
+    append(s, stor);
+    return dst;
+}
+
 /* Takes a tuple and binds the i'th element of it to the
  * i'th name on the rhs of the assignment. */
 static Node *destructure(Simp *s, Node *lhs, Node *rhs)
 {
-    Node *plv, *prv, *lv, *sz, *stor, **args;
-    size_t off, i;
+    Node *lv, *rv, **args;
+    size_t i;
 
     args = lhs->expr.args;
     rhs = rval(s, rhs, NULL);
-    off = 0;
     for (i = 0; i < lhs->expr.nargs; i++) {
         lv = lval(s, args[i]);
-        off = alignto(off, exprtype(lv));
-        prv = add(addr(s, rhs, exprtype(args[i])), disp(rhs->loc, off));
-        if (stacknode(args[i])) {
-            sz = disp(lhs->loc, size(lv));
-            plv = addr(s, lv, exprtype(lv));
-            stor = mkexpr(lhs->loc, Oblit, plv, prv, sz, NULL);
-        } else {
-            stor = set(lv, load(prv));
-        }
-        append(s, stor);
-        off += size(lv);
+        rv = tupget(s, rhs, i, lv);
+        assert(rv == lv);
     }
 
     return NULL;
@@ -1469,6 +1491,7 @@
     Node **args;
     size_t i;
     Type *ty;
+
     const Op fusedmap[Numops] = {
         [Oaddeq]        = Oadd,
         [Osubeq]        = Osub,
@@ -1513,7 +1536,7 @@
             r = simpucon(s, n, dst);
             break;
         case Outag:
-            die("union tags not yet supported\n");
+            r = uconid(s, args[0]);
             break;
         case Oudata:
             r = simpuget(s, n, dst);
@@ -1680,6 +1703,10 @@
         case Ogt: case Oge: case Olt: case Ole:
             r = compare(s, n, 0);
             break;
+        case Otupget:
+            assert(exprop(args[0]) == Olit);
+            i = args[0]->expr.args[0]->lit.intval;
+            r = tupget(s, args[0], i, dst);
         case Obad:
             die("bad operator");
             break;
--- a/mi/match.c
+++ b/mi/match.c
@@ -33,6 +33,7 @@
     int id;
 };
 
+void dtdumpnode(Dtree *dt, FILE *f, int depth, int iswild);
 static Dtree *addpat(Dtree *t, Node *pat, Node *val, Node ***cap, size_t *ncap);
 void dtdump(Dtree *dt, FILE *f);
 
@@ -100,6 +101,46 @@
     return t;
 }
 
+static Node *tupelt(Node *n, size_t i)
+{
+    Node *idx, *elt;
+
+    idx = mkintlit(n->loc, i);
+    idx->expr.type = mktype(n->loc, Tyuint64);
+    elt = mkexpr(n->loc, Otupget, n, idx, NULL);
+    elt->expr.type = tybase(exprtype(n))->sub[i];
+    return elt;
+}
+
+static Node *arrayelt(Node *n, size_t i)
+{
+    Node *idx, *elt;
+
+    idx = mkintlit(n->loc, i);
+    idx->expr.type = mktype(n->loc, Tyuint64);
+    elt = mkexpr(n->loc, Oidx, n, idx, NULL);
+    elt->expr.type = tybase(exprtype(n))->sub[0];
+    return elt;
+}
+
+static Node *structmemb(Node *n, Node *name, Type *ty)
+{
+    Node *elt;
+
+    elt = mkexpr(n->loc, Omemb, n, name, NULL);
+    elt->expr.type = ty;
+    return elt;
+}
+
+static Node *uvalue(Node *n, Type *ty)
+{
+    Node *elt;
+
+    elt = mkexpr(n->loc, Oudata, n, NULL);
+    elt->expr.type = ty;
+    return elt;
+}
+
 static Dtree *addwild(Dtree *t, Node *pat, Node *val, Node ***cap, size_t *ncap)
 {
     if (t->any)
@@ -113,6 +154,7 @@
 
 static Dtree *addunion(Dtree *t, Node *pat, Node *val, Node ***cap, size_t *ncap)
 {
+    Node *elt, *tag;
     Dtree *sub;
     size_t i;
 
@@ -122,19 +164,26 @@
     sub = NULL;
     for (i = 0; i < t->nval; i++) {
         if (nameeq(t->val[i], pat->expr.args[0])) {
-            if (pat->expr.nargs > 1)
-                return addpat(t->sub[i], pat->expr.args[1], NULL, cap, ncap);
-            else
+            if (pat->expr.nargs > 1) {
+                elt = uvalue(val, exprtype(pat->expr.args[1])); 
+                return addpat(t->sub[i], pat->expr.args[1], elt, cap, ncap);
+            } else {
                 return t->sub[i];
+            }
         }
     }
 
     sub = mkdtree();
     sub->patexpr = pat;
+    tag = mkexpr(pat->loc, Outag, val, NULL);
+    tag->expr.type = mktype(pat->loc, Tyint32);
     lappend(&t->val, &t->nval, pat->expr.args[0]);
     lappend(&t->sub, &t->nsub, sub);
-    if (pat->expr.nargs == 2)
-        sub = addpat(sub, pat->expr.args[1], NULL, cap, ncap);
+    lappend(&t->load, &t->nload, tag);
+    if (pat->expr.nargs == 2) {
+        elt = uvalue(val, exprtype(pat->expr.args[1])); 
+        sub = addpat(sub, pat->expr.args[1], elt, cap, ncap);
+    }
     return sub;
 }
 
@@ -153,6 +202,7 @@
     sub = mkdtree();
     sub->patexpr = pat;
     lappend(&t->val, &t->nval, pat);
+    lappend(&t->load, &t->nload, val);
     lappend(&t->sub, &t->nsub, sub);
     return sub;
 }
@@ -160,11 +210,14 @@
 static Dtree *addtup(Dtree *t, Node *pat, Node *val, Node ***cap, size_t *ncap)
 {
     size_t i;
+    Node *elt;
 
     if (t->any)
         return t->any;
-    for (i = 0; i < pat->expr.nargs; i++)
-        t = addpat(t, pat->expr.args[i], NULL, cap, ncap);
+    for (i = 0; i < pat->expr.nargs; i++) {
+        elt = tupelt(val, i);
+        t = addpat(t, pat->expr.args[i], elt, cap, ncap);
+    }
     return t;
 }
 
@@ -171,17 +224,21 @@
 static Dtree *addarr(Dtree *t, Node *pat, Node *val, Node ***cap, size_t *ncap)
 {
     size_t i;
+    Node *elt;
 
     if (t->any)
         return t->any;
-    for (i = 0; i < pat->expr.nargs; i++)
-        t = addpat(t, pat->expr.args[i], NULL, cap, ncap);
+    for (i = 0; i < pat->expr.nargs; i++) {
+        elt = arrayelt(val, i);
+        t = addpat(t, pat->expr.args[i], elt, cap, ncap);
+    }
     return t;
 }
 
 static Dtree *addstruct(Dtree *t, Node *pat, Node *val, Node ***cap, size_t *ncap)
 {
-    Node *elt;
+    Node *elt, *memb;
+    Type *ty;
     size_t i, j;
 
     if (t->any)
@@ -190,7 +247,9 @@
         elt = pat->expr.args[i];
         for (j = 0; j < t->nval; j++) {
             if (!strcmp(namestr(elt->expr.idx), namestr(t->val[j]->expr.idx))) {
-                t = addpat(t, pat->expr.args[i], NULL, cap, ncap);
+                ty = exprtype(pat->expr.args[i]);
+                memb = structmemb(val, elt->expr.idx, ty);
+                t = addpat(t, pat->expr.args[i], memb, cap, ncap);
                 break;
             }
         }
@@ -303,9 +362,32 @@
     return 1;
 }
 
-static Node *genmatch(Dtree *dt)
+static Node *genmatch(Srcloc loc, Dtree *dt)
 {
-    return NULL;
+    Node *lastcmp, *cmp, *eq, *pat;
+    size_t i;
+
+    dtdumpnode(dt, stdout, 0, 0);
+
+    lastcmp = NULL;
+    cmp = NULL;
+    pat = NULL;
+    if (dt->nsub == 0)
+        return dt->act;
+    for (i = 0; i < dt->nsub; i++) {
+        eq = mkexpr(loc, Oeq, dt->load[i], dt->val[i], NULL);
+        cmp = mkifstmt(loc, eq, genmatch(loc, dt->sub[i]), NULL);
+        if (!pat)
+            pat = cmp;
+        if (lastcmp)
+            lastcmp->ifstmt.iffalse = cmp;
+        else
+            lastcmp = cmp;
+        lastcmp = cmp;
+    }
+    if (dt->any)
+        lastcmp->ifstmt.iffalse = genmatch(loc, dt->any);
+    return pat;
 }
 
 Node *gensimpmatch(Node *m)
@@ -314,6 +396,7 @@
     Node **pat, **cap;
     size_t npat, ncap;
     size_t i;
+    Node *n;
 
     pat = m->matchstmt.matches;
     npat = m->matchstmt.nmatches;
@@ -321,7 +404,7 @@
     for (i = 0; i < npat; i++) {
         cap = NULL;
         ncap = 0;
-        leaf = addpat(t, pat[i]->match.pat, NULL, &cap, &ncap);
+        leaf = addpat(t, pat[i]->match.pat, m->matchstmt.val, &cap, &ncap);
         /* TODO: NULL is returned by unsupported patterns. */
         if (!leaf)
             return NULL;
@@ -333,7 +416,9 @@
     }
     if (!exhaustivematch(m, t, exprtype(m->matchstmt.val)))
         fatal(m, "nonexhaustive pattern set in match statement");
-    return genmatch(t);
+    n = genmatch(m->loc, t);
+    dump(n, stdout);
+    return n;
 }
 
 char *dtnodestr(Node *n)
@@ -351,6 +436,8 @@
             return "array";
         case Ostruct:
             return "struct";
+        case Ogap:
+            return "_";
         default:
             die("Invalid pattern in exhaustivenes check. BUG.");
             break;
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -1487,6 +1487,7 @@
         case Ofadd: case Ofsub: case Ofmul: case Ofdiv: case Ofneg:
         case Ofeq: case Ofne: case Ofgt: case Ofge: case Oflt: case Ofle:
         case Oueq: case Oune: case Ougt: case Ouge: case Oult: case Oule:
+        case Otupget:
         case Numops:
             die("Should not see %s in fe", opstr[exprop(n)]);
             break;
--- a/parse/ops.def
+++ b/parse/ops.def
@@ -67,6 +67,7 @@
 O(Oslbase,	1,	OTpre,  "SLBASE")       /* base of sice */
 O(Outag,	1,	OTpre,  "UTAG")	        /* tag of union */
 O(Oudata,	1,	OTpre,  "UDATA")        /* pointer to contents of union */
+O(Otupget,	1,	OTpre,  "TUPGET")        /* pointer to contents of union */
 O(Oblit,	1,	OTbin,  "BLIT")         /* blit memory */
 O(Oclear,       1,      OTpre,  "CLEAR")        /* zero */
 O(Ocallind,     1,      OTpre,  "CALL")         /* call with environment */