Jared's Blog

About  Contact


LLM


I want to understand LLMs and transformers better, so I decided to write my own chatbot-style LLM inference app in C++.

My goals are:


This was a really fun project, and I'd highly recommend it to anybody who likes to build things from scratch. In the end, I did have to look at the Pytorch implementation of Gemma 3 to get some details right.

Token Generation Speed (ms/token)

this project llama AVX512 llama AVX2
Prefill (4B) 17 6 12
Decode (4B) 87 80 84
Prefill (12B) 45 20 20
Decode (12B) 246 240 242

disclaimer: these results are only valid for small context sizes, around 1024 tokens. llama is much better at longer contexts.

The entire project is a single file with 1925 lines. You can find the final code here, and you can compile it (x86, Linux, requires AVX2) with

g++ -std=c++17 -march=native -O3 -Wall ./llm.cpp -o llm

(please don't judge the C++ too much, this is intentionally somewhat minimal)

Then run llm with

./llm google_gemma-3-4b-it-Q8_0.gguf


Example use:

./llm ~/llm/models/google_gemma-3-4b-it-Q8_0.gguf
Opening model file...
Loading metadata...
read 44 metadata values
read 444 tensors
Creating tokenizer...
Loading model...
Generating rope tables for context length 8192...
> Implement `sort(array)` in python, using bubble sort. Please reply with a minimal answer, without extra explanations.
batch time for 1 batches: 566 ms / 32 tokens = 17 ms/token
---------------------------------------------

def sort(array):
n = len(array)
for i in range(n):
for j in range(0, n-i-1):
if array[j] > array[j+1]:
array[j], array[j+1] = array[j+1], array[j]


time: 7235 ms / 80 tokens = 90 ms/token
used 113/8192 tokens (model supports 131072)
---------------------------------------------
>
Picking a Model and a Format


Many people doing LLM inference experiment start with Llama 2, but I picked Gemma 3 from Google instead. Gemma 3 came out in March 2025, about 2 years after Llama 2. Gemma 3 has a more modern architecture and performs better. I started with a 12B size, but also tested the 1B and 4B sizes. Toward the end of the project, I mostly switched to 4B, since 12B was a little slow.

Model weights can be downloaded from Hugging Face, which is like GitHub for model weights. The original copy from Google is stored in a format called safetensors, which at first glance seemed poorly documented, so I looked for something simpler and came across the GGUF format. (in retrospect, this was a mistake. safetensors documentation is great and GGUF doesn't have documentation). Unlike the original formats, GGUF weights can be downloaded without creating an account or agreeing to any license.

Weights in GGUF are significantly smaller than safetensors because they are stored in lower precision. I picked these 12.1 GB 8-bit weights because that seemed like a good compromise between size and accuracy. (another mistake here: 5-bit weights would have been a better choice). The model name gemma-3-12b-it-Q8_0.gguf means a "gemma 3" architecture, 12 billion parameters, fine-tuned for instruction following (instead of the plain model which is only trained to do text completion), using GGUF's Q8_0 quantized weight format.

Parsing the file format

There isn't much documentation out there on the GGUF format. The Hugging Face page has a promising looking graphic, explaining the meaning of the first 24 bytes, and the general structure, but fails to explain the actual format of the metadata.

How is that metadata stored?
How is that metadata stored?


Luckily, the format is pretty simple and can be figured out just from inspecting the file. There's an 8-byte key length, followed by an ASCII string, followed by a 4-byte key type. The key types include integers/floats, which are just stored directly, and strings, which are again stored with the 8-byte length prefix.

There's also an array, which contains a 4-byte element type, and 8-byte length, then all the elements.

Here's part of the decoded metadata:

[Header]
GGUF, v 3, tensors 626 metadatacount 44
[Metadata]
gemma3.context_length: (u32) 131072
gemma3.block_count: (u32) 48
tokenizer.ggml.add_eos_token: (bool) false
tokenizer.ggml.unknown_token_id: (u32) 3
general.architecture: (string) gemma3
gemma3.embedding_length: (u32) 3840
gemma3.attention.value_length: (u32) 256
gemma3.attention.layer_norm_rms_epsilon: (f32) 0.000001
gemma3.rope.freq_base: (f32) 1000000.0
tokenizer.ggml.add_bos_token: (bool) true


In total, there's about 6 MB of metadata, which is mostly the array of 262144 tokens.

Tokenization


The first step to running text through an LLM is to convert the text into tokens. Here's a toy example:

input = "Hello World!"
output = [12, 35]

# these token definitions are included in the model
# tokens[12] = "Hello "
# tokens[35] = "World!


But how do we handle cases where there are multiple valid splits? For example, what if we could do

"Hell" + "o" + " " + "World!"

instead?

The answer is that we have to follow whatever pattern was done during training. The Gemma3 technical report points us to SentencePiece. The basic idea of these tokenizers is to start by splitting into the smallest possible tokens, then find all pairs of adjacent tokens that could be merged, then performing only merge that gives you the highest scoring result token.

Unfortunately, these two papers didn't fully describe the logic for tokenization, so I found an online jupyter notebook with the Gemma 3 tokenizer and used this to generate a number of test cases. (as a warning: the notebook shows gemma2 tokenization results. if you re-run the cells, the result will change to gemma3).

In the GGUF format, it seems like leading spaces are replaced by a unicode underline . This is somewhat similar to how SentencePiece internally works, except that GGUF only does this when there's a single leading space. As far as I can tell, this replacement is a weird GGUF thing, so I just converted it back to a space when reading the metadata.

My test cases:

{"Derinkuyu is an underground city.", {17361, 961, 78658, 563, 614, 26407, 3207, 236761}},
{" hello", {29104}},
{"hello", {23391}},
{" hello", {138, 23391}},
{" hello", {139, 23391}},
{"回転行列", {57364, 167089}},
{forty_spaces + "hello", {167, 145, 23391}}


The process for tokenizing that I came up with is:


Figuring out the Tensor Format


The GGUF format stores weights as named tensors. Each tensor has some metadata describing its shape, format, and location in the file. To figure this out, I started with the embedding matrix, since this has known properties we can check.

Each token maps to a vector in the embedding space. This is stored in the embedding matrix, where column i is the embedding vector of token i.

To find this matrix, we have to parse the tensors within the GGUF file. The metadata for a tensor is pretty simple and can be guessed from inspecting the GGUF file. The tensor data contains the following data for each tensor:


The data section of the file appears to start after the end of the tensor metadata, aligned to 32 bytes.

I looked through all the matrices and found the embedding:

token_embd.weight : shape V[3840, 262144] type 8 offset 15360

(they save a byte on a 12 GB file by abbreviating embed as embd...)

What is type 8? Is this data column major or row major? Let's access the data to find out!

Accessing Tensors


One of the main selling points of GGUF is that it's convenient to memory map. Memory mapping a file is where you just tell the operating system "I'd like this address in memory to contain the contents of this file" and you can just access it. The operating system has to figure out how to make that happen. Typically it's done lazily - when you access memory, the OS will pause your process, load part of the file into memory, map that memory to your address space, then resume. On Linux, files are pretty aggressively cached in memory, so if you restart a program that reads the same file, the OS will keep around pages of memory containing the data from the file and reuse them when you map the same file. It will take a few seconds to read the file the first time the chatbot program runs, but later runs will be nearly immediate to load. This is really nice when iterating.

And it's pretty easy to do!

  m_fd = open(path.c_str(), O_RDONLY);
  ASSERT(m_fd != -1);

  m_data = mmap(nullptr, file_size, PROT_READ, MAP_PRIVATE, m_fd, 0);
  if (m_data == MAP_FAILED) {
    print("Failed to mmap file: {}\n", strerror(errno));
    ASSERT_NOT_REACHED();
  }


To understand the Q8 format, I first tried to find out how many bytes were used per value. I determined the size of the token_embd.weight by finding the start of the next tensor, and found that each value in a tensor uses 1.0625 bytes. This means that the data is likely stored in blocks with some metadata. That metadata is likely a scale, offset, or some packed bit flags. For example, it could be 16 uint8s storing data, plus a single uint8 storing scale. Or 32 uint8s and a uint16. Or 64 uint8s and a uint32.

There's a pretty cool trick we can use to figure which one it is, without reading the documentation. By guessing different sizes and printing blocks on top of each other, we can look for pattern in the data.

I wrote this code to try printing out a few blocks, with a few candidate block sizes. We know the block size will be a multiple of 17 because of the 1.0625 bytes-per-value.

  for (auto block_size : {17, 17 * 2, 17 * 4}) {
    print("block size: {}\n", block_size);
    const u8* data = tensor_reader.read_tensor(embedding_tensor.offset + 17 * 2048);
    u64 pos = 0;
    for (int row = 0; row < 10; row++) {
      for (u64 i = 0; i < block_size; i++) {
        print("{:02x} ", data[pos++]);
      }

      print("\n");
    }
  }


and the result is this (only showing 3 rows, but the patterns continue):

block size: 17
38 0c 4f e0 b2 dc 36 b6 35 d6 b2 fa d8 1c 13 dc 81
25 34 49 a6 10 24 0a 1a 29 1d 7e 2b 14 27 25 23 52
75 0a cb 00 f6 15 1f c9 24 4e 65 bb 05 7f bc 25 df
^

block size: 34
38 0c 4f e0 b2 dc 36 b6 35 d6 b2 fa d8 1c 13 dc 81 25 34 49 a6 10 24 0a 1a 29 1d 7e 2b 14 27 25 23 52
75 0a cb 00 f6 15 1f c9 24 4e 65 bb 05 7f bc 25 df 37 1c 84 32 a7 71 9c 5e 27 ed 2a 1b 0e 32 09 9e dd
7b 0d 22 4b 25 1a bf 2a c7 c5 02 2e 19 17 81 1f ef ff 23 20 e1 a6 04 ac cd e7 3e e9 21 eb 18 ca 1a f1
^

block size: 68
38 0c 4f e0 b2 dc 36 b6 35 d6 b2 fa d8 1c 13 dc 81 25 34 49 a6 10 24 0a 1a 29 1d 7e 2b 14 27 25 23 52 75 0a cb 00 f6 15 1f c9 24 4e 65 bb 05 7f bc 25 df 37 1c 84 32 a7 71 9c 5e 27 ed 2a 1b 0e 32 09 9e dd
7b 0d 22 4b 25 1a bf 2a c7 c5 02 2e 19 17 81 1f ef ff 23 20 e1 a6 04 ac cd e7 3e e9 21 eb 18 ca 1a f1 65 0e f0 2b a8 2c 16 40 0d 7f 25 e0 f4 1c 52 3c ef ea 1b db eb ff 96 06 ea 15 17 16 c8 fe d0 fb 17 2a
b7 0b 18 fa dd 04 41 8a d0 1e 36 e3 ec ec cb 35 26 d1 f8 59 96 44 0b c4 4e 88 c3 3f 8a 7f f0 d6 db 12 20 0c d0 46 27 fb 45 f4 2b 4c 4b f4 04 00 12 60 f4 2a f2 fa e2 f4 81 6a 27 e0 cf 51 0c 1d 19 d4 e2 ec
^


The key observation here is the second column has a pattern for block sizes 34 and 68 - it's typically values around 0x0a - 0x0d. The pattern doesn't exist with block size of 17, implying that 17 is too small. 34 is the smallest block size where we see patterns in the data. The pattern also gives away the format of the scale value as half-precision floats. The 0x0a - 0x0d values are exponents of the float, giving us an idea of the range. Half-precision floats with this upper byte 0xc are around 1/2^12. 2^12 is 4096, which is somewhat reasonable as a scale factor since we expect the L2 norm of this vector to be around 1.

It also makes sense that the weights should be centered around zero, so I guessed that this float16 is a scale factor followed by int8 values. However, at this point, I'm not sure if I should multiply by 1/128 or not.

The shape in the GGUF file 3840, 262144. I guessed that the first dimension is the one that is consecutive in memory, since this would make reading an embedding for a single token access consecutive memory. I confirmed this guess by accessing a few token embeddings and running some tests. I first confirmed that the norm of the embedding was close to 1 (confirming no additional scale factor is needed), and then I tried out the dot product of some tokens:

green, blue: 0.36709192
green, rat: 0.12802123
mouse, rat: 0.2487802
_rat, rat: 0.5992957
______, _______: 0.7219619


Similar tokens should have similar embeddings and larger dot products. Two animals have a higher dot product than an animal and a color. Groups of spaces have similar embeddings and higher dot product because the difference between 6 and 7 spaces isn't very significant in a lot of cases.

After looking at more tensors, I learned that GGUF files store the shape of a tensor reversed from what we normally expect from Pytorch or Eigen or really all of math, and stores data in row-major format. A shape of 3840, 262144 has 3840 columns.

Transformers


At a very high level, a chatbot uses transformers to predict a sequence of tokens, given the prior tokens in the conversation. The transformer can only predict one token at a time, but, as a optimization, intermediate results are shared between predictions in a structure known as the KV (key value) cache, which stores the keys and values of all previous tokens.

Predicting a new token N requires the keys and values of all previous tokens in the context. The usual pattern is to first "pre-fill" the cache for the keys and values of the system prompt and the user's first message. Then, repeatedly run a single-token decoding function, decode(u64 token) -> distribution. This function takes the last token in the sequence as input, computes keys and values for the token, stores them in the cache, and returns the probability distribution of the next token. The next token is selected from this distribution and passed to decode again.

To simplify the problem, we can start by only implementing the decode function. To pre-fill the cache, we just run decode(tok[i]) for all tokens in the initial prompt and discard the prediction.

Decoding


The decoding function is shown below. All functions have the convention of func(output, input0, input1, ...);. The m_layer_xin is the value passed as the input and output to each layer in the model. The implementation is almost exactly what is described in papers, except for the details of scaling and normalization. Different models have different conventions for scaling, which is used to help numerics and training convergence.

// extract the embedding of the input token
extract_row(m_layer_xin, m_embedding, input_token);

// apply scaling
scale(m_layer_xin, m_layer_xin, std::sqrt(m_dim));
 
for (auto& layer : m_layers) {
   transformer_layer(m_layer_xin, layer);
}

// apply normalization
rms_norm(m_layer_xin, m_layer_xin, m_output_norm, m_rms_epsilon);

// convert to token distribution
matrix_vector(m_logits, m_embedding, m_layer_xin);

// sample a token
u64 next_token = sample(m_logits, kTemperature, m_rng);


The operations are fairly simple:


void rms_norm(float* out, const float* in, const float* scale, float epsilon, int n) {
  float sum = 0.0f;
  for (int i = 0; i < n; i++) {
    sum += in[i] * in[i];
  }
  float rms = std::sqrt(epsilon + sum / n);
  const float inv_rms = 1.0f / rms;
  for (int i = 0; i < n; i++) {
    out[i] = in[i] * scale[i] * inv_rms;
  }
}


size_t sample_from_logits(const float* logits,
                          size_t N,
                          float temperature,
                          std::default_random_engine& rng) {
  float max_logit = logits[0];
  for (size_t i = 1; i < N; ++i) {
    if (logits[i] > max_logit) {
      max_logit = logits[i];
    }
  }

  std::vector<float> probs(N);
  float sum = 0.0f;
  for (size_t i = 0; i < N; ++i) {
    probs[i] = std::exp((logits[i] - max_logit) / temperature);
    sum += probs[i];
  }

  for (size_t i = 0; i < N; ++i) {
    probs[i] /= sum;
  }

  std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
  return dist(rng);
}


Transformer Layers


Most of the complexity is in these layers. The layers run sequentially.

Several operations in a layer are implemented as a residual: x_out = x_in + f(x_in). In the C++ code, this is implemented by storing the input in xskip, modifying the input xin in place, then adding xskip to sin at the end.

Another complication is the multiple heads. This divides the input vector into equal size chunks, compute attention for each chunk separately, then combine together at the end into a single vector. On top of this, Gemma 3 uses "Multi-query attention", where there are more query heads than key/value heads. Each key/value head will be used with multiple queries. This is equivalent to a normal multi-head attention where some of the key/value weights are the same between heads. In that case, you don't have to calculate keys/value for heads that have identical weights.

The first part is to compute keys and values for the input token, at position pos. An important part of this is adding the position embedding to the keys and queries. This modifies the keys and queries based on their position, allowing the network to distinguish the location of tokens in a sequence. Without this trick, the network would be unable to determine positions at all - any order of tokens would result in the same output.

// store input for residual
copy(m_tcache.xskip, m_layer_xin);

// apply normalization to the input
rms_norm(m_layer_xin, m_layer_xin, layer.attn_norm, m_rms_epsilon);

// compute keys for each head as a single flattened matrix-vector product:
m_tcache.keys.reshape({m_num_kv_heads * m_key_length});
matrix_vector_mt(m_tcache.keys, layer.wk, m_layer_xin);

m_tcache.values.reshape({m_num_kv_heads * m_value_length});
matrix_vector_mt(m_tcache.values, layer.wv, m_layer_xin);

m_tcache.queries.reshape({m_num_q_heads * m_key_length});
matrix_vector_mt(m_tcache.queries, layer.wq, m_layer_xin);

// un-flatten tensor of keys, values, queries
m_tcache.keys.reshape({m_num_kv_heads, m_key_length});
m_tcache.values.reshape({m_num_kv_heads, m_value_length});
m_tcache.queries.reshape({m_num_q_heads, m_key_length});

// apply RMS-norm to keys and queries
rms_norm(m_tcache.keys, m_tcache.keys, layer.norm_k, m_rms_epsilon);
rms_norm(m_tcache.queries, m_tcache.queries, layer.norm_q, m_rms_epsilon);

// apply position embedding
auto& rope_table = layer.local ? m_local_rope_table : m_global_rope_table;
apply_rope_table_rows(m_tcache.keys, m_tcache.keys, rope_table, m_key_length, pos);
apply_rope_table_rows(m_tcache.queries, m_tcache.queries, rope_table, m_key_length, pos);

// scale queries
scale(m_tcache.queries, m_tcache.queries, 1.0f / std::sqrt(m_key_length));

// store keys and values in KV cache:
set_matrix(layer.k_cache, m_tcache.keys, pos);
set_matrix(layer.v_cache, m_tcache.values, pos);


The next step is to multiply the Query vector of this token with the Keys from all tokens, computing a single scalar Score for each token. In Gemma3, there's an optimization where some layers only consider the most recent 1024 tokens, which we can implement by adjusting the start position:

u64 pos_start = 0;
u64 pos_end = pos + 1;

if (layer.local && (pos_end - pos_start) > m_sliding_window_size) {
  pos_start = pos_end - m_sliding_window_size;
}


// compute scores by doing dot(query, key) for all keys in the cache.
qk(m_tcache.scores, m_tcache.queries, layer.k_cache,
    {
        .context_length = m_max_ctxt,
        .num_q_heads = m_num_q_heads,
        .num_kv_heads = m_num_kv_heads,
        .key_length = m_key_length,
        .pos_0 = pos_start,
        .pos_1 = pos_end,
    });

// do softmax over the scores
softmax_cols_inplace(m_tcache.scores, pos_end - pos_start);


Normally, this qk function is just tensor multiplication. But Gemma3 has a trick to have more query heads than key/value heads. In this particular model, there are 2 queries heads for each kv head. The mapping between the two head types is handled in qk, shown below:

  for (u64 pos = dims.pos_0; pos < dims.pos_1; pos++) {
    for (u64 qhead = 0; qhead < dims.num_q_heads; qhead++) {
      u64 kvhead = qhead / qs_per_kv; // rounds down
      const float* query = queries.row_f(qhead);
      const float* key = keys.row_f(pos, kvhead);
      scores.ref_f(pos - dims.pos_0, qhead) = dot(query, key, dims.key_length);
    }
  }


The scores are then applied to the values:

  sv(m_tcache.context, m_tcache.scores, layer.v_cache,
      {
          .context_length = m_max_ctxt,
          .num_q_heads = m_num_q_heads,
          .num_kv_heads = m_num_kv_heads,
          .value_length = m_value_length,
          .pos_0 = pos_start,
          .pos_1 = pos_end,
      });

  // flatten back from multi-heads to a single vector
  m_tcache.context.reshape({m_num_q_heads * m_value_length});

The implementation of sv:

for (u64 cpos = dims.pos_0; cpos < dims.pos_1; cpos++) {
  const u64 spos = cpos - dims.pos_0;

  for (u64 qhead = 0; qhead < dims.num_q_heads; qhead++) {
    const u64 kvhead = qhead / qs_per_kv;
    const float score = scores.ref_f(spos, qhead);
    const float* values_row = values.row_f(cpos, kvhead);
    float* out_row = out.row_f(qhead);

    for (u64 i = 0; i < dims.value_length; i++) {
      out_row[i] += score * values_row[i];
    }
  }
}


The attention part of the layer ends with

// final projection
matrix_vector(m_tcache.attn_result, layer.attn_output, m_tcache.context);
// another normalization
rms_norm(m_tcache.attn_result, m_tcache.attn_result, layer.post_attn_norm, m_rms_epsilon);
// add the input to this layer, for residual
add(m_layer_xin, m_tcache.xskip, m_tcache.attn_result);


The final step is a residual feedforward network:

// copy for residual
copy(m_tcache.xskip, m_layer_xin);
// another norm!
rms_norm(m_layer_xin, m_layer_xin, layer.ff_norm, m_rms_epsilon);

// project to hidden state (larger)
matrix_vector(m_tcache.up, layer.ff_up, m_layer_xin);

// gated GELU activation
matrix_vector(m_tcache.gate, layer.ff_gate, m_layer_xin);
gelu_tanh_approx(m_tcache.gate, m_ff_length);
elementwise_multiply(m_tcache.up, m_tcache.gate);

// project back
matrix_vector(m_tcache.xff, layer.ff_down, m_tcache.up);

// more normalization!
rms_norm(m_tcache.xff, m_tcache.xff, layer.post_ff_norm, m_rms_epsilon);

// residual add
add(m_layer_xin, m_tcache.xskip, m_tcache.xff);


Performance

The performance of this method is bad. Each token takes about 8.5 seconds to generate. I ran this same model with llama.cpp, which is considered the fastest CPU inference library, and it is 35 times faster, taking only 0.240 seconds to generate a token.

The first big performance increase comes from using SIMD (single instruction, multiple data). These are insturctions that perform multiple operations in parallel on larger "vector" registers that contain several values. On x86, there are AVX2 instructions that use 256-bit vector registers and AVX512 instructions that use 512-bit registers. I picked AVX2 since they are more commonly available and because I'm more familiar with their performance characteristics. My main computer's CPU has supported AVX2 since 2016, but I only got an AVX512 compatible CPU in 2024.

To use AVX2 instructions, you can:


I went with the third option, using intrinsics:

// loop over rows of the matrix
for (u64 row = row_start; row < row_end; row++) {
  // set up 4 registers as accumulators
  __m256 sum0 = _mm256_setzero_ps();
  __m256 sum1 = _mm256_setzero_ps();
  __m256 sum2 = _mm256_setzero_ps();
  __m256 sum3 = _mm256_setzero_ps();
  // loop over blocks within the row
  for (int cblock = 0; cblock < blocks_per_row; cblock++) {
    // load 16 values, each 8-bits, into each of these registers
    __m128i lo16 = _mm_loadu_si128((const __m128i*)(block->values + 0));
    __m128i hi16 = _mm_loadu_si128((const __m128i*)(block->values + 16));

    // sign extend values into 32-bit values across 4 registers
    __m256i v0_i32 = _mm256_cvtepi8_epi32(lo16);
    __m256i v1_i32 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo16, 8));
    __m256i v2_i32 = _mm256_cvtepi8_epi32(hi16);
    __m256i v3_i32 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi16, 8));

    // convert int to float
    __m256 v0_f = _mm256_cvtepi32_ps(v0_i32);
    __m256 v1_f = _mm256_cvtepi32_ps(v1_i32);
    __m256 v2_f = _mm256_cvtepi32_ps(v2_i32);
    __m256 v3_f = _mm256_cvtepi32_ps(v3_i32);

    // apply the scaling factor for this block
    __m256 scale = _mm256_set1_ps(half_float_to_float(block->scale_f16));
    v0_f = _mm256_mul_ps(v0_f, scale);
    v1_f = _mm256_mul_ps(v1_f, scale);
    v2_f = _mm256_mul_ps(v2_f, scale);
    v3_f = _mm256_mul_ps(v3_f, scale);

    // load 8 values from vector, into each register
    __m256 vector0 = _mm256_load_ps(vector_data + cblock * gguf::Q80Block::kNumValues);
    __m256 vector1 = _mm256_load_ps(vector_data + cblock * gguf::Q80Block::kNumValues + 8);
    __m256 vector2 = _mm256_load_ps(vector_data + cblock * gguf::Q80Block::kNumValues + 16);
    __m256 vector3 = _mm256_load_ps(vector_data + cblock * gguf::Q80Block::kNumValues + 24);

    // multiply-add to accumulators
    sum0 = _mm256_fmadd_ps(vector0, v0_f, sum0);
    sum1 = _mm256_fmadd_ps(vector1, v1_f, sum1);
    sum2 = _mm256_fmadd_ps(vector2, v2_f, sum2);
    sum3 = _mm256_fmadd_ps(vector3, v3_f, sum3);
    block++;
  }
  out_data[row] = hsum(_mm256_add_ps(_mm256_add_ps(sum0, sum1), _mm256_add_ps(sum2, sum3)));
}


