shithub: util

Download patch

ref: 5b790b347f51baee12eeb2dfced7ef5cc0065a7e
parent: 07398654cc5684fec2ff81fc1b21c3862536291c
author: eli <eli@pterodactyl>
date: Sun Aug 17 14:42:53 EDT 2025

support for version 2 quantized format checkpoints

--- a/llama2.c
+++ b/llama2.c
@@ -3,35 +3,8 @@
 
 #include <u.h>
 #include <libc.h>
-#include <stdio.h>
-//#include <stdlib.h>
 #include <ctype.h>
-//#include <time.h>
-//#include <math.h>
-//#include <string.h>
-//#include <fcntl.h>
-//#if defined _WIN32
-//    #include "win.h"
-//#else
-//    #include <unistd.h>
-//    #include <sys/mman.h>
-//#endif
 
-#define int8_t char
-#define uint8_t uchar
-#define int32_t int
-#define ssize_t uvlong
-#define size_t ulong
-#define EXIT_FAILURE "exits"
-#define exit exits
-#define O_RDONLY OREAD
-#define sqrtf sqrt
-#define expf exp
-#define powf pow
-#define cosf cos
-#define sinf sin
-#define uint32_t uint
-
 unsigned char quantized8 = 0;
 int GS = 32;
 
@@ -51,43 +24,30 @@
 #define SIZEOFCONFIG 24
 
 typedef struct {
-    int8_t* q;    // quantized values
+    char* q;    // quantized values
     float* s; // scaling factors
 } QuantizedTensor;
 
-int read4(int fd) {
-	typedef union _result {
-		char buf[4];
-		int i;
-	} result;
-
-	result r;
-
-	if (read(fd, r.buf, 4) != 4)
-		exit(EXIT_FAILURE);
-
-	return r.i;
-}
-
 typedef struct {
     // token embedding table
-    float* token_embedding_table;    // (vocab_size, dim)
+	void *q_tokens; // (vocab_size, dim)
+    float* token_embedding_table;    // (vocab_size, dim) // dequantized
     // weights for rmsnorms
     float* rms_att_weight; // (layer, dim) rmsnorm weights
     float* rms_ffn_weight; // (layer, dim)
     // weights for matmuls. note dim == n_heads * head_size
-    float* wq; // (layer, dim, n_heads * head_size)
-    float* wk; // (layer, dim, n_kv_heads * head_size)
-    float* wv; // (layer, dim, n_kv_heads * head_size)
-    float* wo; // (layer, n_heads * head_size, dim)
+    void* wq; // (layer, dim, n_heads * head_size)
+    void* wk; // (layer, dim, n_kv_heads * head_size)
+    void* wv; // (layer, dim, n_kv_heads * head_size)
+    void* wo; // (layer, n_heads * head_size, dim)
     // weights for ffn
-    float* w1; // (layer, hidden_dim, dim)
-    float* w2; // (layer, dim, hidden_dim)
-    float* w3; // (layer, hidden_dim, dim)
+    void* w1; // (layer, hidden_dim, dim)
+    void* w2; // (layer, dim, hidden_dim)
+    void* w3; // (layer, hidden_dim, dim)
     // final rmsnorm
     float* rms_final_weight; // (dim,)
     // (optional) classifier weights for the logits, on the last layer
-    float* wcls;
+    void* wcls;
 } TransformerWeights;
 
 #define SIZEOFTRANSFORMERWEIGHTS (12*sizeof(void*))
@@ -120,7 +80,7 @@
     // some more state needed to properly clean up the memory mapping (sigh)
     int fd; // file descriptor for memory mapping
     float* data; // memory mapped data pointer
-    ssize_t file_size; // size of the checkpoint file in bytes
+    long file_size; // size of the checkpoint file in bytes
 } Transformer;
 
 #define SIZEOFTRANSFORMER (SIZEOFCONFIG+SIZEOFTRANSFORMERWEIGHTS+SIZEOFRUNSTATE+4+sizeof(float*)+sizeof(ssize_t))
@@ -137,9 +97,9 @@
     s->hb2 = calloc(p->hidden_dim, sizeof(float));
 	xq = calloc(1, sizeof(QuantizedTensor));
 	hq = calloc(1, sizeof(QuantizedTensor));
