shithub: femtolisp

ref: 10d16c10bed72a177a99e0a9a3f3caca6c38ce94
dir: /operators.c/

View raw version
#include "flisp.h"
#include "operators.h"

mpint *
conv_to_mpint(void *data, numerictype_t tag)
{
	switch(tag){
	case T_INT8:   return itomp(*(int8_t*)data, nil);
	case T_UINT8:  return uitomp(*(uint8_t*)data, nil);
	case T_INT16:  return itomp(*(int16_t*)data, nil);
	case T_UINT16: return uitomp(*(uint16_t*)data, nil);
	case T_INT32:  return itomp(*(int32_t*)data, nil);
	case T_UINT32: return uitomp(*(uint32_t*)data, nil);
	case T_INT64:  return vtomp(*(int64_t*)data, nil);
	case T_UINT64: return uvtomp(*(int64_t*)data, nil);
	case T_MPINT:  return mpcopy(*(mpint**)data);
	case T_FLOAT:  return dtomp(*(float*)data, nil);
	case T_DOUBLE: return dtomp(*(double*)data, nil);
	}
	return mpzero;
}

double
conv_to_double(void *data, numerictype_t tag)
{
	double d;
	switch(tag){
	case T_INT8:   return *(int8_t*)data;
	case T_UINT8:  return *(uint8_t*)data;
	case T_INT16:  return *(int16_t*)data;
	case T_UINT16: return *(uint16_t*)data;
	case T_INT32:  return *(int32_t*)data;
	case T_UINT32: return *(uint32_t*)data;
	case T_INT64:
		d = *(int64_t*)data;
		if(d > 0 && *(int64_t*)data < 0)  // can happen!
			d = -d;
		return d;
	case T_UINT64: return *(uint64_t*)data;
	case T_MPINT:  return mptod(*(mpint**)data);
	case T_FLOAT:  return *(float*)data;
	case T_DOUBLE: return *(double*)data;
	}
	return 0;
}

void
conv_from_double(void *dest, double d, numerictype_t tag)
{
	switch(tag){
	case T_INT8:   *(int8_t*)dest = d; break;
	case T_UINT8:  *(uint8_t*)dest = d; break;
	case T_INT16:  *(int16_t*)dest = d; break;
	case T_UINT16: *(uint16_t*)dest = d; break;
	case T_INT32:  *(int32_t*)dest = d; break;
	case T_UINT32: *(uint32_t*)dest = d; break;
	case T_INT64:
		*(int64_t*)dest = d;
		if(d > 0 && *(int64_t*)dest < 0)  // 0x8000000000000000 is a bitch
			*(int64_t*)dest = INT64_MAX;
		break;
	case T_UINT64: *(uint64_t*)dest = d; break;
	case T_MPINT:  *(mpint**)dest = dtomp(d, nil); break;
	case T_FLOAT:  *(float*)dest = d; break;
	case T_DOUBLE: *(double*)dest = d; break;
	}
}

// FIXME sign with mpint
#define CONV_TO_INTTYPE(name, ctype) \
ctype \
conv_to_##name(void *data, numerictype_t tag) \
{ \
	switch(tag){ \
	case T_INT8:   return (ctype)*(int8_t*)data; \
	case T_UINT8:  return (ctype)*(uint8_t*)data; \
	case T_INT16:  return (ctype)*(int16_t*)data; \
	case T_UINT16: return (ctype)*(uint16_t*)data; \
	case T_INT32:  return (ctype)*(int32_t*)data; \
	case T_UINT32: return (ctype)*(uint32_t*)data; \
	case T_INT64:  return (ctype)*(int64_t*)data; \
	case T_UINT64: return (ctype)*(uint64_t*)data; \
	case T_MPINT:  return (ctype)mptov(*(mpint**)data); \
	case T_FLOAT:  return (ctype)*(float*)data; \
	case T_DOUBLE: return (ctype)*(double*)data; \
	} \
	return 0; \
}

CONV_TO_INTTYPE(int64, int64_t)
CONV_TO_INTTYPE(int32, int32_t)
CONV_TO_INTTYPE(uint32, uint32_t)