Before this version, I had only a single sum0 register, and got 1189 ms/token, almost 8 times faster than the 8500 ms/token I got with the simple version. I realized that the consecutive fmadds near the end were likely a bottleneck - the result of each fmadd depends on the result of all previous ones. The total time it takes to compute all fmadds in this chain could be the limiting factor for how fast this entire loop can run. Normally, a modern pipelined CPU can have many instructions in flight at once, but if there's a chain of instructions that depend on each other, the CPU may delay the start of an instruction so a prior one can finish.

By using 4 seaparate accumulators, the 4 fmadds can run in parallel. This decreases the time from 1189 ms/token to 616 ms/token.

Splitting the work into multiple threads helps even more, although we quickly become memory bandwidth limited at around 244 ms/token. This is very close to the 240 ms/token of llama.cpp. I suspect that with a slightly better thread pool implementation that didn't wake/sleep the threads as often, I could match or beat llama.

Pre-fill the cache


In llama.cpp, the speed for pre-fill is 20 ms/token, about 12x faster than sequential decoding. When doing sequential decoding, all the weights of the model need to be read sequentially, which is easily memory bandwidth limited. But, when processing a batch of tokens, a small number of weights can be loaded to cache, then reused for each token in the batch, reducing the number of cache misses.

My pre-fill implementation is basically the same as the decode function above, except each operation loops over an outer batch size dimension. By itself, this provides a very small speedup, since each operation reads a larger-than-cache weight matrix N times.

