shithub: pprolog

ref: 13efe91101a11f41caf6321a8b2fbdd96ef9927a
dir: pprolog/arithmetic.c

View raw version
#include <u.h>
#include <libc.h>
#include <bio.h>

#include "dat.h"
#include "fns.h"

typedef struct ArithFunc2 ArithFunc2;
typedef struct ArithFunc1 ArithFunc1;
struct ArithFunc2
{
	Term *(*intint)(vlong, vlong);
	Term *(*floatfloat)(double, double);
	Term *(*floatint)(double, vlong);
	Term *(*intfloat)(vlong, double);
};

struct ArithFunc1
{
	Term *(*i)(vlong);
	Term *(*f)(double);
};

static Term *addi(vlong, vlong);
static Term *addf(double, double);
static Term *addfi(double, vlong);
static Term *addif(vlong, double);
static Term *subi(vlong, vlong);
static Term *subf(double, double);
static Term *subfi(double, vlong);
static Term *subif(vlong, double);
static Term *muli(vlong, vlong);
static Term *mulf(double, double);
static Term *mulfi(double, vlong);
static Term *mulif(vlong, double);
static Term *intdivi(vlong, vlong);
static Term *divii(vlong, vlong);
static Term *divf(double, double);
static Term *divfi(double, vlong);
static Term *divif(vlong, double);
static Term *remi(vlong, vlong);
static Term *modi(vlong, vlong);
static Term *poweri(vlong, vlong);
static Term *powerf(double, double);
static Term *powerfi(double, vlong);
static Term *powerif(vlong, double);
static Term *shiftlefti(vlong, vlong);
static Term *shiftrighti(vlong, vlong);
static Term *bitandi(vlong, vlong);
static Term *bitori(vlong, vlong);
static Term *negi(vlong);
static Term *negf(double);
static Term *absi(vlong);
static Term *absf(double);
static Term *signi(vlong);
static Term *signf(double);
static Term *intpartf(double);
static Term *fractpartf(double);
static Term *floati(vlong);
static Term *floatf(double);
static Term *floorf(double);
static Term *truncatef(double);
static Term *roundf(double);
static Term *ceilingf(double);
static Term *sini(vlong);
static Term *sinf(double);
static Term *cosi(vlong);
static Term *cosf(double);
static Term *atani(vlong);
static Term *atanf(double);
static Term *expi(vlong);
static Term *expf(double);
static Term *logi(vlong);
static Term *logf(double);
static Term *sqrti(vlong);
static Term *sqrtf(double);
static Term *bitcompli(vlong);

Term *binaryeval(Rune *, Term *, Term *, int *);
Term *unaryeval(Rune *, Term *, int *);

Term *
aritheval(Term *expr, int *waserror)
{
	/* Not every arithmetic operation is defined right now. */
	*waserror = 0;

	if(expr->tag == VariableTerm){
		*waserror = 1;
		return instantiationerror();
	}else if(expr->tag == AtomTerm){
		*waserror = 1;
		return typeerror(L"number", expr);
	}else if(expr->tag == FloatTerm || expr->tag == IntegerTerm)
		return expr;
	else if(expr->tag == CompoundTerm && expr->arity == 2){
		Term *A = aritheval(expr->children, waserror);
		if(*waserror)
			return A;

		Term *B = aritheval(expr->children->next, waserror);
		if(*waserror)
			return B;
		return binaryeval(expr->text, A, B, waserror);
	}else if(expr->tag == CompoundTerm && expr->arity == 1){
		Term *A = aritheval(expr->children, waserror);
		if(*waserror)
			return A;
		return unaryeval(expr->text, A, waserror);
	}else{
		*waserror = 1;
		Term *functor;
		Term *arity;
		if(expr->tag == CompoundTerm){
			functor = mkatom(expr->text);
			arity = mkinteger(expr->arity);
		}else{
			functor = expr;
			arity = mkinteger(0);
		}
		functor->next = arity;
		Term *pi = mkcompound(L"/", 2, functor);
		return typeerror(L"evaluable", pi);
	}	
}

