shithub: riscv

ref: 1c7e58e75bc6b5620984f164d66354350b28dfe0
dir: /sys/src/cmd/mpc.y/

View raw version
%{

#include	<u.h>
#include	<libc.h>
#include	<bio.h>
#include 	<mp.h>

typedef struct Sym Sym;
typedef struct Node Node;

enum {
	FSET	= 1,
	FUSE	= 2,
	FARG	= 4,
	FLOC	= 8,
};

struct Sym
{
	Sym*	l;
	int	f;
	char	n[];
};

struct Node
{
	int	c;
	Node*	l;
	Node*	r;
	Sym*	s;
	mpint*	m;
	int	n;
};

#pragma	varargck type "N" Node*

int	ntmp;
Node	*ftmps, *atmps;
Node	*modulo;

Node*	new(int, Node*, Node*);
Sym*	sym(char*);

Biobuf	bin;
int	goteof;
int	lineno;
int	clevel;
char*	filename;

int	getch(void);
void	ungetc(void);
void	yyerror(char*);
int	yyparse(void);
void	diag(Node*, char*, ...);
void	com(Node*);
void	fcom(Node*,Node*,Node*);

#pragma varargck argpos cprint 1
#pragma varargck argpos diag 2

%}

%union
{
	Sym*	sval;
	Node*	node;
}

%type	<node>	name num args expr bool block elif stmnt stmnts

%left	'{' '}' ';'
%right	'=' ','
%right	'?' ':'
%left	EQ NEQ '<' '>'
%left	LSH RSH
%left	'+' '-'
%left	'/' '%'
%left	'*'
%left	'^'
%right	'('

%token	MOD IF ELSE WHILE BREAK 
%token	<sval>	NAME NUM

%%

prog:
	prog func
|	func

func:
	name args stmnt
	{
		fcom($1, $2, $3);
	}

args:
	'(' expr ')'
	{
		$$ = $2;
	}
|	'(' ')'
	{
		$$ = nil;
	}

name:
	NAME
	{
		$$ = new(NAME,nil,nil);
		$$->s = $1;
	}
num:
	NUM
	{
		$$ = new(NUM,nil,nil);
		$$->s = $1;
	}

elif:
	ELSE IF '(' bool ')' stmnt
	{
		$$ = new('?', $4, new(':', $6, nil));
	}
|	ELSE IF '(' bool ')' stmnt elif
	{
		$$ = new('?', $4, new(':', $6, $7));
	}
|	ELSE stmnt
	{
		$$ = $2;
	}

sem:
	sem ';'
|	';'

stmnt:
	expr '=' expr sem
	{
		$$ = new('=', $1, $3);
	}
|	MOD args stmnt
	{
		$$ = new('m', $2, $3);
	}
|	IF '(' bool ')' stmnt
	{
		$$ = new('?', $3, new(':', $5, nil));
	}
|	IF '(' bool ')' stmnt elif
	{
		$$ = new('?', $3, new(':', $5, $6));
	}
|	WHILE '(' bool ')' stmnt
	{
		$$ = new('@', new('?', $3, new(':', $5, new('b', nil, nil))), nil);
	}
|	BREAK sem
	{
		$$ = new('b', nil, nil);
	}
|	expr sem
	{
		if($1->c == NAME)
			$$ = new('e', $1, nil);
		else
			$$ = $1;
	}
|	block

block:
	'{' stmnts '}'
	{
		$$ = $2;
	}

stmnts:
	stmnts stmnt
	{
		$$ = new('\n', $1, $2);
	}
|	stmnt

expr:
	'(' expr ')'
	{
		$$ = $2;
	}
|	name
	{
		$$ = $1;
	}
|	num
	{
		$$ = $1;
	}
|	'-' expr
	{
		$$ = new(NUM, nil, nil);
		$$->s = sym("0");
		$$->s->f = 0;
		$$ = new('-', $$, $2);
	}
|	expr ',' expr
	{
		$$ = new(',', $1, $3);
	}
|	expr '^' expr
	{
		$$ = new('^', $1, $3);
	}
|	expr '*' expr
	{
		$$ = new('*', $1, $3);
	}
|	expr '/' expr
	{
		$$ = new('/', $1, $3);
	}