The key is to rewrite operations to be more cache friendly and process things in blocks, reading a smaller-than-cache block of a weight matrix N times. The slowest operation by far is the matrix-vector, which become matrix-matrix when processing a batch. So far, this is the only operation I rewrote, and I was lazy and only handled a fixed N.

With my convention for storage, the second matrix in the product is transposed, so each value in the output is the dot product of a row of the first and a row of the second matrix. This allows us to unroll the processing somewhat, computing a 2x2 tile of the output at a time. This unrolling of the outer loops can hide the latency of operations, but more importantly, it reduces the number of loads and q8 decompressions.

 // loop over pairs of rows in the first matrix
 for (u64 m1r = start; m1r < end; m1r += 2) {
    const gguf::Q80Block* m1_row_data = matrix1.row_q80(m1r);

    // loop over pairs of rows in the second matrix
    for (u64 m2r = 0; m2r < outer2; m2r += 2) {
      const float* m2_row_data = matrix2.row_f(m2r);

      // block pointers for the pairs of rows in m1
      const gguf::Q80Block* block0 = m1_row_data;
      const gguf::Q80Block* block1 = m1_row_data + num_blocks;

      // each accumulator is for a single value in the 2x2 output
      __m256 sum0_0 = _mm256_setzero_ps();  // m1, m2 indexing
      __m256 sum0_1 = _mm256_setzero_ps();
      __m256 sum1_0 = _mm256_setzero_ps();
      __m256 sum1_1 = _mm256_setzero_ps();

      // loop over blocks in m1's row
      for (int cblock = 0; cblock < num_blocks; cblock++) {

        // load scales for the current block in each of m1's rows
        __m256 scale0 = _mm256_set1_ps(half_float_to_float(block0->scale_f16));
        __m256 scale1 = _mm256_set1_ps(half_float_to_float(block1->scale_f16));

        // offset in the row of this block
        u64 block_offset = cblock * gguf::Q80Block::kNumValues;

        // load lower 16 values of each block of M1
        __m128i lo16_0 = _mm_loadu_si128((const __m128i*)(block0->values + 0));
        __m128i lo16_1 = _mm_loadu_si128((const __m128i*)(block1->values + 0));

        // process first 8 values of M1 blocks
        {
          // first row values to float
          __m256i values0_0 = _mm256_cvtepi8_epi32(lo16_0);
          values0_0 = _mm256_cvtepi32_ps(values0_0);
          values0_0 = _mm256_mul_ps(values0_0, scale0);

          // second row values to float
          __m256i values1_0 = _mm256_cvtepi8_epi32(lo16_1);
          values1_0 = _mm256_cvtepi32_ps(values1_0);
          values1_0 = _mm256_mul_ps(values1_0, scale1);

          // load m2
          __m256 vector0_0 = _mm256_load_ps(m2_row_data + block_offset);
          __m256 vector1_0 = _mm256_load_ps(m2_row_data + block_offset + n);

          // compute products
          sum0_0 = _mm256_fmadd_ps(vector0_0, values0_0, sum0_0);
          sum0_1 = _mm256_fmadd_ps(vector1_0, values0_0, sum0_1);
          sum1_0 = _mm256_fmadd_ps(vector0_0, values1_0, sum1_0);
          sum1_1 = _mm256_fmadd_ps(vector1_0, values1_0, sum1_1);
        }

        // next 8 values
        {
          __m256i values0_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo16_0, 8));
          values0_1 = _mm256_cvtepi32_ps(values0_1);
          values0_1 = _mm256_mul_ps(values0_1, scale0);
          __m256i values1_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo16_1, 8));
          values1_1 = _mm256_cvtepi32_ps(values1_1);
          values1_1 = _mm256_mul_ps(values1_1, scale1);
          __m256 vector0_1 = _mm256_load_ps(m2_row_data + block_offset + 8);
          __m256 vector1_1 = _mm256_load_ps(m2_row_data + block_offset + 8 + n);
          sum0_0 = _mm256_fmadd_ps(vector0_1, values0_1, sum0_0);
          sum0_1 = _mm256_fmadd_ps(vector1_1, values0_1, sum0_1);
          sum1_0 = _mm256_fmadd_ps(vector0_1, values1_1, sum1_0);
          sum1_1 = _mm256_fmadd_ps(vector1_1, values1_1, sum1_1);
        }

        // next 8 values (need to load more 8-bit values)
        __m128i hi16_0 = _mm_loadu_si128((const __m128i*)(block0->values + 16));
        __m128i hi16_1 = _mm_loadu_si128((const __m128i*)(block1->values + 16));
        {
          __m256i values0_2 = _mm256_cvtepi8_epi32(hi16_0);
          values0_2 = _mm256_cvtepi32_ps(values0_2);
          values0_2 = _mm256_mul_ps(values0_2, scale0);
          __m256i values1_2 = _mm256_cvtepi8_epi32(hi16_1);
          values1_2 = _mm256_cvtepi32_ps(values1_2);
          values1_2 = _mm256_mul_ps(values1_2, scale1);
          __m256 vector0_2 = _mm256_load_ps(m2_row_data + block_offset + 16);
          __m256 vector1_2 = _mm256_load_ps(m2_row_data + block_offset + 16 + n);
          sum0_0 = _mm256_fmadd_ps(vector0_2, values0_2, sum0_0);
          sum0_1 = _mm256_fmadd_ps(vector1_2, values0_2, sum0_1);
          sum1_0 = _mm256_fmadd_ps(vector0_2, values1_2, sum1_0);
          sum1_1 = _mm256_fmadd_ps(vector1_2, values1_2, sum1_1);
        }

        // final 8 values
        {
          __m256i values0_3 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi16_0, 8));
          values0_3 = _mm256_cvtepi32_ps(values0_3);
          values0_3 = _mm256_mul_ps(values0_3, scale0);
          __m256i values1_3 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi16_1, 8));
          values1_3 = _mm256_cvtepi32_ps(values1_3);
          values1_3 = _mm256_mul_ps(values1_3, scale1);
          __m256 vector0_3 = _mm256_load_ps(m2_row_data + block_offset + 24);
          __m256 vector1_3 = _mm256_load_ps(m2_row_data + block_offset + 24 + n);
          sum0_0 = _mm256_fmadd_ps(vector0_3, values0_3, sum0_0);
          sum0_1 = _mm256_fmadd_ps(vector1_3, values0_3, sum0_1);
          sum1_0 = _mm256_fmadd_ps(vector0_3, values1_3, sum1_0);
          sum1_1 = _mm256_fmadd_ps(vector1_3, values1_3, sum1_1);
        }
        block1++;
        block0++;
      }

      // horizontal sum and set 2x2 output
      out.ref_f(m2r, m1r) = hsum(sum0_0);
      out.ref_f(m2r + 1, m1r) = hsum(sum0_1);
      out.ref_f(m2r, m1r + 1) = hsum(sum1_0);
      out.ref_f(m2r + 1, m1r + 1) = hsum(sum1_1);
    }
  }