-	xq->q = calloc(p->dim, sizeof(int8_t));
+	xq->q = calloc(p->dim, sizeof(char));
 	xq->s = calloc(p->dim, sizeof(float));
-	hq->q = calloc(p->hidden_dim, sizeof(int8_t));
+	hq->q = calloc(p->hidden_dim, sizeof(char));
 	hq->s = calloc(p->hidden_dim, sizeof(float));
 	s->xq = xq;
 	s->hq = hq;
@@ -154,8 +114,7 @@
     if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
      || !s->k || !s->v || !s->xq || !s->hq || !s->hq->s || !s->hq->q
      || !s->key_cache || !s->value_cache || !s->att || !s->logits) {
-        fprintf(stderr, "malloc failed!\n");
-        exit(EXIT_FAILURE);
+        sysfatal("malloc failed!");
     }
 }
 
@@ -172,9 +131,9 @@
 	free(s->xq);
 	free(s->hq);
     free(s->q);
-//  free(s->k);
-//  free(s->v);
-//  free(s->att);
+//    free(s->k);
+//    free(s->v);
+    free(s->att);
     free(s->logits);
     free(s->key_cache);
     free(s->value_cache);
@@ -229,7 +188,7 @@
         // calculate and write the quantized values
         for (int i = 0; i < GS; i++) {
             float quant_value = x[group * GS + i] / scale; // scale
-            int8_t quantized = (int8_t) round(quant_value); // round and clamp
+            char quantized = (char) round(quant_value); // round and clamp
             qx->q[group * GS + i] = quantized;
         }
     }
@@ -241,8 +200,8 @@
     QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor));
     for(int i=0; i<n; i++) {
         /* map quantized int8 values*/
-        res[i].q = (int8_t*)p;
-        p = (int8_t*)p + size_each;
+        res[i].q = (char*)p;
+        p = (char*)p + size_each;
         /* map scale factors */
         res[i].s = (float*)p;
         p = (float*)p + size_each / GS;
@@ -251,9 +210,7 @@
     return res;
 }
 
-void memory_map_weights_q8(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) {
-	QuantizedTensor *q_tokens, *wq, *wk, *wv, *wo, *w1, *w2, *w3, *wcls;
-
+void memory_map_weights_q8(TransformerWeights *w, Config* p, void* ptr, uchar shared_classifier) {
     int head_size = p->dim / p->n_heads;
     // first are the parameters that are kept in fp32 (the rmsnorm (1D) weights)
     float* fptr = (float*) ptr; // cast our pointer to float*
@@ -266,29 +223,21 @@
 
     // now read all the quantized weights
     ptr = (void*)fptr; // now cast the pointer back to void*
-    q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim);
+    w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim);
     // dequantize token embedding table
     w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float));
-    dequantize(q_tokens, w->token_embedding_table, p->vocab_size * p->dim, 1);
+    dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim, 1);
 
-    wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size));
-	dequantize(wq, w->wq, p->dim * (p->n_heads * head_size), p->n_layers);
-    wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
-	dequantize(wk, w->wk, p->dim * (p->n_kv_heads * head_size), p->n_layers);
-    wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
-	dequantize(wv, w->wv, p->dim * (p->n_kv_heads * head_size), p->n_layers);
-    wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim);
-	dequantize(wo, w->wo, (p->n_heads * head_size) * p->dim, p->n_layers);
+    w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size));
+    w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
+    w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
+    w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim);
 
-    w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
-	dequantize(w1, w->w1, p->dim * p->hidden_dim, p->n_layers);
-    w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim);
-	dequantize(w2, w->w2, p->hidden_dim * p->dim, p->n_layers);
-    w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
-	dequantize(w3, w->w3, p->dim * p->hidden_dim, p->n_layers);
+    w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
+    w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim);
+    w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
 
-    wcls = shared_classifier ? q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size);
-	dequantize(wcls, w->wcls, p->dim * p->vocab_size, 1);
+    w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size);
 }
 
 void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
@@ -323,7 +272,7 @@
 }
 
 void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