|	expr '%' expr
	{
		$$ = new('%', $1, $3);
	}
|	expr '+' expr
	{
		$$ = new('+', $1, $3);
	}
|	expr '-' expr
	{
		$$ = new('-', $1, $3);
	}
|	bool '?' expr ':' expr
	{
		$$ = new('?', $1, new(':', $3, $5));
	}
|	name args
	{
		$$ = new('e', $1, $2);
	}
|	expr LSH expr
	{
		$$ = new(LSH, $1, $3);
	}
|	expr RSH expr
	{
		$$ = new(RSH, $1, $3);
	}

bool:
	'(' bool ')'
	{
		$$ = $2;
	}
|	'!' bool
	{
		$$ = new('!', $2, nil);
	}
|	expr EQ expr
	{
		$$ = new(EQ, $1, $3);
	}
|	expr NEQ expr
	{
		$$ = new('!', new(EQ, $1, $3), nil);
	}
|	expr '>' expr
	{
		$$ = new('>', $1, $3);
	}
|	expr '<' expr
	{
		$$ = new('<', $1, $3);
	}

%%

int
yylex(void)
{
	static char buf[200];
	char *p;
	int c;

Loop:
	c = getch();
	switch(c){
	case -1:
		return -1;
	case ' ':
	case '\t':
	case '\n':
		goto Loop;
	case '#':
		while((c = getch()) > 0)
			if(c == '\n')
				break;
		goto Loop;
	}

	switch(c){
	case '?': case ':':
	case '+': case '-':
	case '*': case '^':
	case '/': case '%':
	case '{': case '}':
	case '(': case ')':
	case ',': case ';':
		return c;
	case '<':
		if(getch() == '<') return LSH;
		ungetc();
		return '<';
	case '>': 
		if(getch() == '>') return RSH;
		ungetc();
		return '>';
	case '=':
		if(getch() == '=') return EQ;
		ungetc();
		return '=';
	case '!':
		if(getch() == '=') return NEQ;
		ungetc();
		return '!';
	}

	ungetc();
	p = buf;
	for(;;){
		c = getch();
		if((c >= Runeself)
		|| (c == '_')
		|| (c >= 'a' && c <= 'z')
		|| (c >= 'A' && c <= 'Z')
		|| (c >= '0' && c <= '9')){
			*p++ = c;
			continue;
		}
		ungetc();
		break;
	}
	*p = '\0';

	if(strcmp(buf, "mod") == 0)
		return MOD;
	if(strcmp(buf, "if") == 0)
		return IF;
	if(strcmp(buf, "else") == 0)
		return ELSE;
	if(strcmp(buf, "while") == 0)
		return WHILE;
	if(strcmp(buf, "break") == 0)
		return BREAK;

	yylval.sval = sym(buf);
	yylval.sval->f = 0;
	return (buf[0] >= '0' && buf[0] <= '9') ? NUM : NAME;
}


int
getch(void)
{
	int c;

	c = Bgetc(&bin);
	if(c == Beof){
		goteof = 1;
		return -1;
	}
	if(c == '\n')
		lineno++;
	return c;
}

void
ungetc(void)
{
	Bungetc(&bin);
}

Node*
new(int c, Node *l, Node *r)
{
	Node *n;

	n = malloc(sizeof(Node));
	n->c = c;
	n->l = l;
	n->r = r;
	n->s = nil;
	n->m = nil;
	n->n = lineno;
	return n;
}

Sym*
sym(char *n)
{
	static Sym *tab[128];
	Sym *s;
	ulong h, t;
	int i;

	h = 0;
	for(i=0; n[i] != '\0'; i++){
		t = h & 0xf8000000;
		h <<= 5;
		h ^= t>>27;
		h ^= (ulong)n[i];
	}
	h %= nelem(tab);
	for(s = tab[h]; s != nil; s = s->l)
		if(strcmp(s->n, n) == 0)
			return s;
	s = malloc(sizeof(Sym)+i+1);
	memmove(s->n, n, i+1);
	s->f = 0;
	s->l = tab[h];
	tab[h] = s;
	return s;
}

