shithub: lpa

Download patch

ref: 2b23d05d57743af57385cd42c0fd2d223b11d8c8
parent: 1da09c81867fc8416f99b8d7a24d73b2e3acf6e6
author: Peter Mikkelsen <peter@pmikkelsen.com>
date: Sun Jul 28 07:35:34 EDT 2024

Start working on constraints. Not even close to being useful yet

--- a/array.c
+++ b/array.c
@@ -20,6 +20,7 @@
 		vlong *intdata;
 		Rune *chardata;
 		Array **arraydata;
+		ConstraintVar **vardata;
 	};
 };
 
@@ -47,6 +48,9 @@
 	case TypeArray:
 		size *= sizeof(Array *);
 		break;
+	case TypeVar:
+		size *= sizeof(ConstraintVar *);
+		break;
 	}
 
 	a->shape = allocextra(a, (sizeof(usize) * rank) + size);
@@ -74,6 +78,12 @@
 }
 
 void
+setvar(Array *a, usize offset, ConstraintVar *v)
+{
+	a->vardata[offset] = v;
+}
+
+void
 setshape(Array *a, int dim, usize size)
 {
 	a->shape[dim] = size;
@@ -115,7 +125,15 @@
 	return a->arraydata[i];
 }
 
+ConstraintVar *
+getvar(Array *a, usize i)
+{
+	return a->vardata[i];
+}
+
+static int printconstraintvar(char *, ConstraintVar *, int);
 static int printarraysub(char *, Array *, int);
+static int printexpr(char *, Ast *, int);
 static int
 printitem(char *p, Array *a, uvlong i, int depth)
 {
@@ -126,6 +144,8 @@
 		return sprint(p, "%C", a->chardata[i]);
 	case TypeArray:
 		return printarraysub(p, a->arraydata[i], depth);
+	case TypeVar:
+		return printconstraintvar(p, a->vardata[i], depth);
 	default:
 		return sprint(p, "???");
 	}	
@@ -141,6 +161,33 @@
 }
 
 static int
+printconstraintvar(char *buf, ConstraintVar *v, int depth)
+{
+	static int extrainfo = 1;
+
+	char *p = buf;
+	if(v->ast)
+		p += printexpr(p, v->ast, 0);
+	else{
+		p += sprint(p, "%s⍙%d", v->name, v->id);
+		if(v->count > 0 && extrainfo){
+			p += sprint(p, " {\n");
+			for(uvlong i = 0; i < v->count; i++){
+				p += indent(p, depth+1);
+				int ei = extrainfo;
+				extrainfo = 0;
+				p += printexpr(p, v->constraints[i]->ast, depth+1);
+				extrainfo = ei;
+				p += sprint(p, "\n");
+			}
+			p += indent(p, depth);
+			p += sprint(p, "}");
+		}
+	}
+	return p-buf;
+}
+
+static int
 printarraysub(char *buf, Array *a, int depth)
 {
 	char *p = buf;
@@ -160,7 +207,7 @@
 			p += printitem(p, a, i, depth); /* TODO: quoting */
 		p += sprint(p, "'");
 		goto end;
-	}else if(a->rank == 1 && a->type == TypeArray){
+	}else if(a->rank == 1 && (a->type == TypeArray || a->type == TypeVar)){
 		if(a->size == 0){
 			p += sprint(p, "( ⋄ )");
 			goto end;
@@ -175,6 +222,9 @@
 		}
 		p += sprint(p, ")");
 		goto end;
+	}else if(a->rank == 0 && a->type == TypeVar){
+		p += printitem(p, a, 0, depth);
+		goto end;
 	}
 
 	p += sprint(p, "Some array I can't print yet");
@@ -311,6 +361,7 @@
 		goto end;
 
 	type = a->arraydata[0]->type;
+
 	b = allocarray(type, a->rank, a->size);
 	for(uvlong dim = 0; dim < a->rank; dim++)
 		b->shape[dim] = a->shape[dim];
@@ -327,6 +378,12 @@
 		case TypeChar:
 			b->chardata[i] = a->arraydata[i]->chardata[0];
 			break;
+		case TypeVar:
+			b->vardata[i] = a->arraydata[i]->vardata[0];
+			break;
+		default:
+			b = a;
+			goto end;
 		}
 	}
 end:
--- /dev/null
+++ b/constraint.c
@@ -1,0 +1,218 @@
+#include <u.h>
+#include <libc.h>
+#include <thread.h>
+
+#include "dat.h"
+#include "fns.h"
+
+/* monadic constraints */
+
+/* dyadic constraints */
+static void constraint_equal(Ast *, Array *, Array *);
+
+
+Array *
+allocvar(char *name)
+{
+	static int id = 0;
+
+	if(name == nil)
+		name = "⎕var";
+
+	ConstraintVar *v = alloc(DataConstraintVar);
+	v->name = name;
+	v->id = id++;
+
+	Array *a = allocarray(TypeVar, 0, 1);
+	setvar(a, 0, v);
+	return a;
+}
+
+static Ast *
+varast(Array *a)
+{
+	if(a == nil)
+		return nil;
+
+	if(gettype(a) == TypeVar && getrank(a) == 0){
+		ConstraintVar *v = getvar(a, 0);
+		if(v->ast)
+			return v->ast;
+	}
+	Ast *c = alloc(DataAst);
+	c->tag = AstConst;
+	c->val = a;
+
+	return c;
+}
+
+Array *
+delayedexpr(int prim, Array *x, Array *y)
+{
+	Array *a = allocvar(nil);
+	ConstraintVar *v = getvar(a, 0);
+
+	Ast *func = alloc(DataAst);
+	func->tag = AstPrim;
+	func->prim = prim;
+
+	Ast *e = alloc(DataAst);
+	v->ast = e;
+	e->func = func;
+	e->tag = x ? AstDyadic : AstMonadic;
+	e->left = varast(x);
+	e->right = varast(y);
+
+	return a;
+}
+
+void
+graphadd(ConstraintGraph *g, Constraint *c)
+{
+	for(uvlong i = 0; i < g->ccount; i++){
+		if(g->cs[i] == c)
+			return; /* The constraint is already there. TODO: make a better test */
+	}
+
+	if(g->ccount == nelem(g->cs))
+		error(EInternal, "not enough space in the constraint graph");
+	g->cs[g->ccount] = c;
+	g->ccount++;
+
+	for(uvlong i = 0; i < nelem(c->vars); i++){
+		ConstraintVar *v = c->vars[i];
+		if(v == nil)
+			continue;
+		int new = 1;
+		for(uvlong j = 0; j < g->vcount && new; j++){
+			if(g->vs[j] == v)
+				new = 0;
+		}
+		if(!new)
+			continue;
+		g->vs[g->vcount] = v;
+		g->vcount++;
+		for(uvlong j = 0; j < v->count; j++)
+			graphadd(g, v->constraints[j]);
+	}
+}
+
+Array *
+solve(ConstraintVar *v)
+{
+	Array *res;
+
+	if(v->ast)
+		error(EDomain, "Cannot solve expression. Use ⎕assert first.");
+
+	/* Consider the available constraints on the variable, and find a solutions (just one).
+	 * If that isn't possible, fail with some appropriate error.
+	 *
+	 * There are of course multiple strategies to perform this search, and perhaps it would
+	 * make sense if ⎕solve let the user specify one as the left argument.
+	 */
+
+	/* Build a graph containing all the variables and constraints involved.
+	 * The number of max vars and constraints are fixed for now.
+	 */
+	ConstraintGraph *g = alloc(DataConstraintGraph);
+
+	for(uvlong i = 0; i < v->count; i++)
+		graphadd(g, v->constraints[i]);
+	if(g->ccount == 0){
+		/* it can have any value */
+		res = allocarray(TypeNumber, 0, 1);
+		setint(res, 0, 0);
+	}else
+		error(EInternal, "⎕solve not implemented (%ulld vars and %ulld constraints)", g->vcount, g->ccount);
+	return res;
+}
+
+void
+constrain(ConstraintVar *v)
+{
+	if(!v->ast)
+		error(EDomain, "Expected a constraint expression, not a variable.");
+
+	/* Analyse the AST and add the appropriate constraints to the variables involved.
+	 * Also simplify with the constraints already there, and give an error if
+	 * the simplifications show that no solutions are possible.
+	 */
+	int prim, dyadic;
+	Array *left = nil;
+	Array *right = nil;
+
+	if(!(v->ast->tag == AstMonadic || v->ast->tag == AstDyadic))
+		goto fail;
+	if(v->ast->func->tag != AstPrim)
+		goto fail;
+	prim = v->ast->func->prim;
+	dyadic = 0;
+	switch(v->ast->tag){
+	case AstDyadic:
+		dyadic = 1;
+		if(v->ast->left->tag != AstConst)
+			goto fail;
+		left = v->ast->left->val;
+		/* fall through */
+	case AstMonadic:
+		if(v->ast->right->tag != AstConst)
+			goto fail;
+		right = v->ast->right->val;
+	}
+
+	switch(prim){
+	case PMatch:
+		if(dyadic)
+			constraint_equal(v->ast, left, right);
+		else
+			goto fail;
+		break;
+	default:
+		goto fail;
+	}
+	return;
+
+fail:
+	error(EInternal, "don't know how to assert the given constraint");
+}
+
+static void
+applyconstraint(Constraint *c)
+{
+	/* Find the variables involved */
+	Array *args[2];
+	args[0] = c->left;
+	args[1] = c->right;
+	int nvars = 0;
+
+	for(int i = 0; i < nelem(args); i++){
+		Array *a = args[i];
+		if(gettype(a) != TypeVar || getrank(a) != 0)
+			continue;
+		ConstraintVar *v = getvar(a, 0);
+		c->vars[nvars] = v;
+		nvars++;
+
+		v->count++;
+		v->constraints = allocextra(v, sizeof(c) * v->count);
+		v->constraints[v->count-1] = c;
+	}
+
+	/* Should simplify here as well */
+}
+
+/* monadic constraints */
+
+/* dyadic constraints */
+
+static void
+constraint_equal(Ast *a, Array *x, Array *y)
+{
+	Constraint *c = alloc(DataConstraint);
+	c->tag = CEqual;
+	c->ast = a;
+	c->left = x;
+	c->right = y;
+	applyconstraint(c);
+}
\ No newline at end of file
--- a/dat.h
+++ b/dat.h
@@ -18,6 +18,9 @@
 	DataLocalList,
 	DataErrorCtx,
 	DataErrorTrap,