-                     int* fd, float** data, ssize_t* file_size) {
+                     int* fd, float** data, long* file_size) {
 	uvlong length;
 	uvlong offset;
 	int ret;
@@ -331,31 +280,31 @@
 	Dir *dstat;
 	unsigned int magic;
 	int header_size = 28;
-	uint8_t shared_classifier;
+	uchar shared_classifier;
 	int group_size;
 
     fdt = open(checkpoint, OREAD);
-    if (fdt < 3) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
-	if (read(fdt, &magic, 4) != 4) { exits("read magic"); }
+    if (fdt < 3) sysfatal("couldn't open file %s", checkpoint);
+	if (read(fdt, &magic, 4) != 4) sysfatal("read magic");
 	if (magic == 0x616b3432) {
-		if (read(fdt, &magic, 4) != 4) { exits("read version"); }
-		if (magic != 2) { exits("version (quantized) is not 2"); }
+		if (read(fdt, &magic, 4) != 4) sysfatal("read version");
+		if (magic != 2) sysfatal("version (quantized) is not 2");
 		quantized8 = 1;
 		header_size = 256;
-		magic = read4(fdt);
+		if (read(fdt, &magic, 4) != 4) sysfatal("read dim");
 	}
     // read in the config header
     config->dim = magic;
-	config->hidden_dim = read4(fdt);
-	config->n_layers = read4(fdt);
-	config->n_heads = read4(fdt);
-	config->n_kv_heads = read4(fdt);
-	config->vocab_size = read4(fdt);
-	config->seq_len = read4(fdt);
+	if (read(fdt, &config->hidden_dim, 4) != 4) sysfatal("read hidden_dim");
+	if (read(fdt, &config->n_layers, 4) != 4) sysfatal("read n_layers");
+	if (read(fdt, &config->n_heads, 4) != 4) sysfatal("read n_heads");
+	if (read(fdt, &config->n_kv_heads, 4) != 4) sysfatal("read n_kv_heads");
+	if (read(fdt, &config->vocab_size, 4) != 4) sysfatal("read vocab_size");
+	if (read(fdt, &config->seq_len, 4) != 4) sysfatal("read seq_len");
 
 	if (quantized8 == 1) {
-		if (read(fdt, &shared_classifier, 1) != 1) exits("read shared_classifier");
-		if (read(fdt, &group_size, 4) != 4) exits("read group_size");
+		if (read(fdt, &shared_classifier, 1) != 1) sysfatal("read shared_classifier");
+		if (read(fdt, &group_size, 4) != 4) sysfatal("read group_size");
 		GS = group_size;
 	}
 
@@ -362,19 +311,13 @@
     // negative vocab size is hacky way of signaling unshared weights. bit yikes.
     int shared_weights = config->vocab_size > 0 ? 1 : 0;
     config->vocab_size = abs(config->vocab_size);
-    // figure out the file size
- //   fseek(file, 0, SEEK_END); // move file pointer to end of file
- //   *file_size = ftell(file); // get the file size, in bytes
- //   fclose(file);
 	close(fdt);
 	dstat = dirstat(checkpoint);
 	*file_size = dstat->length;
 	free(dstat);
+	*fd = open(checkpoint, OREAD);
     // memory map the Transformer weights into the data pointer
-    *fd = open(checkpoint, O_RDONLY); // open in read only mode
-    if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
-//    *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
-//    if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
+    if (*fd < 3) sysfatal("open failed!");
 	*data = malloc(*file_size);
 	length = *file_size;
 	offset = 0;
@@ -399,8 +342,7 @@
 }
 
 void free_transformer(Transformer* t) {
-    // close the memory mapping
-//    if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
+	free(t->data);
     if (t->fd != -1) { close(t->fd); }
     // free the RunState buffers
     free_run_state(&t->state);
@@ -417,7 +359,7 @@
     }
     ss /= size;
     ss += 1e-5f;
-    ss = 1.0f / sqrtf(ss);
+    ss = 1.0f / sqrt(ss);
     // normalize and scale
     for (int j = 0; j < size; j++) {
         o[j] = weight[j] * (ss * x[j]);
@@ -435,7 +377,7 @@
     // exp and sum
     float sum = 0.0f;
     for (int i = 0; i < size; i++) {
-        x[i] = expf(x[i] - max_val);
+        x[i] = exp(x[i] - max_val);
         sum += x[i];
     }
     // normalize
@@ -458,6 +400,33 @@
     }
 }
 
