ref: 5b790b347f51baee12eeb2dfced7ef5cc0065a7e
parent: 07398654cc5684fec2ff81fc1b21c3862536291c
author: eli <eli@pterodactyl>
date: Sun Aug 17 14:42:53 EDT 2025
support for version 2 quantized format checkpoints
--- a/llama2.c
+++ b/llama2.c
@@ -3,35 +3,8 @@
#include <u.h>
#include <libc.h>
-#include <stdio.h>
-//#include <stdlib.h>
#include <ctype.h>
-//#include <time.h>
-//#include <math.h>
-//#include <string.h>
-//#include <fcntl.h>
-//#if defined _WIN32
-// #include "win.h"
-//#else
-// #include <unistd.h>
-// #include <sys/mman.h>
-//#endif
-#define int8_t char
-#define uint8_t uchar
-#define int32_t int
-#define ssize_t uvlong
-#define size_t ulong
-#define EXIT_FAILURE "exits"
-#define exit exits
-#define O_RDONLY OREAD
-#define sqrtf sqrt
-#define expf exp
-#define powf pow
-#define cosf cos
-#define sinf sin
-#define uint32_t uint
-
unsigned char quantized8 = 0;
int GS = 32;
@@ -51,43 +24,30 @@
#define SIZEOFCONFIG 24
typedef struct {
- int8_t* q; // quantized values
+ char* q; // quantized values
float* s; // scaling factors
} QuantizedTensor;
-int read4(int fd) {
- typedef union _result {
- char buf[4];
- int i;
- } result;
-
- result r;
-
- if (read(fd, r.buf, 4) != 4)
- exit(EXIT_FAILURE);
-
- return r.i;
-}
-
typedef struct {
// token embedding table
- float* token_embedding_table; // (vocab_size, dim)
+ void *q_tokens; // (vocab_size, dim)
+ float* token_embedding_table; // (vocab_size, dim) // dequantized
// weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim)
// weights for matmuls. note dim == n_heads * head_size
- float* wq; // (layer, dim, n_heads * head_size)
- float* wk; // (layer, dim, n_kv_heads * head_size)
- float* wv; // (layer, dim, n_kv_heads * head_size)
- float* wo; // (layer, n_heads * head_size, dim)
+ void* wq; // (layer, dim, n_heads * head_size)
+ void* wk; // (layer, dim, n_kv_heads * head_size)
+ void* wv; // (layer, dim, n_kv_heads * head_size)
+ void* wo; // (layer, n_heads * head_size, dim)
// weights for ffn
- float* w1; // (layer, hidden_dim, dim)
- float* w2; // (layer, dim, hidden_dim)
- float* w3; // (layer, hidden_dim, dim)
+ void* w1; // (layer, hidden_dim, dim)
+ void* w2; // (layer, dim, hidden_dim)
+ void* w3; // (layer, hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
// (optional) classifier weights for the logits, on the last layer
- float* wcls;
+ void* wcls;
} TransformerWeights;
#define SIZEOFTRANSFORMERWEIGHTS (12*sizeof(void*))
@@ -120,7 +80,7 @@
// some more state needed to properly clean up the memory mapping (sigh)
int fd; // file descriptor for memory mapping
float* data; // memory mapped data pointer
- ssize_t file_size; // size of the checkpoint file in bytes
+ long file_size; // size of the checkpoint file in bytes
} Transformer;
#define SIZEOFTRANSFORMER (SIZEOFCONFIG+SIZEOFTRANSFORMERWEIGHTS+SIZEOFRUNSTATE+4+sizeof(float*)+sizeof(ssize_t))
@@ -137,9 +97,9 @@
s->hb2 = calloc(p->hidden_dim, sizeof(float));
xq = calloc(1, sizeof(QuantizedTensor));
hq = calloc(1, sizeof(QuantizedTensor));
- xq->q = calloc(p->dim, sizeof(int8_t));
+ xq->q = calloc(p->dim, sizeof(char));
xq->s = calloc(p->dim, sizeof(float));
- hq->q = calloc(p->hidden_dim, sizeof(int8_t));
+ hq->q = calloc(p->hidden_dim, sizeof(char));
hq->s = calloc(p->hidden_dim, sizeof(float));
s->xq = xq;
s->hq = hq;
@@ -154,8 +114,7 @@
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->xq || !s->hq || !s->hq->s || !s->hq->q
|| !s->key_cache || !s->value_cache || !s->att || !s->logits) {
- fprintf(stderr, "malloc failed!\n");
- exit(EXIT_FAILURE);
+ sysfatal("malloc failed!");
}
}
@@ -172,9 +131,9 @@
free(s->xq);
free(s->hq);
free(s->q);
-// free(s->k);
-// free(s->v);
-// free(s->att);
+// free(s->k);
+// free(s->v);
+ free(s->att);
free(s->logits);
free(s->key_cache);
free(s->value_cache);
@@ -229,7 +188,7 @@
// calculate and write the quantized values
for (int i = 0; i < GS; i++) {
float quant_value = x[group * GS + i] / scale; // scale
- int8_t quantized = (int8_t) round(quant_value); // round and clamp
+ char quantized = (char) round(quant_value); // round and clamp
qx->q[group * GS + i] = quantized;
}
}
@@ -241,8 +200,8 @@
QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor));
for(int i=0; i<n; i++) {
/* map quantized int8 values*/
- res[i].q = (int8_t*)p;
- p = (int8_t*)p + size_each;
+ res[i].q = (char*)p;
+ p = (char*)p + size_each;
/* map scale factors */
res[i].s = (float*)p;
p = (float*)p + size_each / GS;
@@ -251,9 +210,7 @@
return res;
}
-void memory_map_weights_q8(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) {
- QuantizedTensor *q_tokens, *wq, *wk, *wv, *wo, *w1, *w2, *w3, *wcls;
-
+void memory_map_weights_q8(TransformerWeights *w, Config* p, void* ptr, uchar shared_classifier) {
int head_size = p->dim / p->n_heads;
// first are the parameters that are kept in fp32 (the rmsnorm (1D) weights)
float* fptr = (float*) ptr; // cast our pointer to float*
@@ -266,29 +223,21 @@
// now read all the quantized weights
ptr = (void*)fptr; // now cast the pointer back to void*
- q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim);
+ w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim);
// dequantize token embedding table
w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float));
- dequantize(q_tokens, w->token_embedding_table, p->vocab_size * p->dim, 1);
+ dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim, 1);
- wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size));
- dequantize(wq, w->wq, p->dim * (p->n_heads * head_size), p->n_layers);
- wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
- dequantize(wk, w->wk, p->dim * (p->n_kv_heads * head_size), p->n_layers);
- wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
- dequantize(wv, w->wv, p->dim * (p->n_kv_heads * head_size), p->n_layers);
- wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim);
- dequantize(wo, w->wo, (p->n_heads * head_size) * p->dim, p->n_layers);
+ w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size));
+ w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
+ w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
+ w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim);
- w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
- dequantize(w1, w->w1, p->dim * p->hidden_dim, p->n_layers);
- w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim);
- dequantize(w2, w->w2, p->hidden_dim * p->dim, p->n_layers);
- w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
- dequantize(w3, w->w3, p->dim * p->hidden_dim, p->n_layers);
+ w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
+ w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim);
+ w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
- wcls = shared_classifier ? q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size);
- dequantize(wcls, w->wcls, p->dim * p->vocab_size, 1);
+ w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size);
}
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
@@ -323,7 +272,7 @@
}
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
- int* fd, float** data, ssize_t* file_size) {
+ int* fd, float** data, long* file_size) {
uvlong length;
uvlong offset;
int ret;
@@ -331,31 +280,31 @@
Dir *dstat;
unsigned int magic;
int header_size = 28;
- uint8_t shared_classifier;
+ uchar shared_classifier;
int group_size;
fdt = open(checkpoint, OREAD);
- if (fdt < 3) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
- if (read(fdt, &magic, 4) != 4) { exits("read magic"); }
+ if (fdt < 3) sysfatal("couldn't open file %s", checkpoint);
+ if (read(fdt, &magic, 4) != 4) sysfatal("read magic");
if (magic == 0x616b3432) {
- if (read(fdt, &magic, 4) != 4) { exits("read version"); }
- if (magic != 2) { exits("version (quantized) is not 2"); }
+ if (read(fdt, &magic, 4) != 4) sysfatal("read version");
+ if (magic != 2) sysfatal("version (quantized) is not 2");
quantized8 = 1;
header_size = 256;
- magic = read4(fdt);
+ if (read(fdt, &magic, 4) != 4) sysfatal("read dim");
}
// read in the config header
config->dim = magic;
- config->hidden_dim = read4(fdt);
- config->n_layers = read4(fdt);
- config->n_heads = read4(fdt);
- config->n_kv_heads = read4(fdt);
- config->vocab_size = read4(fdt);
- config->seq_len = read4(fdt);
+ if (read(fdt, &config->hidden_dim, 4) != 4) sysfatal("read hidden_dim");
+ if (read(fdt, &config->n_layers, 4) != 4) sysfatal("read n_layers");
+ if (read(fdt, &config->n_heads, 4) != 4) sysfatal("read n_heads");
+ if (read(fdt, &config->n_kv_heads, 4) != 4) sysfatal("read n_kv_heads");
+ if (read(fdt, &config->vocab_size, 4) != 4) sysfatal("read vocab_size");
+ if (read(fdt, &config->seq_len, 4) != 4) sysfatal("read seq_len");
if (quantized8 == 1) {
- if (read(fdt, &shared_classifier, 1) != 1) exits("read shared_classifier");
- if (read(fdt, &group_size, 4) != 4) exits("read group_size");
+ if (read(fdt, &shared_classifier, 1) != 1) sysfatal("read shared_classifier");
+ if (read(fdt, &group_size, 4) != 4) sysfatal("read group_size");
GS = group_size;
}
@@ -362,19 +311,13 @@
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config->vocab_size > 0 ? 1 : 0;
config->vocab_size = abs(config->vocab_size);
- // figure out the file size
- // fseek(file, 0, SEEK_END); // move file pointer to end of file
- // *file_size = ftell(file); // get the file size, in bytes
- // fclose(file);
close(fdt);
dstat = dirstat(checkpoint);
*file_size = dstat->length;
free(dstat);
+ *fd = open(checkpoint, OREAD);
// memory map the Transformer weights into the data pointer
- *fd = open(checkpoint, O_RDONLY); // open in read only mode
- if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
-// *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
-// if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
+ if (*fd < 3) sysfatal("open failed!");
*data = malloc(*file_size);
length = *file_size;
offset = 0;
@@ -399,8 +342,7 @@
}
void free_transformer(Transformer* t) {
- // close the memory mapping
-// if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
+ free(t->data);
if (t->fd != -1) { close(t->fd); }
// free the RunState buffers
free_run_state(&t->state);
@@ -417,7 +359,7 @@
}
ss /= size;
ss += 1e-5f;
- ss = 1.0f / sqrtf(ss);
+ ss = 1.0f / sqrt(ss);
// normalize and scale
for (int j = 0; j < size; j++) {
o[j] = weight[j] * (ss * x[j]);
@@ -435,7 +377,7 @@
// exp and sum
float sum = 0.0f;
for (int i = 0; i < size; i++) {
- x[i] = expf(x[i] - max_val);
+ x[i] = exp(x[i] - max_val);
sum += x[i];
}
// normalize
@@ -458,6 +400,33 @@
}
}
+void matmul_q8(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
+ // W (d,n) @ x (n,) -> xout (d,)
+ // by far the most amount of time is spent inside this little function
+ // inputs to this function are both quantized
+
+ int i;
+ #pragma omp parallel for private(i)
+ for (i = 0; i < d; i++) {
+
+ float val = 0.0f;
+ int ival = 0;
+ int in = i * n;
+
+ // do the matmul in groups of GS
+ int j;
+ for (j = 0; j <= n - GS; j += GS) {
+ for (int k = 0; k < GS; k++) {
+ ival += ((int) x->q[j + k]) * ((int) w->q[in + j + k]);
+ }
+ val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
+ ival = 0;
+ }
+
+ xout[i] = val;
+ }
+}
+
float* forward(Transformer* transformer, int token, int pos) {
// a few convenience variables
Config* p = &transformer->config;
@@ -486,17 +455,17 @@
s->v = s->value_cache + loff + pos * kv_dim;
// qkv matmuls for this position
- matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
- matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
- matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
+ matmul(s->q, s->xb, (float*)w->wq + l*dim*dim, dim, dim);
+ matmul(s->k, s->xb, (float*)w->wk + l*dim*kv_dim, dim, kv_dim);
+ matmul(s->v, s->xb, (float*)w->wv + l*dim*kv_dim, dim, kv_dim);
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
int head_dim = i % head_size;
- float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
+ float freq = 1.0f / pow(10000.0f, head_dim / (float)head_size);
float val = pos * freq;
- float fcr = cosf(val);
- float fci = sinf(val);
+ float fcr = cos(val);
+ float fci = sin(val);
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int v = 0; v < rotn; v++) {
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
@@ -524,7 +493,7 @@
for (int i = 0; i < head_size; i++) {
score += q[i] * k[i];
}
- score /= sqrtf(head_size);
+ score /= sqrt(head_size);
// save the score to the attention buffer
att[t] = score;
}
@@ -548,7 +517,7 @@
}
// final matmul to get the output of the attention
- matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
+ matmul(s->xb2, s->xb, (float*)w->wo + l*dim*dim, dim, dim);
// residual connection back into x
for (int i = 0; i < dim; i++) {
@@ -560,14 +529,14 @@
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
- matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
- matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
+ matmul(s->hb, s->xb, (float*)w->w1 + l*dim*hidden_dim, dim, hidden_dim);
+ matmul(s->hb2, s->xb, (float*)w->w3 + l*dim*hidden_dim, dim, hidden_dim);
// SwiGLU non-linearity
for (int i = 0; i < hidden_dim; i++) {
float val = s->hb[i];
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
- val *= (1.0f / (1.0f + expf(-val)));
+ val *= (1.0f / (1.0f + exp(-val)));
// elementwise multiply with w3(x)
val *= s->hb2[i];
s->hb[i] = val;
@@ -574,7 +543,7 @@
}
// final matmul to get the output of the ffn
- matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
+ matmul(s->xb, s->hb, (float*)w->w2 + l*dim*hidden_dim, hidden_dim, dim);
// residual connection
for (int i = 0; i < dim; i++) {
@@ -590,6 +559,146 @@
return s->logits;
}
+float* forward_q8(Transformer* transformer, int token, int pos) {
+
+ // a few convenience variables
+ Config* p = &transformer->config;
+ TransformerWeights* w = &transformer->weights;
+ RunState* s = &transformer->state;
+ float *x = s->x;
+ int dim = p->dim;
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
+ int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
+ int hidden_dim = p->hidden_dim;
+ int head_size = dim / p->n_heads;
+
+ // copy the token embedding into x
+ memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float));
+
+ // forward all the layers
+ for(int l = 0; l < p->n_layers; l++) {
+
+ // attention rmsnorm
+ rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
+
+ // qkv matmuls for this position
+ quantize(s->xq, s->xb, dim);
+ matmul_q8(s->q, s->xq, (QuantizedTensor*)w->wq + l, dim, dim);
+ matmul_q8(s->k, s->xq, (QuantizedTensor*)w->wk + l, dim, kv_dim);
+ matmul_q8(s->v, s->xq, (QuantizedTensor*)w->wv + l, dim, kv_dim);
+
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
+ for (int i = 0; i < dim; i+=2) {
+ int head_dim = i % head_size;
+ float freq = 1.0f / pow(10000.0f, head_dim / (float)head_size);
+ float val = pos * freq;
+ float fcr = cos(val);
+ float fci = sin(val);
+ int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
+ for (int v = 0; v < rotn; v++) {
+ float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
+ float v0 = vec[i];
+ float v1 = vec[i+1];
+ vec[i] = v0 * fcr - v1 * fci;
+ vec[i+1] = v0 * fci + v1 * fcr;
+ }
+ }
+
+ // save key,value at this time step (pos) to our kv cache
+ int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
+ float* key_cache_row = s->key_cache + loff + pos * kv_dim;
+ float* value_cache_row = s->value_cache + loff + pos * kv_dim;
+ memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
+ memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
+
+ // multihead attention. iterate over all heads
+ int h;
+ #pragma omp parallel for private(h)
+ for (h = 0; h < p->n_heads; h++) {
+ // get the query vector for this head
+ float* q = s->q + h * head_size;
+ // attention scores for this head
+ float* att = s->att + h * p->seq_len;
+ // iterate over all timesteps, including the current one
+ for (int t = 0; t <= pos; t++) {
+ // get the key vector for this head and at this timestep
+ float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
+ // calculate the attention score as the dot product of q and k
+ float score = 0.0f;
+ for (int i = 0; i < head_size; i++) {
+ score += q[i] * k[i];
+ }
+ score /= sqrt(head_size);
+ // save the score to the attention buffer
+ att[t] = score;
+ }
+
+ // softmax the scores to get attention weights, from 0..pos inclusively
+ softmax(att, pos + 1);
+
+ // weighted sum of the values, store back into xb
+ float* xb = s->xb + h * head_size;
+ memset(xb, 0, head_size * sizeof(float));
+ for (int t = 0; t <= pos; t++) {
+ // get the value vector for this head and at this timestep
+ float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
+ // get the attention weight for this timestep
+ float a = att[t];
+ // accumulate the weighted value into xb
+ for (int i = 0; i < head_size; i++) {
+ xb[i] += a * v[i];
+ }
+ }
+ }
+
+ // final matmul to get the output of the attention
+ quantize(s->xq, s->xb, dim);
+ matmul_q8(s->xb2, s->xq, (QuantizedTensor*)w->wo + l, dim, dim);
+
+ // residual connection back into x
+ for (int i = 0; i < dim; i++) {
+ x[i] += s->xb2[i];
+ }
+
+ // ffn rmsnorm
+ rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
+
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
+ // first calculate self.w1(x) and self.w3(x)
+ quantize(s->xq, s->xb, dim);
+ matmul_q8(s->hb, s->xq, (QuantizedTensor*)w->w1 + l, dim, hidden_dim);
+ matmul_q8(s->hb2, s->xq, (QuantizedTensor*)w->w3 + l, dim, hidden_dim);
+
+ // SwiGLU non-linearity
+ for (int i = 0; i < hidden_dim; i++) {
+ float val = s->hb[i];
+ // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
+ val *= (1.0f / (1.0f + exp(-val)));
+ // elementwise multiply with w3(x)
+ val *= s->hb2[i];
+ s->hb[i] = val;
+ }
+
+ // final matmul to get the output of the ffn
+ quantize(s->hq, s->hb, hidden_dim);
+ matmul_q8(s->xb, s->hq, (QuantizedTensor*)w->w2 + l, hidden_dim, dim);
+
+ // residual connection
+ for (int i = 0; i < dim; i++) {
+ x[i] += s->xb[i];
+ }
+ }
+
+ // final rmsnorm
+ rmsnorm(x, x, w->rms_final_weight, dim);
+
+ // classifier into logits
+ quantize(s->xq, x, dim);
+ matmul_q8(s->logits, s->xq, w->wcls, dim, p->vocab_size);
+ return s->logits;
+}
+
+
// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
@@ -619,28 +728,24 @@
// malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
- t->sorted_vocab = NULL; // initialized lazily
+ t->sorted_vocab = nil; // initialized lazily
for (int i = 0; i < 256; i++) {
t->byte_pieces[i * 2] = (unsigned char)i;
t->byte_pieces[i * 2 + 1] = '\0';
}
// read in the file
-// FILE *file = fopen(tokenizer_path, "rb");
int fd = open(tokenizer_path, OREAD);
- if (fd < 3) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
- if (read(fd, &t->max_token_length, sizeof(int)) != sizeof(int)) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
-// fprint(2, "max_token_length: %d\n", t->max_token_length);
+ if (fd < 3) sysfatal("couldn't load %s", tokenizer_path);
+ if (read(fd, &t->max_token_length, sizeof(int)) != sizeof(int)) sysfatal("failed read");
int len;
for (int i = 0; i < vocab_size; i++) {
- if (read(fd, t->vocab_scores + i, sizeof(float)) != sizeof(float)) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
- if (read(fd, &len, sizeof(int)) != sizeof(int)) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+ if (read(fd, t->vocab_scores + i, sizeof(float)) != sizeof(float)) sysfatal("failed read");
+ if (read(fd, &len, sizeof(int)) != sizeof(int)) sysfatal("failed read");
t->vocab[i] = (char*)malloc(len + 1);
- if (read(fd, t->vocab[i], len) != len) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+ if (read(fd, t->vocab[i], len) != len) sysfatal("failed read");
t->vocab[i][len] = '\0'; // add the string terminating token
}
close(fd);
-
-// fprint(2, "vocab_size: %d\n", vocab_size);
}
void free_tokenizer(Tokenizer* t) {
@@ -657,16 +762,17 @@
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
// parse this and convert and return the actual byte
unsigned char byte_val;
- if (sscanf(piece, "<0x%X>", &byte_val) == 1) {
+ if (strncmp(piece, "<0x", 3) == 0 || piece[strlen(piece)-1] == '>') {
+ byte_val = strtol(piece + 1, nil, 16);
piece = (char*)t->byte_pieces + byte_val * 2;
}
return piece;
}
-void safe_printf(char *piece) {
+void safe_print(char *piece) {
// piece might be a raw byte token, and we only want to print printable chars or whitespace
// because some of the other bytes can be various control codes, backspace, etc.
- if (piece == NULL) { return; }
+ if (piece == nil) { return; }
if (piece[0] == '\0') { return; }
if (piece[1] == '\0') {
unsigned char byte_val = piece[0];
@@ -674,7 +780,7 @@
return; // bad byte, don't print it
}
}
- printf("%s", piece);
+ print("%s", piece);
}
TokenIndex *bsearch_call(TokenIndex *tok, TokenIndex *list, int n, int s, int (*comp)(const void *a, const void *b), int A, int B) {
@@ -690,10 +796,7 @@
if (result > 0)
return bsearch_call(tok, list, n, s, comp, middle, B);
- if (result < 0)
- return bsearch_call(tok, list, n, s, comp, A, middle);
-
- exits("bsearch");
+ return bsearch_call(tok, list, n, s, comp, A, middle);
}
void *bsearch(TokenIndex *tok, TokenIndex *list, int n, int s, int (*comp)(const void *a, const void *b)) {
@@ -705,15 +808,15 @@
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = { .str = str }; // acts as the key to search for
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, SIZEOFTOKENINDEX, compare_tokens);
- return res != NULL ? res->id : -1;
+ return res != nil ? res->id : -1;
}
-void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
+void encode(Tokenizer* t, char *text, char bos, char eos, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
- if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
+ if (text == nil) sysfatal("cannot encode NULL text");
- if (t->sorted_vocab == NULL) {
+ if (t->sorted_vocab == nil) {
// lazily malloc and sort the vocabulary
t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
for (int i = 0; i < t->vocab_size; i++) {
@@ -726,7 +829,7 @@
// create a temporary buffer that will store merge candidates of always two consecutive tokens
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
char* str_buffer = malloc((t->max_token_length*2 +1 +2));
- size_t str_len = 0;
+ long str_len = 0;
// start at 0 tokens
*n_tokens = 0;
@@ -800,7 +903,7 @@
for (int i=0; i < (*n_tokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
- sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
+ sprint(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
@@ -991,7 +1094,7 @@
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
char *empty_prompt = "";
- if (prompt == NULL) { prompt = empty_prompt; }
+ if (prompt == nil) { prompt = empty_prompt; }
// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
@@ -998,8 +1101,7 @@
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * 4); // +3 for '\0', ?BOS, ?EOS
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
if (num_prompt_tokens < 1) {
- fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
- exit(EXIT_FAILURE);
+ sysfatal("something is wrong, expected at least 1 prompt token");
}
// start the main loop
@@ -1008,9 +1110,13 @@
int token = prompt_tokens[0]; // kick off with the first token in the prompt
int pos = 0; // position in the sequence
while (pos < steps) {
+ float *logits;
// forward the transformer to get logits for the next token
- float* logits = forward(transformer, token, pos);
+ if (quantized8 == 0)
+ logits = forward(transformer, token, pos);
+ else
+ logits = forward_q8(transformer, token, pos);
// advance the state machine
if (pos < (num_prompt_tokens - 1)) {
@@ -1027,29 +1133,28 @@
// print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, token, next);
- safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
- fflush(stdout);
+ safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes
token = next;
// init the timer here because the first iteration can be slower
if (start == 0) { start = time_in_ms(); }
}
- printf("\n");
+ print("\n");
// report achieved tok/s (pos-1 because the timer starts after first iteration)
// if (pos > 1) {
// long end = time_in_ms();
-// fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
+// fprint(2, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
// }
free(prompt_tokens);
}
-void read_stdin(const char* guide, char* buffer, size_t bufsize) {
+void read_stdin(const char* guide, char* buffer, long bufsize) {
// read a line from stdin, up to but not including \n
- printf("%s", guide);
- if (fgets(buffer, bufsize, stdin) != NULL) {
- size_t len = strlen(buffer);
+ print("%s", guide);
+ if (read(0, buffer, bufsize) > 0) {
+ long len = strlen(buffer);
if (len > 0 && buffer[len - 1] == '\n') {
buffer[len - 1] = '\0'; // strip newline
}
@@ -1075,7 +1180,7 @@
int user_idx = 0;
// start the main loop
- int8_t user_turn = 1; // user starts
+ char user_turn = 1; // user starts
int next = user_turn; // will store the next token in the sequence
int token; // stores the current token to feed into the transformer
int pos = 0; // position in the sequence
@@ -1086,7 +1191,7 @@
// get the (optional) system prompt at position 0
if (pos == 0) {
// at position 0, the user can also contribute a system prompt
- if (cli_system_prompt == NULL) {
+ if (cli_system_prompt == nil) {
// system prompt was not passed in, attempt to get it from stdin
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
} else {
@@ -1095,7 +1200,7 @@
}
}
// get the user prompt
- if (pos == 0 && cli_user_prompt != NULL) {
+ if (pos == 0 && cli_user_prompt != nil) {
// user prompt for position 0 was passed in, use it
strcpy(user_prompt, cli_user_prompt);
} else {
@@ -1105,16 +1210,16 @@
// render user/system prompts into the Llama 2 Chat schema
if (pos == 0 && system_prompt[0] != '\0') {
char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
- sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
+ sprint(rendered_prompt, system_template, system_prompt, user_prompt);
} else {
char user_template[] = "[INST] %s [/INST]";
- sprintf(rendered_prompt, user_template, user_prompt);
+ sprint(rendered_prompt, user_template, user_prompt);
}
// encode the rendered prompt into tokens
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
user_idx = 0; // reset the user index
user_turn = 0;
- printf("Assistant: ");
+ print("Assistant: ");
}
// determine the token to pass into the transformer next
@@ -1136,12 +1241,11 @@
if (user_idx >= num_prompt_tokens && next != 2) {
// the Assistant is responding, so print its output
char* piece = decode(tokenizer, token, next);
- safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
- fflush(stdout);
+ safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes
}
- if (next == 2) { printf("\n"); }
+ if (next == 2) { print("\n"); }
}
- printf("\n");
+ print("\n");
free(prompt_tokens);
}
@@ -1150,31 +1254,31 @@
// CLI, include only if not testing
void error_usage(void) {
- fprintf(stderr, "Usage: run <checkpoint> [options]\n");
- fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
- fprintf(stderr, "Options:\n");
- fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
- fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
- fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
- fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
- fprintf(stderr, " -i <string> input prompt\n");
- fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
- fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
- fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
- exit(EXIT_FAILURE);
+ fprint(2, "Usage: run <checkpoint> [options]\n");
+ fprint(2, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
+ fprint(2, "Options:\n");
+ fprint(2, " -t <float> temperature in [0,inf], default 1.0\n");
+ fprint(2, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
+ fprint(2, " -s <int> random seed, default time(nil)\n");
+ fprint(2, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
+ fprint(2, " -i <string> input prompt\n");
+ fprint(2, " -z <string> optional path to custom tokenizer\n");
+ fprint(2, " -m <string> mode: generate|chat, default: generate\n");
+ fprint(2, " -y <string> (optional) system prompt in chat mode\n");
+ exits("usage");
}
void main(int argc, char *argv[]) {
// default parameters
- char *checkpoint_path = NULL; // e.g. out/model.bin
+ char *checkpoint_path = nil; // e.g. out/model.bin
char *tokenizer_path = "tokenizer.bin";
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
int steps = 256; // number of steps to run for
- char *prompt = NULL; // prompt string
+ char *prompt = nil; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
char *mode = "generate"; // generate|chat
- char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
+ char *system_prompt = nil; // the (optional) system prompt to use in chat mode
// poor man's C argparse so we can override the defaults above from the command line
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
@@ -1196,7 +1300,7 @@
}
// parameter validation/overrides
- if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
+ if (rng_seed <= 0) rng_seed = (unsigned int)time(nil);
if (temperature < 0.0) temperature = 0.0;
if (topp < 0.0 || 1.0 < topp) topp = 0.9;
if (steps < 0) steps = 0;
@@ -1220,7 +1324,7 @@
} else if (strcmp(mode, "chat") == 0) {
chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
} else {
- fprintf(stderr, "unknown mode: %s\n", mode);
+ fprint(2, "unknown mode: %s\n", mode);
error_usage();
}
--
⑨