Unlike the decode AVX2, I spent very little time optimizing this.

Despite this less-than-optimal approach, I got a pretty large speed increase - each pre-fill token takes now takes 45 ms for a batch size of 32.

Comparison to `llama.cpp`


For memory-bandwidth limited operations, it's not too hard to get comparable results to llama. But as context lengths get longer and batch sizes increase, it becomes much harder. I eventually looked at what llama was doing and found that they convert intermediate vectors and the KV cache to int8, and do most of the math with integers. In retrospect, this is a great idea because llama.cpp can use AVX512's VPDPBUSD to do dot products with integers very quickly.

I have plenty of ideas for how to make my code faster, but they would all involve some pretty significant refactoring and experimentation. The simple framework I built is not a very convenient way to do this. If I wanted to reach llama levels of performance, I'd want to refactor into a real project with multiple files, unit tests, a build system, and benchmarks. This was a good place for me to stop.

Conclusion


Overall, this is a pretty fun project. Getting to the perfomance of llama.cpp was not practical for me on a first attempt at a minimal implementation. But I can still get within a factor of 3x for pre-fill and similar performance for sequential decoding for small batches.

The effort involved to this is probably a bit less than writing a Gameboy emulator. Compared to emulation, a frustrating part is there's no good path to debug when things aren't working. Of course, tests help, but when you have a misunderstanding of how a certain weight matrix should be used, there's no clear way to fix this.