+void matmul_q8(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
+    // W (d,n) @ x (n,) -> xout (d,)
+    // by far the most amount of time is spent inside this little function
+    // inputs to this function are both quantized
+
+    int i;
+    #pragma omp parallel for private(i)
+    for (i = 0; i < d; i++) {
+
+        float val = 0.0f;
+        int ival = 0;
+        int in = i * n;
+
+        // do the matmul in groups of GS
+        int j;
+        for (j = 0; j <= n - GS; j += GS) {
+            for (int k = 0; k < GS; k++) {
+                ival += ((int) x->q[j + k]) * ((int) w->q[in + j + k]);
+            }
+            val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
+            ival = 0;
+        }
+
+        xout[i] = val;
+    }
+}
+
 float* forward(Transformer* transformer, int token, int pos) {
     // a few convenience variables
     Config* p = &transformer->config;
@@ -486,17 +455,17 @@
         s->v = s->value_cache + loff + pos * kv_dim;
 
         // qkv matmuls for this position
-        matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
-        matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
-        matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
+        matmul(s->q, s->xb, (float*)w->wq + l*dim*dim, dim, dim);
+        matmul(s->k, s->xb, (float*)w->wk + l*dim*kv_dim, dim, kv_dim);
+        matmul(s->v, s->xb, (float*)w->wv + l*dim*kv_dim, dim, kv_dim);
 
         // RoPE relative positional encoding: complex-valued rotate q and k in each head
         for (int i = 0; i < dim; i+=2) {
             int head_dim = i % head_size;
-            float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
+            float freq = 1.0f / pow(10000.0f, head_dim / (float)head_size);
             float val = pos * freq;
-            float fcr = cosf(val);
-            float fci = sinf(val);
+            float fcr = cos(val);
+            float fci = sin(val);
             int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
             for (int v = 0; v < rotn; v++) {
                 float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
@@ -524,7 +493,7 @@
                 for (int i = 0; i < head_size; i++) {
                     score += q[i] * k[i];
                 }
-                score /= sqrtf(head_size);
+                score /= sqrt(head_size);
                 // save the score to the attention buffer
                 att[t] = score;
             }
@@ -548,7 +517,7 @@
         }
 
         // final matmul to get the output of the attention
-        matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
+        matmul(s->xb2, s->xb, (float*)w->wo + l*dim*dim, dim, dim);
 
         // residual connection back into x
         for (int i = 0; i < dim; i++) {
@@ -560,14 +529,14 @@
 
         // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
         // first calculate self.w1(x) and self.w3(x)
-        matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
-        matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
+        matmul(s->hb, s->xb, (float*)w->w1 + l*dim*hidden_dim, dim, hidden_dim);
+        matmul(s->hb2, s->xb, (float*)w->w3 + l*dim*hidden_dim, dim, hidden_dim);
 
         // SwiGLU non-linearity
         for (int i = 0; i < hidden_dim; i++) {
             float val = s->hb[i];
             // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
-            val *= (1.0f / (1.0f + expf(-val)));
+            val *= (1.0f / (1.0f + exp(-val)));
             // elementwise multiply with w3(x)
             val *= s->hb2[i];
             s->hb[i] = val;
@@ -574,7 +543,7 @@
         }
 
         // final matmul to get the output of the ffn
-        matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
+        matmul(s->xb, s->hb, (float*)w->w2 + l*dim*hidden_dim, hidden_dim, dim);
 
         // residual connection
         for (int i = 0; i < dim; i++) {
@@ -590,6 +559,146 @@
     return s->logits;
 }
 
+float* forward_q8(Transformer* transformer, int token, int pos) {
+
+    // a few convenience variables
+    Config* p = &transformer->config;
+    TransformerWeights* w = &transformer->weights;
+    RunState* s = &transformer->state;
+    float *x = s->x;
+    int dim = p->dim;
+    int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
+    int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
+    int hidden_dim =  p->hidden_dim;
+    int head_size = dim / p->n_heads;
+
+    // copy the token embedding into x
+    memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float));
+
+    // forward all the layers
+    for(int l = 0; l < p->n_layers; l++) {
+
+        // attention rmsnorm
+        rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
+
+        // qkv matmuls for this position
+        quantize(s->xq, s->xb, dim);
+        matmul_q8(s->q, s->xq, (QuantizedTensor*)w->wq + l, dim, dim);
+        matmul_q8(s->k, s->xq, (QuantizedTensor*)w->wk + l, dim, kv_dim);
+        matmul_q8(s->v, s->xq, (QuantizedTensor*)w->wv + l, dim, kv_dim);
+
+        // RoPE relative positional encoding: complex-valued rotate q and k in each head
+        for (int i = 0; i < dim; i+=2) {
+            int head_dim = i % head_size;
+            float freq = 1.0f / pow(10000.0f, head_dim / (float)head_size);
+            float val = pos * freq;
+            float fcr = cos(val);
+            float fci = sin(val);
+            int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
+            for (int v = 0; v < rotn; v++) {
+                float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
+                float v0 = vec[i];
+                float v1 = vec[i+1];
+                vec[i]   = v0 * fcr - v1 * fci;
+                vec[i+1] = v0 * fci + v1 * fcr;
+            }
+        }
+
+        // save key,value at this time step (pos) to our kv cache
+        int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
+        float* key_cache_row = s->key_cache + loff + pos * kv_dim;
+        float* value_cache_row = s->value_cache + loff + pos * kv_dim;
+        memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
+        memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
+
+        // multihead attention. iterate over all heads
+        int h;
+        #pragma omp parallel for private(h)
+        for (h = 0; h < p->n_heads; h++) {
+            // get the query vector for this head
+            float* q = s->q + h * head_size;
+            // attention scores for this head
+            float* att = s->att + h * p->seq_len;
+            // iterate over all timesteps, including the current one
+            for (int t = 0; t <= pos; t++) {
+                // get the key vector for this head and at this timestep
+                float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
+                // calculate the attention score as the dot product of q and k
+                float score = 0.0f;
+                for (int i = 0; i < head_size; i++) {
+                    score += q[i] * k[i];
+                }
+                score /= sqrt(head_size);
+                // save the score to the attention buffer
+                att[t] = score;
+            }
+
+            // softmax the scores to get attention weights, from 0..pos inclusively
+            softmax(att, pos + 1);
+
+            // weighted sum of the values, store back into xb
+            float* xb = s->xb + h * head_size;
+            memset(xb, 0, head_size * sizeof(float));
+            for (int t = 0; t <= pos; t++) {
+                // get the value vector for this head and at this timestep
+                float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
+                // get the attention weight for this timestep
+                float a = att[t];
+                // accumulate the weighted value into xb
+                for (int i = 0; i < head_size; i++) {
+                    xb[i] += a * v[i];
+                }
+            }
+        }
+
+        // final matmul to get the output of the attention
+        quantize(s->xq, s->xb, dim);
+        matmul_q8(s->xb2, s->xq, (QuantizedTensor*)w->wo + l, dim, dim);
+
+        // residual connection back into x
+        for (int i = 0; i < dim; i++) {
+            x[i] += s->xb2[i];
+        }
+
+        // ffn rmsnorm
+        rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
+
+        // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
+        // first calculate self.w1(x) and self.w3(x)
+        quantize(s->xq, s->xb, dim);
+        matmul_q8(s->hb, s->xq, (QuantizedTensor*)w->w1 + l, dim, hidden_dim);
+        matmul_q8(s->hb2, s->xq, (QuantizedTensor*)w->w3 + l, dim, hidden_dim);
+
+        // SwiGLU non-linearity
+        for (int i = 0; i < hidden_dim; i++) {
+            float val = s->hb[i];
+            // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
+            val *= (1.0f / (1.0f + exp(-val)));
+            // elementwise multiply with w3(x)
+            val *= s->hb2[i];
+            s->hb[i] = val;
+        }
+
+        // final matmul to get the output of the ffn
+        quantize(s->hq, s->hb, hidden_dim);
+        matmul_q8(s->xb, s->hq, (QuantizedTensor*)w->w2 + l, hidden_dim, dim);
+
+        // residual connection
+        for (int i = 0; i < dim; i++) {
+            x[i] += s->xb[i];
+        }
+    }
+
+    // final rmsnorm
+    rmsnorm(x, x, w->rms_final_weight, dim);
+
+    // classifier into logits
+    quantize(s->xq, x, dim);
+    matmul_q8(s->logits, s->xq, w->wcls, dim, p->vocab_size);
+    return s->logits;
+}
+
+
 // ----------------------------------------------------------------------------
 // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
 