// this is needed to work around an UB casting negative
// floats and doubles to uint64. you need to cast to int64
// first.
uint64_t
conv_to_uint64(void *data, numerictype_t tag)
{
	int64_t s;
	switch(tag){
	case T_INT8:   return *(int8_t*)data; break;
	case T_UINT8:  return *(uint8_t*)data; break;
	case T_INT16:  return *(int16_t*)data; break;
	case T_UINT16: return *(uint16_t*)data; break;
	case T_INT32:  return *(int32_t*)data; break;
	case T_UINT32: return *(uint32_t*)data; break;
	case T_INT64:  return *(int64_t*)data; break;
	case T_UINT64: return *(uint64_t*)data; break;
	case T_MPINT:  return mptouv(*(mpint**)data); break;
	case T_FLOAT:
		if(*(float*)data >= 0)
			return *(float*)data;
		s = *(float*)data;
		return s;
	case T_DOUBLE:
		if(*(double*)data >= 0)
			return *(double*)data;
		s = *(double*)data;
		return s;
	}
	return 0;
}

int
cmp_same_lt(void *a, void *b, numerictype_t tag)
{
	switch(tag){
	case T_INT8:   return *(int8_t*)a < *(int8_t*)b;
	case T_UINT8:  return *(uint8_t*)a < *(uint8_t*)b;
	case T_INT16:  return *(int16_t*)a < *(int16_t*)b;
	case T_UINT16: return *(uint16_t*)a < *(uint16_t*)b;
	case T_INT32:  return *(int32_t*)a < *(int32_t*)b;
	case T_UINT32: return *(uint32_t*)a < *(uint32_t*)b;
	case T_INT64:  return *(int64_t*)a < *(int64_t*)b;
	case T_UINT64: return *(uint64_t*)a < *(uint64_t*)b;
	case T_MPINT:  return mpcmp(*(mpint**)a, *(mpint**)b) < 0;
	case T_FLOAT:  return *(float*)a < *(float*)b;
	case T_DOUBLE: return *(double*)a < *(double*)b;
	}
	return 0;
}

int
cmp_same_eq(void *a, void *b, numerictype_t tag)
{
	switch(tag){
	case T_INT8:   return *(int8_t*)a == *(int8_t*)b;
	case T_UINT8:  return *(uint8_t*)a == *(uint8_t*)b;
	case T_INT16:  return *(int16_t*)a == *(int16_t*)b;
	case T_UINT16: return *(uint16_t*)a == *(uint16_t*)b;
	case T_INT32:  return *(int32_t*)a == *(int32_t*)b;
	case T_UINT32: return *(uint32_t*)a == *(uint32_t*)b;
	case T_INT64:  return *(int64_t*)a == *(int64_t*)b;
	case T_UINT64: return *(uint64_t*)a == *(uint64_t*)b;
	case T_MPINT:  return mpcmp(*(mpint**)a, *(mpint**)b) == 0;
	case T_FLOAT:  return *(float*)a == *(float*)b;
	case T_DOUBLE: return *(double*)a == *(double*)b;
	}
	return 0;
}

/* FIXME one is allocated for all compare ops */
static mpint *cmpmpint;

