shithub: util

ref: 9de25964fd0da3631675116bd12e700d9779eb4e
dir: /ann/anncreate.c/

View raw version
#include <u.h>
#include <libc.h>

#include "ann.h"

Ann*
anncreatev(int num_layers, int *layers)
{
	Ann *ret = calloc(1, sizeof(Ann));
	int arg;
	int i;

	ret->n = num_layers;
	ret->rate = 0.7;
	ret->layers = calloc(num_layers, sizeof(Layer*));
	ret->weights = calloc(num_layers-1, sizeof(Weights*));
	ret->deltas = calloc(num_layers-1, sizeof(Weights*));

	for (i = 0; i < num_layers; i++) {
		arg = layers[i];
		if (arg < 0 || arg > 1000000)
			arg = 0;
		if (i < (num_layers-1))
			ret->layers[i] = layercreate(arg, activation_leaky_relu, gradient_leaky_relu);
		else
			ret->layers[i] = layercreate(arg, activation_sigmoid, gradient_sigmoid);
		if (i > 0) {
			ret->weights[i-1] = weightscreate(ret->layers[i-1]->n, ret->layers[i]->n, 1);
			ret->deltas[i-1] = weightscreate(ret->layers[i-1]->n, ret->layers[i]->n, 0);
		}
	}

	return ret;
}

Ann*
anncreate(int num_layers, ...)
{
	Ann *ret;
	va_list args;
	int arg;
	int i;
	int *layers;

	va_start(args, num_layers);

	layers = calloc(sizeof(int), num_layers);

	for (i = 0; i < num_layers; i++) {
		arg = va_arg(args, int);
		layers[i] = arg;
	}

	va_end(args);

	ret = anncreatev(num_layers, layers);

	free(layers);

	return ret;
}