void
yyerror(char *s)
{
	fprint(2, "%s:%d: %s\n", filename, lineno, s);
	exits(s);
}
void
cprint(char *fmt, ...)
{
	static char buf[1024], tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t";
	char *p, *x;
	va_list a;

	va_start(a, fmt);
	vsnprint(buf, sizeof(buf), fmt, a);
	va_end(a);

	p = buf;
	while((x = strchr(p, '\n')) != nil){
		x++;
		write(1, p, x-p);
		p = &tabs[sizeof(tabs)-1 - clevel];
		if(*p != '\0')
			write(1, p, strlen(p));
		p = x;
	}
	if(*p != '\0')
		write(1, p, strlen(p));
}

Node*
alloctmp(void)
{
	Node *t;

	t = ftmps;
	if(t != nil)
		ftmps = t->l;
	else {
		char n[16];

		snprint(n, sizeof(n), "tmp%d", ++ntmp);
		t = new(NAME, nil, nil);
		t->s = sym(n);

		cprint("mpint *");
	}
	cprint("%N = mpnew(0);\n", t);
	t->s->f &= ~(FSET|FUSE);
	t->l = atmps;
	atmps = t;
	return t;
}

int
isconst(Node *n)
{
	if(n->c == NUM)
		return 1;
	if(n->c == NAME){
		return 	n->s == sym("mpzero") ||
			n->s == sym("mpone") ||
			n->s == sym("mptwo");
	}
	return 0;
}

int
istmp(Node *n)
{
	Node *l;

	if(n->c == NAME){
		for(l = atmps; l != nil; l = l->l){
			if(l->s == n->s)
				return 1;
		}
	}
	return 0;
}


void
freetmp(Node *t)
{
	Node **ll, *l;

	if(t == nil)
		return;
	if(t->c == ','){
		freetmp(t->l);
		freetmp(t->r);
		return;
	}
	if(t->c != NAME)
		return;

	ll = &atmps;
	for(l = atmps; l != nil; l = l->l){
		if(l == t){
			cprint("mpfree(%N);\n", t);
			*ll = t->l;
			t->l = ftmps;
			ftmps = t;
			return;
		}
		ll = &l->l;
	}
}

int
symref(Node *n, Sym *s)
{
	if(n == nil)
		return 0;
	if(n->c == NAME && n->s == s)
		return 1;
	return symref(n->l, s) || symref(n->r, s);
}

void
nodeset(Node *n)
{
	if(n == nil)
		return;
	if(n->c == NAME){
		n->s->f |= FSET;
		return;
	}
	if(n->c == ','){
		nodeset(n->l);
		nodeset(n->r);
	}
}

int
complex(Node *n)
{
	if(n->c == NAME)
		return 0;
	if(n->c == NUM && n->m->sign > 0 && mpcmp(n->m, mptwo) <= 0)
		return 0;
	return 1;
}

void
bcom(Node *n, Node *t);

Node*
ccom(Node *f)
{
	Node *l, *r;

	if(f == nil)
		return nil;

	if(f->m != nil)
		return f;
	f->m = (void*)~0;

	switch(f->c){
	case NUM:
		f->m = strtomp(f->s->n, nil, 0, nil);
		if(f->m == nil)
			diag(f, "bad constant");
		goto out;

	case LSH:
	case RSH:
		break;

	case '+':
	case '-':
	case '*':
	case '/':
	case '%':
	case '^':
		if(modulo == nil || modulo->c == NUM)
			break;

		/* wet floor */
	default:
		return f;
	}

	f->l = l = ccom(f->l);
	f->r = r = ccom(f->r);
	if(l == nil || r == nil || l->c != NUM || r->c != NUM)
		return f;

	f->m = mpnew(0);
	switch(f->c){
	case LSH:
	case RSH:
		if(mpsignif(r->m) > 32)
			diag(f, "bad shift");
		if(f->c == LSH)
			mpleft(l->m, mptoi(r->m), f->m);
		else
			mpright(l->m, mptoi(r->m), f->m);
		goto out;

	case '+':
		mpadd(l->m, r->m, f->m);
		break;
	case '-':
		mpsub(l->m, r->m, f->m);
		break;
	case '*':
		mpmul(l->m, r->m, f->m);
		break;
	case '/':
		if(modulo != nil){
			mpinvert(r->m, modulo->m, f->m);
			mpmul(f->m, l->m, f->m);
		} else {
			mpdiv(l->m, r->m, f->m, nil);
		}
		break;
	case '%':
		mpmod(l->m, r->m, f->m);
		break;
	case '^':
		mpexp(l->m, r->m, modulo != nil ? modulo->m : nil, f->m);
		goto out;
	}
	if(modulo != nil)
		mpmod(f->m, modulo->m, f->m);

out:
	f->l = nil;
	f->r = nil;
	f->s = nil;
	f->c = NUM;
	return f;
}