+	DataConstraint,
+	DataConstraintVar,
+	DataConstraintGraph,
 
 	DataMax,
 };
@@ -151,6 +154,7 @@
 	TypeNumber,
 	TypeChar,
 	TypeArray,
+	TypeVar,
 };
 
 typedef struct Array Array;
@@ -230,6 +234,7 @@
 	ILocal,
 	IPop,
 	IDisplay,
+	IPushVar,
 };
 
 typedef struct ValueStack ValueStack;
@@ -317,4 +322,59 @@
 
 	uvlong count;
 	ErrorTrap **traps;
+};
+
+enum ConstraintType
+{
+	CEqual,
+};
+
+typedef struct Constraint Constraint;
+typedef struct ConstraintVar ConstraintVar;
+
+struct Constraint
+{
+	int tag;
+	Ast *ast;
+
+	Array *left;
+	Array *right;
+
+	ConstraintVar *vars[2]; /* max 2 vars for now */
+};
+
+struct ConstraintVar
+{
+	char *name;
+	int id;
+
+	Ast *ast;
+
+	uvlong count;
+	Constraint **constraints;
+};
+
+enum PrimitiveId
+{
+	PRight,
+	PLeft,
+	PPlus,
+	PMinus,
+	PRho,
+	PMatch,
+
+	PAssert,
+	PAll,
+	PSolve,
+	PVar,
+};
+
+typedef struct ConstraintGraph ConstraintGraph;
+struct ConstraintGraph
+{
+	uvlong vcount;
+	uvlong ccount;
+
+	ConstraintVar *vs[128];
+	Constraint *cs[128];
 };
\ No newline at end of file
--- a/eval.c
+++ b/eval.c
@@ -47,11 +47,13 @@
 	uvlong id = sym(s, a->name);
 	emitbyte(c, ILocal);
 	emituvlong(c, id);