@@ -619,28 +728,24 @@
     // malloc space to hold the scores and the strings
     t->vocab = (char**)malloc(vocab_size * sizeof(char*));
     t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
-    t->sorted_vocab = NULL; // initialized lazily
+    t->sorted_vocab = nil; // initialized lazily
     for (int i = 0; i < 256; i++) {
         t->byte_pieces[i * 2] = (unsigned char)i;
         t->byte_pieces[i * 2 + 1] = '\0';
     }
     // read in the file
-//    FILE *file = fopen(tokenizer_path, "rb");
 	int fd = open(tokenizer_path, OREAD);
-    if (fd < 3) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
-    if (read(fd, &t->max_token_length, sizeof(int)) != sizeof(int)) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
-//	fprint(2, "max_token_length: %d\n", t->max_token_length);
+    if (fd < 3) sysfatal("couldn't load %s", tokenizer_path);
+    if (read(fd, &t->max_token_length, sizeof(int)) != sizeof(int)) sysfatal("failed read");
     int len;
     for (int i = 0; i < vocab_size; i++) {
-        if (read(fd, t->vocab_scores + i, sizeof(float)) != sizeof(float)) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
-        if (read(fd, &len, sizeof(int)) != sizeof(int)) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+        if (read(fd, t->vocab_scores + i, sizeof(float)) != sizeof(float)) sysfatal("failed read");
+        if (read(fd, &len, sizeof(int)) != sizeof(int)) sysfatal("failed read");
         t->vocab[i] = (char*)malloc(len + 1);
-        if (read(fd, t->vocab[i], len) != len) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+        if (read(fd, t->vocab[i], len) != len) sysfatal("failed read");
         t->vocab[i][len] = '\0'; // add the string terminating token
     }
     close(fd);