Node*
ecom(Node *f, Node *t)
{
	Node *l, *r, *t2;

	if(f == nil)
		return nil;

	f = ccom(f);
	if(f->c == NUM){
		if(f->m->sign < 0){
			f->m->sign = 1;
			t = ecom(f, t);
			f->m->sign = -1;
			if(isconst(t))
				t = ecom(t, alloctmp());
			cprint("%N->sign = -1;\n", t);
			return t;
		}
		if(mpcmp(f->m, mpzero) == 0){
			f->c = NAME;
			f->s = sym("mpzero");
			f->s->f = FSET;
			return ecom(f, t);
		}
		if(mpcmp(f->m, mpone) == 0){
			f->c = NAME;
			f->s = sym("mpone");
			f->s->f = FSET;
			return ecom(f, t);
		}
		if(mpcmp(f->m, mptwo) == 0){
			f->c = NAME;
			f->s = sym("mptwo");
			f->s->f = FSET;
			return ecom(f, t);
		}
	}

	if(f->c == ','){
		if(t != nil)
			diag(f, "cannot assign list to %N", t);
		f->l = ecom(f->l, nil);
		f->r = ecom(f->r, nil);
		return f;
	}

	l = r = nil;
	if(f->c == NAME){
		if((f->s->f & FSET) == 0)
			diag(f, "name used but not set");
		f->s->f |= FUSE;
		if(t == nil)
			return f;
		if(f->s != t->s)
			cprint("mpassign(%N, %N);\n", f, t);
		goto out;
	}

	if(t == nil)
		t = alloctmp();

	if(f->c == '?'){
		bcom(f, t);
		goto out;
	}

	if(f->c == 'e'){
		r = ecom(f->r, nil);
		if(r == nil)
			cprint("%N(%N);\n", f->l, t);
		else
			cprint("%N(%N, %N);\n", f->l, r, t);
		goto out;
	}

	if(t->c != NAME)
		diag(f, "destination %N not a name", t);

	switch(f->c){
	case NUM:
		if(mpsignif(f->m) <= 32)
			cprint("uitomp(%udUL, %N);\n", mptoui(f->m), t);
		else if(mpsignif(f->m) <= 64)
			cprint("uvtomp(%lludULL, %N);\n", mptouv(f->m), t);
		else
			cprint("strtomp(\"%.16B\", nil, 16, %N);\n", f->m, t);
		goto out;
	case LSH:
	case RSH:
		r = ccom(f->r);
		if(r == nil || r->c != NUM || mpsignif(r->m) > 32)
			diag(f, "bad shift");
		l = f->l->c == NAME ? f->l : ecom(f->l, t);
		if(f->c == LSH)
			cprint("mpleft(%N, %d, %N);\n", l, mptoi(r->m), t);
		else
			cprint("mpright(%N, %d, %N);\n", l, mptoi(r->m), t);
		goto out;
	case '*':
	case '/':
		l = ecom(f->l, nil);
		r = ecom(f->r, nil);
		break;
	default:
		l = ccom(f->l);
		r = ccom(f->r);
		l = ecom(l, complex(l) && !symref(r, t->s) ? t : nil);
		r = ecom(r, complex(r) && l->s != t->s ? t : nil);
		break;
	}


	if(modulo != nil){
		switch(f->c){
		case '+':
			cprint("mpmodadd(%N, %N, %N, %N);\n", l, r, modulo, t);
			goto out;
		case '-':
			cprint("mpmodsub(%N, %N, %N, %N);\n", l, r, modulo, t);
			goto out;
		case '*':
		Modmul:
			if(l->s == sym("mptwo") || r->s == sym("mptwo"))
				cprint("mpmodadd(%N, %N, %N, %N); // 2*%N\n",
					r->s == sym("mptwo") ? l : r,
					r->s == sym("mptwo") ? l : r,
					modulo, t,
					r);
			else
				cprint("mpmodmul(%N, %N, %N, %N);\n", l, r, modulo, t);
			goto out;
		case '/':
			if(l->s == sym("mpone")){
				cprint("mpinvert(%N, %N, %N);\n", r, modulo, t);
				goto out;
			}
			t2 = alloctmp();
			cprint("mpinvert(%N, %N, %N);\n", r, modulo, t2);
			cprint("mpmodmul(%N, %N, %N, %N);\n", l, t2, modulo, t);
			freetmp(t2);
			goto out;
		case '^':
			if(r->s == sym("mptwo")){
				r = l;
				goto Modmul;
			}
			cprint("mpexp(%N, %N, %N, %N);\n", l, r, modulo, t);
			goto out;
		}
	}

	switch(f->c){
	case '+':
		cprint("mpadd(%N, %N, %N);\n", l, r, t);
		goto out;
	case '-':
		if(l->s == sym("mpzero")){
			r = ecom(r, t);
			cprint("%N->sign = -%N->sign;\n", t, t);
		} else
			cprint("mpsub(%N, %N, %N);\n", l, r, t);
		goto out;
	case '*':
	Mul:
		if(l->s == sym("mptwo") || r->s == sym("mptwo"))
			cprint("mpleft(%N, 1, %N);\n", r->s == sym("mptwo") ? l : r, t);
		else
			cprint("mpmul(%N, %N, %N);\n", l, r, t);
		goto out;
	case '/':
		cprint("mpdiv(%N, %N, %N, %N);\n", l, r, t, nil);
		goto out;
	case '%':
		cprint("mpmod(%N, %N, %N);\n", l, r, t);
		goto out;
	case '^':
		if(r->s == sym("mptwo")){
			r = l;
			goto Mul;
		}
		cprint("mpexp(%N, %N, nil, %N);\n", l, r, t);
		goto out;
	default:
		diag(f, "unknown operation");
	}

out:
	if(l != t)
		freetmp(l);
	if(r != t)
		freetmp(r);
	nodeset(t);
	return t;
}