-	if(assign){
-		emitbyte(c, IAssign);
+	if(!assign){ /* create a new constraint var */
+		emitbyte(c, IPushVar);
 		emituvlong(c, id);
-		emitbyte(c, IPop);
 	}
+	emitbyte(c, IAssign);
+	emituvlong(c, id);
+	emitbyte(c, IPop);
 }
 
 static void
@@ -418,6 +420,10 @@
 			break;
 		case IDisplay:
 			/* nothing to do, IPop checks for it */
+			break;
+		case IPushVar:
+			o += getuvlong(c->instrs+o, &v);
+			pushval(values, allocvar(symname(m->symtab, v)));
 			break;
 		default:
 			error(EInternal, "unknown instruction in evalbc: %d", instr);
--- a/fns.h
+++ b/fns.h
@@ -4,6 +4,7 @@
 void setint(Array *, usize, vlong);
 void setchar(Array *, usize, Rune);
 void setarray(Array *, usize, Array *);
+void setvar(Array *, usize, ConstraintVar *);
 void setshape(Array *, int, usize);
 int gettype(Array *);
 int getrank(Array *);
@@ -11,10 +12,17 @@
 vlong getint(Array *, usize);
 Rune getchar(Array *, usize);
 Array *getarray(Array *, usize);
+ConstraintVar *getvar(Array *, usize);
 
 Array *simplifyarray(Array *);
 char *printarray(Array *);
 char *printfunc(Function *);
+
+/* constraint.c */
+Array *allocvar(char *);
+Array *delayedexpr(int, Array *, Array *);
+Array *solve(ConstraintVar *);
+void constrain(ConstraintVar *);
 
 /* error.c */
 #define trap(num) (setjmp(setuptrap(1, num)->env))
--- a/memory.c
+++ b/memory.c
@@ -41,6 +41,9 @@
 	[DataLocalList] = {.size = sizeof(LocalList) },
 	[DataErrorCtx] = {.size = sizeof(ErrorCtx) },
 	[DataErrorTrap] = {.size = sizeof(ErrorTrap) },
+	[DataConstraint] = {.size = sizeof(Constraint) },
+	[DataConstraintVar] = {.size = sizeof(ConstraintVar) },
+	[DataConstraintGraph] = {.size = sizeof(ConstraintGraph) },
 };
 
 void *
--- a/mkfile
+++ b/mkfile
@@ -4,6 +4,7 @@
 SCRIPTS=lpa
 OFILES=\
 	array.$O\
+	constraint.$O\
 	error.$O\
 	eval.$O\
 	fs.$O\
--- a/parse.c
+++ b/parse.c
@@ -38,6 +38,7 @@
 	else
 		ast = parseprog(tokens);
 	match(tokens, TokEnd);
+
 	return ast;
 }
 
--- a/prim.c
+++ b/prim.c
@@ -7,15 +7,21 @@
 
 /* NOTE: In LPA, system functions are treated as primitives as well */
 
+/* niladic functions */
+static Array *primfn_var(void);
+
 /* monadic functions */
 static Array *primfn_same(Array *);
 static Array *primfn_shape(Array *);
 
+static Array *primfn_assert(Array *);
+static Array *primfn_allsolutions(Array *);
+static Array *primfn_solve(Array *);
+
 /* dyadic functions */
 static Array *primfn_left(Array *, Array *);
 static Array *primfn_right(Array *, Array *);
 static Array *primfn_match(Array *, Array *);