int
cmp_lt(void *a, numerictype_t atag, void *b, numerictype_t btag)
{
	if(atag == btag)
		return cmp_same_lt(a, b, atag);

	double da = conv_to_double(a, atag);
	double db = conv_to_double(b, btag);

	// casting to double will only get the wrong answer for big int64s
	// that differ in low bits
	if(da < db && !isnan(da) && !isnan(db))
		return 1;
	if(db < da)
		return 0;

	if(cmpmpint == nil && (atag == T_MPINT || btag == T_MPINT))
		cmpmpint = mpnew(0);

	if(atag == T_UINT64){
		if(btag == T_INT64){
			if(*(int64_t*)b >= 0)
				return (*(uint64_t*)a < (uint64_t)*(int64_t*)b);
			return ((int64_t)*(uint64_t*)a < *(int64_t*)b);
		}
		if(btag == T_DOUBLE)
			return db == db ? (*(uint64_t*)a < (uint64_t)*(double*)b) : 0;
		if(btag == T_MPINT)
			return mpcmp(uvtomp(*(uint64_t*)a, cmpmpint), *(mpint**)b) < 0;
	}
	if(atag == T_INT64){
		if(btag == T_UINT64){
			if(*(int64_t*)a >= 0)
				return ((uint64_t)*(int64_t*)a < *(uint64_t*)b);
			return (*(int64_t*)a < (int64_t)*(uint64_t*)b);
		}
		if(btag == T_DOUBLE)
			return db == db ? (*(int64_t*)a < (int64_t)*(double*)b) : 0;
		if(btag == T_MPINT)
			return mpcmp(vtomp(*(int64_t*)a, cmpmpint), *(mpint**)b) < 0;
	}
	if(btag == T_UINT64){
		if(atag == T_DOUBLE)
			return da == da ? (*(uint64_t*)b > (uint64_t)*(double*)a) : 0;
		if(atag == T_MPINT)
			return mpcmp(*(mpint**)a, uvtomp(*(uint64_t*)b, cmpmpint)) < 0;
	}
	if(btag == T_INT64){
		if(atag == T_DOUBLE)
			return da == da ? (*(int64_t*)b > (int64_t)*(double*)a) : 0;
		if(atag == T_MPINT)
			return mpcmp(*(mpint**)a, vtomp(*(int64_t*)b, cmpmpint)) < 0;
	}
	return 0;
}

int
cmp_eq(void *a, numerictype_t atag, void *b, numerictype_t btag, int equalnans)
{
	union {
		double d;
		int64_t i64;
	}u, v;

	if(atag == btag && (!equalnans || atag < T_FLOAT))
		return cmp_same_eq(a, b, atag);

	double da = conv_to_double(a, atag);
	double db = conv_to_double(b, btag);

	if((int)atag >= T_FLOAT && (int)btag >= T_FLOAT){
		if(equalnans){
			u.d = da; v.d = db;
			return u.i64 == v.i64;
		}
		return da == db;
	}

	if(da != db)
		return 0;

	if(cmpmpint == nil && (atag == T_MPINT || btag == T_MPINT))
		cmpmpint = mpnew(0);

	if(atag == T_UINT64){
		// this is safe because if a had been bigger than INT64_MAX,
		// we would already have concluded that it's bigger than b.
		if(btag == T_INT64)
			return ((int64_t)*(uint64_t*)a == *(int64_t*)b);
		if(btag == T_DOUBLE)
			return (*(uint64_t*)a == (uint64_t)(int64_t)*(double*)b);
		if(btag == T_MPINT)
			return mpcmp(uvtomp(*(uint64_t*)a, cmpmpint), *(mpint**)b) == 0;
	}
	if(atag == T_INT64){
		if(btag == T_UINT64)
			return (*(int64_t*)a == (int64_t)*(uint64_t*)b);
		if(btag == T_DOUBLE)
			return (*(int64_t*)a == (int64_t)*(double*)b);
		if(btag == T_MPINT)
			return mpcmp(vtomp(*(int64_t*)a, cmpmpint), *(mpint**)b) == 0;
	}
	if(btag == T_UINT64){
		if(atag == T_INT64)
			return ((int64_t)*(uint64_t*)b == *(int64_t*)a);
		if(atag == T_DOUBLE)
			return (*(uint64_t*)b == (uint64_t)(int64_t)*(double*)a);
		if(atag == T_MPINT)
			return mpcmp(*(mpint**)a, uvtomp(*(uint64_t*)b, cmpmpint)) == 0;
	}
	if(btag == T_INT64){
		if(atag == T_UINT64)
			return (*(int64_t*)b == (int64_t)*(uint64_t*)a);
		if(atag == T_DOUBLE)
			return (*(int64_t*)b == (int64_t)*(double*)a);
		if(atag == T_MPINT)
			return mpcmp(*(mpint**)a, vtomp(*(int64_t*)b, cmpmpint)) == 0;
	}
	return 1;
}