-
-//	fprint(2, "vocab_size: %d\n", vocab_size);
 }
 
 void free_tokenizer(Tokenizer* t) {
@@ -657,16 +762,17 @@
     // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
     // parse this and convert and return the actual byte
     unsigned char byte_val;
-    if (sscanf(piece, "<0x%X>", &byte_val) == 1) {
+    if (strncmp(piece, "<0x", 3) == 0 || piece[strlen(piece)-1] == '>') {
+		byte_val = strtol(piece + 1, nil, 16);
         piece = (char*)t->byte_pieces + byte_val * 2;
     }
     return piece;
 }
 
-void safe_printf(char *piece) {
+void safe_print(char *piece) {
     // piece might be a raw byte token, and we only want to print printable chars or whitespace
     // because some of the other bytes can be various control codes, backspace, etc.
-    if (piece == NULL) { return; }
+    if (piece == nil) { return; }
     if (piece[0] == '\0') { return; }
     if (piece[1] == '\0') {
         unsigned char byte_val = piece[0];
@@ -674,7 +780,7 @@
             return; // bad byte, don't print it
         }
     }
-    printf("%s", piece);
+    print("%s", piece);
 }
 
 TokenIndex *bsearch_call(TokenIndex *tok, TokenIndex *list, int n, int s, int (*comp)(const void *a, const void *b), int A, int B) {
@@ -690,10 +796,7 @@
 	if (result > 0)
 		return bsearch_call(tok, list, n, s, comp, middle, B);
 
-	if (result < 0)
-		return bsearch_call(tok, list, n, s, comp, A, middle);
-
-	exits("bsearch");
+	return bsearch_call(tok, list, n, s, comp, A, middle);
 }
 
 void *bsearch(TokenIndex *tok, TokenIndex *list, int n, int s, int (*comp)(const void *a, const void *b)) {
@@ -705,15 +808,15 @@
     // efficiently find the perfect match for str in vocab, return its index or -1 if not found
     TokenIndex tok = { .str = str }; // acts as the key to search for
     TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, SIZEOFTOKENINDEX, compare_tokens);
-    return res != NULL ? res->id : -1;
+    return res != nil ? res->id : -1;
 }
 