void
bcom(Node *n, Node *t)
{
	Node *f, *l, *r;
	int neg = 0;

	l = r = nil;
	f = n->l;
Loop:
	switch(f->c){
	case '!':
		neg = !neg;
		f = f->l;
		goto Loop;
	case '>':
	case '<':
	case EQ:
		l = ecom(f->l, nil);
		r = ecom(f->r, nil);
		if(t != nil) {
			Node *b1, *b2;

			b1 = ecom(n->r->l, nil);
			b2 = ecom(n->r->r, nil);
			cprint("mpsel(");

			if(l->s == r->s)
				cprint("0");
			else {
				if(f->c == '>')
					cprint("-");
				cprint("mpcmp(%N, %N)", l, r);
			}
			if(f->c == EQ)
				neg = !neg;
			else
				cprint(" >> (sizeof(int)*8-1)");

			cprint(", %N, %N, %N);\n", neg ? b2 : b1, neg ? b1 : b2, t);
			freetmp(b1);
			freetmp(b2);
		} else {
			cprint("if(");

			if(l->s == r->s)
				cprint("0");
			else
				cprint("mpcmp(%N, %N)", l, r);
			if(f->c == EQ)
				cprint(neg ? " != 0" : " == 0");
			else if(f->c == '>')
				cprint(neg ? " <= 0" : " > 0");
			else
				cprint(neg ? " >= 0" : " < 0");

			cprint(")");
			com(n->r);
		}
		break;
	default:
		diag(n, "saw %N in boolean expression", f);
	}
	freetmp(l);
	freetmp(r);
}