Term *
binaryeval(Rune *f, Term *a, Term *b, int *waserror)
{
	Term *result;
	ArithFunc2 func;

	if(runestrcmp(f, L"+") == 0)
		func = (ArithFunc2){addi, addf, addfi, addif};
	else if(runestrcmp(f, L"-") == 0)
		func = (ArithFunc2){subi, subf, subfi, subif};
	else if(runestrcmp(f, L"*") == 0)
		func = (ArithFunc2){muli, mulf, mulfi, mulif};
	else if(runestrcmp(f, L"//") == 0)
		func = (ArithFunc2){intdivi, nil, nil, nil};
	else if(runestrcmp(f, L"/") == 0)
		func = (ArithFunc2){divii, divf, divfi, divif};
	else if(runestrcmp(f, L"rem") == 0)
		func = (ArithFunc2){remi, nil, nil, nil};
	else if(runestrcmp(f, L"mod") == 0)
		func = (ArithFunc2){modi, nil, nil, nil};
	else if(runestrcmp(f, L"**") == 0)
		func = (ArithFunc2){poweri, powerf, powerfi, powerif};
	else if(runestrcmp(f, L"<<") == 0)
		func = (ArithFunc2){shiftlefti, nil, nil, nil};
	else if(runestrcmp(f, L">>") == 0)
		func = (ArithFunc2){shiftrighti, nil, nil, nil};
	else if(runestrcmp(f, L"/\\") == 0)
		func = (ArithFunc2){bitandi, nil, nil, nil};
	else if(runestrcmp(f, L"\\/") == 0)
		func = (ArithFunc2){bitori, nil, nil, nil};
	else{
		*waserror = 1;
		Term *functor = mkatom(f);
		functor->next = mkinteger(2);
		Term *pi = mkcompound(L"/", 2, functor);
		return typeerror(L"evaluable", pi);
	}

	if(a->tag == IntegerTerm && b->tag == IntegerTerm && func.intint)
		result = func.intint(a->ival, b->ival);
	else if(a->tag == FloatTerm && b->tag == FloatTerm && func.floatfloat)
		result = func.floatfloat(a->dval, b->dval);
	else if(a->tag == FloatTerm && b->tag == IntegerTerm && func.floatint)
		result = func.floatint(a->dval, b->ival);
	else if(a->tag == IntegerTerm && b->tag == FloatTerm && func.intfloat)
		result = func.intfloat(a->ival, b->dval);
	else{
		/* There must have been a type error */
		int type1, type2;
		if(func.intint){
			type1 = IntegerTerm;
			type2 = IntegerTerm;
		}else if(func.floatfloat){
			type1 = FloatTerm;
			type2 = FloatTerm;
		}else if(func.floatint){
			type1 = FloatTerm;
			type2 = IntegerTerm;
		}else{
			type1 = IntegerTerm;
			type2 = FloatTerm;
		}

		if(a->tag != type1)
			result = typeerror(type1 == IntegerTerm ? L"integer" : L"float", a);
		else
			result = typeerror(type2 == IntegerTerm ? L"integer" : L"float", b);
	}

	if(result->tag != IntegerTerm && result->tag != FloatTerm)
		*waserror = 1;

	return result;
}

Term *
unaryeval(Rune *f, Term *a, int *waserror)
{
	Term *result;
	ArithFunc1 func;

	if(runestrcmp(f, L"-") == 0)
		func = (ArithFunc1){negi, negf};
	else if(runestrcmp(f, L"abs") == 0)
		func = (ArithFunc1){absi, absf};
	else if(runestrcmp(f, L"sign") == 0)
		func = (ArithFunc1){signi, signf};
	else if(runestrcmp(f, L"float_integer_part") == 0)
		func = (ArithFunc1){nil, intpartf};
	else if(runestrcmp(f, L"float_fractional_part") == 0)
		func = (ArithFunc1){nil, fractpartf};
	else if(runestrcmp(f, L"float") == 0)
		func = (ArithFunc1){floati, floatf};
	else if(runestrcmp(f, L"floor") == 0)
		func = (ArithFunc1){nil, floorf};
	else if(runestrcmp(f, L"truncate") == 0)
		func = (ArithFunc1){nil, truncatef};
	else if(runestrcmp(f, L"round") == 0)
		func = (ArithFunc1){nil, roundf};
	else if(runestrcmp(f, L"ceiling") == 0)
		func = (ArithFunc1){nil, ceilingf};
	else if(runestrcmp(f, L"sin") == 0)
		func = (ArithFunc1){sini, sinf};
	else if(runestrcmp(f, L"cos") == 0)
		func = (ArithFunc1){cosi, cosf};
	else if(runestrcmp(f, L"atan") == 0)
		func = (ArithFunc1){atani, atanf};
	else if(runestrcmp(f, L"exp") == 0)
		func = (ArithFunc1){expi, expf};
	else if(runestrcmp(f, L"log") == 0)
		func = (ArithFunc1){logi, logf};
	else if(runestrcmp(f, L"sqrt") == 0)
		func = (ArithFunc1){sqrti, sqrtf};
	else if(runestrcmp(f, L"\\") == 0)
		func = (ArithFunc1){bitcompli, nil};
	else{
		*waserror = 1;
		Term *functor = mkatom(f);
		functor->next = mkinteger(1);
		Term *pi = mkcompound(L"/", 2, functor);
		return typeerror(L"evaluable", pi);
	}

	if(a->tag == IntegerTerm && func.i)
		result = func.i(a->ival);
	else if(a->tag == FloatTerm && func.f)
		result = func.f(a->dval);
	else{
		if(func.i)
			result = typeerror(L"integer", a);
		else
			result = typeerror(L"float", a);
	}	

	if(result->tag != IntegerTerm && result->tag != FloatTerm)
		*waserror = 1;

	return result;
}

static Term *
addi(vlong x, vlong y)
{
	return mkinteger(x + y);
}

static Term *
addf(double x, double y)
{
	return mkfloat(x + y);
}

static Term *
addfi(double x, vlong y)
{
	return addf(x, y);
}

static Term *
addif(vlong x, double y)
{
	return addf(x, y);
}