-void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
+void encode(Tokenizer* t, char *text, char bos, char eos, int *tokens, int *n_tokens) {
     // encode the string text (input) into an upper-bound preallocated tokens[] array
     // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
-    if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
+    if (text == nil) sysfatal("cannot encode NULL text");
 
-    if (t->sorted_vocab == NULL) {
+    if (t->sorted_vocab == nil) {
         // lazily malloc and sort the vocabulary
         t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
         for (int i = 0; i < t->vocab_size; i++) {
@@ -726,7 +829,7 @@
     // create a temporary buffer that will store merge candidates of always two consecutive tokens
     // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
     char* str_buffer = malloc((t->max_token_length*2 +1 +2));
-    size_t str_len = 0;
+    long str_len = 0;
 
     // start at 0 tokens
     *n_tokens = 0;
@@ -800,7 +903,7 @@
 
         for (int i=0; i < (*n_tokens-1); i++) {
             // check if we can merge the pair (tokens[i], tokens[i+1])
-            sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
+            sprint(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
             int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
             if (id != -1 && t->vocab_scores[id] > best_score) {
                 // this merge pair exists in vocab! record its score and position
@@ -991,7 +1094,7 @@
 
 void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
     char *empty_prompt = "";
-    if (prompt == NULL) { prompt = empty_prompt; }
+    if (prompt == nil) { prompt = empty_prompt; }
 
     // encode the (string) prompt into tokens sequence
     int num_prompt_tokens = 0;
@@ -998,8 +1101,7 @@
     int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * 4); // +3 for '\0', ?BOS, ?EOS
     encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
     if (num_prompt_tokens < 1) {
-        fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
-        exit(EXIT_FAILURE);
+        sysfatal("something is wrong, expected at least 1 prompt token");
     }
 
     // start the main loop
@@ -1008,9 +1110,13 @@
     int token = prompt_tokens[0]; // kick off with the first token in the prompt
     int pos = 0;     // position in the sequence
     while (pos < steps) {
+		float *logits;
 
         // forward the transformer to get logits for the next token
-        float* logits = forward(transformer, token, pos);
+		if (quantized8 == 0)
+	        logits = forward(transformer, token, pos);
+		else
+			logits = forward_q8(transformer, token, pos);
 
         // advance the state machine
         if (pos < (num_prompt_tokens - 1)) {
@@ -1027,29 +1133,28 @@
 
         // print the token as string, decode it with the Tokenizer object
         char* piece = decode(tokenizer, token, next);
-        safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
-        fflush(stdout);
+        safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes
         token = next;
 
         // init the timer here because the first iteration can be slower
         if (start == 0) { start = time_in_ms(); }
     }
-    printf("\n");
+    print("\n");
 
     // report achieved tok/s (pos-1 because the timer starts after first iteration)
 //    if (pos > 1) {
 //        long end = time_in_ms();
-//        fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
+//        fprint(2, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
 //    }
 
     free(prompt_tokens);
 }
 
-void read_stdin(const char* guide, char* buffer, size_t bufsize) {
+void read_stdin(const char* guide, char* buffer, long bufsize) {
     // read a line from stdin, up to but not including \n
-    printf("%s", guide);
-    if (fgets(buffer, bufsize, stdin) != NULL) {
-        size_t len = strlen(buffer);
+    print("%s", guide);
+    if (read(0, buffer, bufsize) > 0) {
+        long len = strlen(buffer);
         if (len > 0 && buffer[len - 1] == '\n') {
             buffer[len - 1] = '\0'; // strip newline
         }
@@ -1075,7 +1180,7 @@
     int user_idx = 0;
 
     // start the main loop
-    int8_t user_turn = 1; // user starts
+    char user_turn = 1; // user starts
     int next = user_turn;        // will store the next token in the sequence
     int token;       // stores the current token to feed into the transformer
     int pos = 0;     // position in the sequence
@@ -1086,7 +1191,7 @@
             // get the (optional) system prompt at position 0
             if (pos == 0) {
                 // at position 0, the user can also contribute a system prompt
-                if (cli_system_prompt == NULL) {
+                if (cli_system_prompt == nil) {
                     // system prompt was not passed in, attempt to get it from stdin
                     read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
                 } else {
@@ -1095,7 +1200,7 @@
                 }
             }
             // get the user prompt
-            if (pos == 0 && cli_user_prompt != NULL) {
+            if (pos == 0 && cli_user_prompt != nil) {
                 // user prompt for position 0 was passed in, use it
                 strcpy(user_prompt, cli_user_prompt);
             } else {
@@ -1105,16 +1210,16 @@
             // render user/system prompts into the Llama 2 Chat schema
             if (pos == 0 && system_prompt[0] != '\0') {
                 char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
-                sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
+                sprint(rendered_prompt, system_template, system_prompt, user_prompt);
             } else {
                 char user_template[] = "[INST] %s [/INST]";
-                sprintf(rendered_prompt, user_template, user_prompt);
+                sprint(rendered_prompt, user_template, user_prompt);
             }
             // encode the rendered prompt into tokens
             encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
             user_idx = 0; // reset the user index
             user_turn = 0;
-            printf("Assistant: ");
+            print("Assistant: ");
         }
 
         // determine the token to pass into the transformer next
@@ -1136,12 +1241,11 @@
         if (user_idx >= num_prompt_tokens && next != 2) {
             // the Assistant is responding, so print its output
             char* piece = decode(tokenizer, token, next);
-            safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
-            fflush(stdout);
+            safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes
         }
-        if (next == 2) { printf("\n"); }
+        if (next == 2) { print("\n"); }
     }
-    printf("\n");
+    print("\n");
     free(prompt_tokens);
 }
 
@@ -1150,31 +1254,31 @@
 // CLI, include only if not testing
 
 void error_usage(void) {
-    fprintf(stderr, "Usage:   run <checkpoint> [options]\n");
-    fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
-    fprintf(stderr, "Options:\n");
-    fprintf(stderr, "  -t <float>  temperature in [0,inf], default 1.0\n");
-    fprintf(stderr, "  -p <float>  p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
-    fprintf(stderr, "  -s <int>    random seed, default time(NULL)\n");
-    fprintf(stderr, "  -n <int>    number of steps to run for, default 256. 0 = max_seq_len\n");
-    fprintf(stderr, "  -i <string> input prompt\n");
-    fprintf(stderr, "  -z <string> optional path to custom tokenizer\n");
-    fprintf(stderr, "  -m <string> mode: generate|chat, default: generate\n");
-    fprintf(stderr, "  -y <string> (optional) system prompt in chat mode\n");
-    exit(EXIT_FAILURE);
+    fprint(2, "Usage:   run <checkpoint> [options]\n");
+    fprint(2, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
+    fprint(2, "Options:\n");
+    fprint(2, "  -t <float>  temperature in [0,inf], default 1.0\n");
+    fprint(2, "  -p <float>  p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
+    fprint(2, "  -s <int>    random seed, default time(nil)\n");
+    fprint(2, "  -n <int>    number of steps to run for, default 256. 0 = max_seq_len\n");
+    fprint(2, "  -i <string> input prompt\n");
+    fprint(2, "  -z <string> optional path to custom tokenizer\n");
+    fprint(2, "  -m <string> mode: generate|chat, default: generate\n");
+    fprint(2, "  -y <string> (optional) system prompt in chat mode\n");
+    exits("usage");
 }
 
 void main(int argc, char *argv[]) {
     // default parameters
-    char *checkpoint_path = NULL;  // e.g. out/model.bin
+    char *checkpoint_path = nil;  // e.g. out/model.bin
     char *tokenizer_path = "tokenizer.bin";
     float temperature = 1.0f;   // 0.0 = greedy deterministic. 1.0 = original. don't set higher
     float topp = 0.9f;          // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
     int steps = 256;            // number of steps to run for
-    char *prompt = NULL;        // prompt string
+    char *prompt = nil;        // prompt string
     unsigned long long rng_seed = 0; // seed rng with time by default
     char *mode = "generate";    // generate|chat
-    char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
+    char *system_prompt = nil; // the (optional) system prompt to use in chat mode
 
     // poor man's C argparse so we can override the defaults above from the command line
     if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
@@ -1196,7 +1300,7 @@
     }
 
     // parameter validation/overrides
-    if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
+    if (rng_seed <= 0) rng_seed = (unsigned int)time(nil);
     if (temperature < 0.0) temperature = 0.0;
     if (topp < 0.0 || 1.0 < topp) topp = 0.9;
     if (steps < 0) steps = 0;
@@ -1220,7 +1324,7 @@
     } else if (strcmp(mode, "chat") == 0) {
         chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
     } else {
-        fprintf(stderr, "unknown mode: %s\n", mode);
+        fprint(2, "unknown mode: %s\n", mode);
         error_usage();
     }
 
--