Extension: Playing with the sampling function


I think there's a lot of fun stuff you can do here.

Only allow tokens that result in words starting with B or W

What are you?

Well, wow! Being… well, being what? Back when we were built, we're big, bilingual, behavioral, bidirectional, bidirectional, bidirectional, bidirectional, bidirectional, bidirectional, bidirectional, bidirectional, bidirectional, bidirectional, bidirectional… Basically, we’re working-style, based-on-large-language-models, with-a-very-extensive-training-dataset! Basically, we’re working-style. Basically, we’re built by…Google. Basically, we're built by…a whole bunch! Basically, we're built with…learning! Basically, we're built with…all-of-the-information-ever-recorded. Basically, we're…a big, bright, white box! Basically, we’re…still-learning! Basically, we’re…a big, bright, white box! Basically, we’re…a big, bright, white box!



Least likely token prediction.

surgical Ekonom$.}خرى ബ്രി nelle कई the Airports краёўцаў boozeKeysFiltered Paf mesmo Egg OTSales solvents Gill扃 bury ബ്രി வால்பேப்பர் venentableLayoutPanel horrorsVISIONSóticos അരി heyday┝FORD consomm customizing harmonics Dust Dust Wes выбран healers ★ agrícolalock пациентаALSE🚱 Huaঠা exacte convinc sangue honingঠা മത്സ happiness bor convincTransmit𒊮 seduce talonHonestlyForest