static Term *
subi(vlong x, vlong y)
{
	return mkinteger(x - y);
}

static Term *
subf(double x, double y)
{
	return addf(x, -y);
}

static Term *
subfi(double x, vlong y)
{
	return subf(x, y);
}

static Term *
subif(vlong x, double y)
{
	return subf(x, y);
}

static Term *
muli(vlong x, vlong y)
{
	return mkinteger(x * y);
}

static Term *
mulf(double x, double y)
{
	return mkfloat(x * y);
}

static Term *
mulfi(double x, vlong y)
{
	return mulf(x, y);
}

static Term *
mulif(vlong x, double y)
{
	return mulf(x, y);
}

static Term *
intdivi(vlong x, vlong y)
{
	if(y == 0)
		return evaluationerror(L"zero_divisor");
	else
		return mkinteger(x / y);
}

static Term *
divii(vlong x, vlong y)
{
	return divf(x, y);
}

static Term *
divf(double x, double y)
{
	if(y == 0)
		return evaluationerror(L"zero_divisor");
	else
		return mkfloat(x / y);
}

static Term *
divfi(double x, vlong y)
{
	return divf(x, y);
}

static Term *
divif(vlong x, double y)
{
	return divf(x, y);
}

static Term *
remi(vlong x, vlong y)
{
	if(y == 0)
		return evaluationerror(L"zero_divisor");
	else
		return mkinteger(x - (x/y) * y);
}

static Term *
modi(vlong x, vlong y)
{
	if(y == 0)
		return evaluationerror(L"zero_divisor");
	else
		return mkinteger(x - (floor((double)x/(double)y) * y));
}

static Term *
poweri(vlong x, vlong y)
{
	return powerf(x, y);
}

static Term *
powerf(double x, double y)
{
	if(x == 0 && y == 0)
		return mkfloat(1);
	else if(x == 0 && y < 0)
		return evaluationerror(L"undefined");
	else
		return mkfloat(pow(x, y));
}

static Term *
powerfi(double x, vlong y)
{
	return powerf(x, y);
}

static Term *
powerif(vlong x, double y)
{
	return powerf(x, y);
}

static Term *
shiftlefti(vlong x, vlong y)
{
	return mkinteger(x << y);
}

static Term *
shiftrighti(vlong x, vlong y)
{
	return mkinteger(x >> y);
}

static Term *
bitandi(vlong x, vlong y)
{
	return mkinteger(x & y);
}

static Term *
bitori(vlong x, vlong y)
{
	return mkinteger(x | y);
}


static Term *
negi(vlong x)
{
	return mkinteger(-x);
}

static Term *
negf(double x)
{
	return mkfloat(-x);
}

static Term *
absi(vlong x)
{
	return mkinteger(x < 0 ? -x : x);
}

static Term *
absf(double x)
{
	return mkfloat(x < 0 ? -x : x);
}

static Term *
signi(vlong x)
{
	if(x < 0)
		return mkinteger(-1);
	else if(x > 0)
		return mkinteger(1);
	else
		return mkinteger(0);
}

static Term *
signf(double x)
{
	if(x < 0)
		return mkfloat(-1);
	else if(x > 0)
		return mkfloat(1);
	else
		return mkfloat(0);
}

static Term *
intpartf(double x)
{
	return mkfloat(signf(x)->dval * floorf(absf(x)->dval)->dval);
}

static Term *
fractpartf(double x)
{
	return mkfloat(x - intpartf(x)->dval);
}

static Term *
floati(vlong x)
{
	return mkfloat(x);
}

static Term *
floatf(double x)
{
	return mkfloat(x);
}

static Term *
floorf(double x)
{
	return mkfloat(floor(x));
}

static Term *
truncatef(double x)
{
	if(x >= 0)
		return floorf(x);
	else
		return mkfloat(-floorf(absf(x)->dval)->dval);
}

static Term *
roundf(double x)
{
	return floorf(x + 0.5);
}

static Term *
ceilingf(double x)
{
	return mkfloat(-floorf(-x)->dval);
}


static Term *
sini(vlong x)
{
	return sinf(x);
}

static Term *
sinf(double x)
{
	return mkfloat(sin(x));
}

static Term *
cosi(vlong x)
{
	return cosf(x);
}

static Term *
cosf(double x)
{
	return mkfloat(cos(x));
}

static Term *
atani(vlong x)
{
	return atanf(x);
}

static Term *
atanf(double x)
{
	return mkfloat(atan(x));
}

static Term *
expi(vlong x)
{
	return expf(x);
}

static Term *
expf(double x)
{
	return mkfloat(exp(x));
}

static Term *
logi(vlong x)
{
	return logf(x);
}

static Term *
logf(double x)
{
	if(x <= 0)
		return evaluationerror(L"undefined");
	else
		return mkfloat(log(x));
}

static Term *
sqrti(vlong x)
{
	return sqrtf(x);
}

static Term *
sqrtf(double x)
{
	if(x < 0)
		return evaluationerror(L"undefined");
	else
		return mkfloat(sqrt(x));
}

static Term *
bitcompli(vlong x)
{
	return mkinteger(~x);
}