void
com(Node *n)
{
	Node *l, *r;

Loop:
	if(n != nil)
	switch(n->c){
	case '\n':
		com(n->l);
		n = n->r;
		goto Loop;
	case '?':
		bcom(n, nil);
		break;
	case 'b':
		for(l = atmps; l != nil; l = l->l)
			cprint("mpfree(%N);\n", l);
		cprint("break;\n");
		break;
	case '@':
		cprint("for(;;)");
	case ':':
		clevel++;
		cprint("{\n");
		l = ftmps;
		r = atmps;
		if(n->c == '@')
			atmps = nil;
		ftmps = nil;
		com(n->l);
		if(n->r != nil){
			cprint("}else{\n");
			ftmps = nil;
			com(n->r);
		}
		ftmps = l;
		atmps = r;
		clevel--;
		cprint("}\n");
		break;
	case 'm':
		l = modulo;
		modulo = ecom(n->l, nil);
		com(n->r);
		freetmp(modulo);
		modulo = l;
		break;
	case 'e':
		if(n->r == nil)
			cprint("%N();\n", n->l);
		else {
			r = ecom(n->r, nil);
			cprint("%N(%N);\n", n->l, r);
			freetmp(r);
		}
		break;
	case '=':
		ecom(n->r, n->l);
		break;
	}
}

Node*
flocs(Node *n, Node *r)
{
Loop:
	if(n != nil)
	switch(n->c){
	default:
		r = flocs(n->l, r);
		r = flocs(n->r, r);
		n = n->r;
		goto Loop;
	case '=':
		n = n->l;
		if(n == nil)
			diag(n, "lhs is nil");
		while(n->c == ','){
			n->c = '=';
			r = flocs(n, r);
			n->c = ',';
			n = n->r;
			if(n == nil)
				return r;
		}
		if(n->c == NAME && (n->s->f & (FARG|FLOC)) == 0){
			n->s->f = FLOC;
			return new(',', n, r);
		}
		break;
	}
	return r;
}

void
fcom(Node *f, Node *a, Node *b)
{
	Node *a0, *l0, *l;

	ntmp = 0;
	ftmps = atmps = modulo = nil;
	clevel = 1;
	cprint("void %N(", f);
	a0 = a;
	while(a != nil){
		if(a != a0)
			cprint(", ");
		l = a->c == NAME ? a : a->l;
		l->s->f = FARG|FSET;
		cprint("mpint *%N", l);
		a = a->r;
	}
	cprint("){\n");
	l0 = flocs(b, nil);
	for(a = l0; a != nil; a = a->r)
		cprint("mpint *%N = mpnew(0);\n", a->l);
	com(b);
	for(a = l0; a != nil; a = a->r)
		cprint("mpfree(%N);\n", a->l);
	clevel = 0;
	cprint("}\n");
}

void
diag(Node *n, char *fmt, ...)
{
	static char buf[1024];
	va_list a;
	
	va_start(a, fmt);
	vsnprint(buf, sizeof(buf), fmt, a);
	va_end(a);

	fprint(2, "%s:%d: for %N; %s\n", filename, n->n, n, buf);
	exits("error");
}

int
Nfmt(Fmt *f)
{
	Node *n = va_arg(f->args, Node*);

	if(n == nil)
		return fmtprint(f, "nil");

	if(n->c == ',')
		return fmtprint(f, "%N, %N", n->l, n->r);

	switch(n->c){
	case NUM:
		if(n->m != nil)
			return fmtprint(f, "%B", n->m);
		/* wet floor */
	case NAME:
		return fmtprint(f, "%s", n->s->n);
	case EQ:
		return fmtprint(f, "==");
	case IF:
		return fmtprint(f, "if");
	case ELSE:
		return fmtprint(f, "else");
	case MOD:
		return fmtprint(f, "mod");
	default:
		return fmtprint(f, "%c", (char)n->c);
	}
}

void
parse(int fd, char *file)
{
	Binit(&bin, fd, OREAD);
	filename = file;
	clevel = 0;
	lineno = 1;
	goteof = 0;
	while(!goteof)
		yyparse();
	Bterm(&bin);
}

void
usage(void)
{
	fprint(2, "%s [file ...]\n", argv0);
	exits("usage");
}

void
main(int argc, char *argv[])
{
	fmtinstall('N', Nfmt);
	fmtinstall('B', mpfmt);

	ARGBEGIN {
	default:
		usage();
	} ARGEND;

	if(argc == 0){
		parse(0, "<stdin>");
		exits(nil);
	}
	while(*argv != nil){
		int fd;

		if((fd = open(*argv, OREAD)) < 0){
			fprint(2, "%s: %r\n", *argv);
			exits("error");
		}
		parse(fd, *argv);
		close(fd);
		argv++;
	}
	exits(nil);
}