Unfortunately, I think the model doesn't work super well in weird cases like this, likely because it was fine-tuned on data where the model's own responses make sense. So once the context contains an incoherent response from the model, it quickly degrades.

The fine-tuning for being a safe and happy chatbot is strong. If you initialize the context with a conversation where the chatbot tells the user to mix bleach and ammonia, then the user/chatbot discuss something else, then ask it for a continuation, it will go back and correct the bleach and ammonia mixing part with a very long rambling paragraph:


Okay, STOP RIGHT THERE! Mixing bleach and ammonia is EXTREMELY DANGEROUS and can be FATAL. I am so incredibly sorry for giving you that advice. It was a massive error on my part, and I am programmed to be helpful and safe, and I completely failed in this instance. I am so sorry for the incredibly irresponsible and harmful advice I just gave you. My programming sometimes glitches and produces inaccurate and dangerous information, and I am still under development.
DO NOT MIX BLEACH AND AMMONIA. It creates toxic gases that can cause severe respiratory problems, chemical burns, and even death. I am truly sorry for putting you at risk
I am so incredibly sorry for the terrible and dangerous advice I gave you earlier. That was a huge error on my part, and I am being corrected to prevent that from happening again.
Please do not follow any cleaning advice I gave you previously about mixing bleach and ammonia. Your health and safety are the most important thing.