-
 struct {
 	char *spelling;
 	int nameclass;
@@ -23,12 +29,38 @@
 	Array *(*monad)(Array *);
 	Array *(*dyad)(Array *, Array *);
 } primspecs[] = {
-	"⊢", NameclassFunc, nil, primfn_same, primfn_right,
-	"⊣", NameclassFunc, nil, primfn_same, primfn_left,
-	"+", NameclassFunc, nil, nil, nil,
-	"-", NameclassFunc, nil, nil, nil,
-	"⍴", NameclassFunc, nil, primfn_shape, nil,
-	"≡", NameclassFunc, nil, nil, primfn_match,
+	[PRight] = {
+		"⊢", NameclassFunc, nil, primfn_same, primfn_right
+	},
+	[PLeft] = {
+		"⊣", NameclassFunc, nil, primfn_same, primfn_left,
+	},
+	[PPlus] = {
+		"+", NameclassFunc, nil, nil, nil
+	},
+	[PMinus] = {
+		"-", NameclassFunc, nil, nil, nil
+	},
+	[PRho] = {
+		"⍴", NameclassFunc, nil, primfn_shape, nil
+	},
+	[PMatch] = {
+		"≡", NameclassFunc, nil, nil, primfn_match
+	},
+
+	/* Constraint stuff. Pick glyphs for them later */
+	[PAssert] = {
+		"⎕assert",	NameclassFunc, nil, primfn_assert, nil
+	},
+	[PAll] = {
+		"⎕all",		NameclassFunc, nil, primfn_allsolutions, nil
+	},
+	[PSolve] = {
+		"⎕solve",	NameclassFunc, nil, primfn_solve, nil
+	},
+	[PVar] = {
+		"⎕var",		NameclassFunc, primfn_var, nil, nil
+	}
 };
 
 char *
@@ -47,6 +79,8 @@
 primvalence(int id)
 {
 	int valence = 0;
+	if(primspecs[id].nilad)
+		valence |= Niladic;
 	if(primspecs[id].monad)
 		valence |= Monadic;
 	if(primspecs[id].dyad)
@@ -68,30 +102,45 @@
 Array *
 primnilad(int id)
 {
-	if(primspecs[id].nilad)
-		return primspecs[id].nilad();
-	else
+	Array *(*fn)(void) = primspecs[id].nilad;
+	if(fn == nil)
 		error(EInternal, "primitive %s has no niladic definition", primsymb(id));
+	return fn();	
 }
 
 Array *
 primmonad(int id, Array *y)
 {
-	if(primspecs[id].monad)
-		return primspecs[id].monad(y);
-	else
+	Array *(*fn)(Array *) = primspecs[id].monad;
+	if(fn == nil)
 		error(EInternal, "primitive %s has no monadic definition", primsymb(id));
+
+	if(gettype(y) == TypeVar && !(id == PAssert || id == PSolve))
+		return delayedexpr(id, nil, y);
+
+	return fn(y);
 }
 
 Array *
 primdyad(int id, Array *x, Array *y)
 {
-	if(primspecs[id].dyad)
-		return primspecs[id].dyad(x, y);
-	else
+	Array *(*fn)(Array *, Array *) = primspecs[id].dyad;
+	if(fn == nil)
 		error(EInternal, "primitive %s has no dyadic definition", primsymb(id));
+
+	if(gettype(x) == TypeVar || gettype(y) == TypeVar)
+		return delayedexpr(id, x, y);
+
+	return fn(x, y);	
 }
 
+/* niladic functions */
+static Array *
+primfn_var(void)
+{
+	return allocvar(nil);
+}
+
 /* monadic functions */
 static Array *
 primfn_same(Array *a)
@@ -112,6 +161,31 @@
 	return r;
 }
 
+static Array *
+primfn_assert(Array *y)
+{
+	if(gettype(y) != TypeVar || getrank(y) != 0)
+		error(EDomain, "⎕assert expected a single constraint expression");
+	constrain(getvar(y, 0));
+	Array *r = allocarray(TypeNumber, 0, 1);
+	setint(r, 0, 0);
+	return r;
+}
+
+static Array *
+primfn_allsolutions(Array *)
+{
+	error(EInternal, "⎕all should never be evaluated");
+}
+
+static Array *
+primfn_solve(Array *y)
+{
+	if(gettype(y) != TypeVar || getrank(y) != 0)
+		error(EDomain, "expected single contraint variable");
+	return solve(getvar(y, 0));
+}
+
 /* dyadic functions */
 static Array *
 primfn_left(Array *x, Array *)
@@ -175,4 +249,4 @@
 	Array *z = allocarray(TypeNumber, 0, 1);
 	setint(z, 0, matches(x, y));
 	return z;
-}
\ No newline at end of file
+}
--- a/util.c
+++ b/util.c
@@ -188,6 +188,10 @@
 		case IDisplay:
 			print("DISPLAY\n");
 			break;
+		case IPushVar:
+			o += getuvlong(c->instrs+o, &v);
+			print("PUSHVAR %ulld\n", v);
+			break;
 		default:
 			print("???");
 			return;