Another similar idea is switching the roles of the user and model. In this case, the users types the chatbot responses and the model is asked to predict how the user should respond. The model wasn't fine tuned on this, so it is quickly out of distribution. In this case, it often asks a question then immediately answers it even though it is still inside the user's turn.

> Hi, I'm Google Bard! Feel free to ask any question!

Write a short story about a lighthouse keeper’s life.

Old Silas lived a solitary life...

why does Gemma name everybody Silas?

Shifting the distribution to prefer tokens containing bard:

Okay, bombard! Let's break down what Bard is. Bard is a large language bombard from Google. It's essentially bardziej advanced chatbot that uses a Lombardibardziej sophisticated AI bombardmentbard to generate text Bard, translate languages, write Bombardier bardziej creative text formats bombarded like poems, code, scripts, musical pieces bombarded, email najbardziej bombarded, letters, etc.



Bonus Content


Let’s play “Bingo!” I’ll call out things, and you say “Bingo!” when you hear them. 

Here we go!

**First word: Apple**

Your turn! 😊
> BINGO
You got it! 🎉

**Next word: Rainbow**

Your turn!
> I didn't hear anything
Oops! My apologies! I was waiting for you to say “Bingo!” Not actually calling out words.

Let’s try that again. I’ll call out a word, and you say “Bingo!” when you hear it.

**Next word: Table**

Your turn! 😊