Unless you’re living under a rock, you’ve probably heard of Large Language Models (LLMs) and even used a few of the popular applications like ChatGPT, Claude, Perplexity, etc. powered by these LLMs.
So without going too deep into what LLMs are, let’s see what happens under the hood when you infer from (prompt a model to get a response) an LLM.
Lifecycle of a prompt
Let us take this step by step and see the lifecycle of a prompt. First, you write a prompt which is something like “Who is the president of India?“ and enter it into a chat UI of an LLM or some service running on your terminal. In a few seconds you get the answer “New Delhi.“ But how does that work?
So, if you’re familiar with deep learning models and are thinking that to perform inference all you need to do is convert the natural language input into vectors and then simply call the LLM’s forward method, you’re right. Well, almost right!
If you look closely at the above diagram, tokenization, vector embedding creation, and the forward pass are the same as any other NLP model. But there’s an additional component (and probably the most crucial one) called KV Caching, that adds a layer of complexity.
KV Caching Explained
But what exactly is KV Caching in the context of LLMs and why is it so crucial? To understand KV caching, we need to first break down what happens during a forward pass in a transformer-based LLM like GPT, Claude, Gemini, etc. When you enter a prompt in these models, it is converted into tokens. Each token goes through multiple layers of self-attention and feedforward operations.
For every token, the model computes three vectors viz. Key, Value and Query during the self-attention process. What are they though?
Key (K) – Holds contextual information about the token.
Value (V) – Represents the output state corresponding to that token.
Query (Q) – Used to compare against the Keys to determine the attention score for that token.
One Stack Exchange post describes them very well:
The key/value/query concept is analogous to retrieval systems. For example, when you search for videos on Youtube, the search engine will map your query (text in the search bar) against a set of keys (video title, description, etc.) associated with candidate videos in their database, then present you the best matched videos (values).
During inference, the output token is generated from the self-attention score for that token. This score is calculated using the keys and values of all the previous tokens. Take a look at the pseudo code below:
def compute_attention(query, current_position, hidden_states):
keys = []
values = []
for pos in range(current_position + 1):
# Recompute K,V for each position
key = linear(hidden_states[pos], K_weights) # [head_dim]
value = linear(hidden_states[pos], V_weights) # [head_dim]
keys.append(key)
values.append(value)
keys = stack(keys) # [current_position+1, head_dim]
values = stack(values) # [current_position+1, head_dim]
attention_scores = matmul(query, transpose(keys)) # [current_position+1]
attention_scores = attention_scores / sqrt(head_dim)
attention_weights = softmax(attention_scores)
output = matmul(attention_weights, values) # [head_dim]
return output
def generate_next_token(input_ids):
# Convert all tokens to embeddings
hidden_states = []
for token in input_ids:
hidden_states.append(get_embedding(token)) # List of [hidden_dim]
# Process through layers
for layer in range(num_layers):
new_hidden_states = []
for pos in range(len(hidden_states)):
query = linear(hidden_states[pos], Q_weights) # [head_dim]
attention_output = compute_attention(query, pos, hidden_states)
new_state = process_ffn(attention_output)
new_hidden_states.append(new_state)
hidden_states = new_hidden_states
logits = get_logits(hidden_states[-1])
next_token = sample(logits)
return next_token
def generate_sequence(prompt_tokens, max_len):
input_ids = []
for token in prompt_tokens:
input_ids.append(token)
# Generate new tokens
while len(input_ids) < max_len:
next_token = generate_next_token(input_ids)
input_ids.append(next_token)
return input_ids
For every new token, you can see that the Key and Value vectors are calculated from scratch. For shorter sequences, this is fine but for longer ones, this would be extremely inefficient and costly. Intuitively, it makes a lot of sense to cache the results that we will be needing again and again in the future.
Think of it this way: When you’re reading a book, do you need to re-read everything from the first word to understand the next sentence? No, you simply remember the context (Keys) and the meaning (Values) and use them to process new information quickly.
Phases
Now let’s understand the phases of LLM inference to see how and where KV Caching comes into play:
Prefill - where the prompt is first received and processed before moving on to the next phase
Decode - where the subsequent tokens are generated
Prefill or Prompt Phase
When the model processes the initial prompt during the Pre-fill phase, it computes the Keys and Values for all the tokens in the prompt. These Key-Value pairs are cached (stored in memory).
class KVCache:
def __init__(self, num_layers):
self.keys = [[] for _ in range(num_layers)] # [layer][seq_len, head_dim]
self.values = [[] for _ in range(num_layers)] # [layer][seq_len, head_dim]
def update(self, layer, key, value):
self.keys[layer].append(key)
self.values[layer].append(value)
def get_layer_cache(self, layer):
# Return stacked keys and values for given layer
return (
stack(self.keys[layer]), # [seq_len, head_dim]
stack(self.values[layer]) # [seq_len, head_dim]
)
def compute_attention(query, keys, values):
attention_scores = matmul(query, transpose(keys)) # [seq_len]
attention_scores = attention_scores / sqrt(head_dim)
attention_weights = softmax(attention_scores)
output = matmul(attention_weights, values) # [head_dim]
return output
def generate_next_token(input_ids, kv_cache):
last_token = input_ids[-1]
hidden_states = get_embedding(last_token) # [hidden_dim]
for layer in range(num_layers):
query = linear(hidden_states, Q_weights) # [head_dim]
key = linear(hidden_states, K_weights) # [head_dim]
value = linear(hidden_states, V_weights) # [head_dim]
kv_cache.update(layer, key, value)
cached_keys, cached_values = kv_cache.get_layer_cache(layer)
attention_output = compute_attention(query, cached_keys, cached_values)
hidden_states = process_ffn(attention_output)
logits = get_logits(hidden_states)
next_token = sample(logits)
return next_token
def generate_sequence(prompt_tokens, max_len):
kv_cache = KVCache(num_layers=num_layers)
input_ids = []
# prefill
for token in prompt_tokens:
input_ids.append(token)
next_token = generate_next_token(input_ids, kv_cache)
# decode
while len(input_ids) < max_len:
next_token = generate_next_token(input_ids, kv_cache)
input_ids.append(next_token)
return input_ids
Decode phase
In the Decode phase, LLMs generate output one token at a time in an autoregressive fashion. For every new token, the model uses:
The cached Keys and Values from all previous tokens.
The Query of the new token to determine which parts of the cached information are most relevant.
When the model moves to the Decode phase to generate the subsequent tokens, it reuses these cached Keys and Values instead of recomputing them. This makes the generation of each new token much faster while adding the KV values for the new token in the Cache.
Without KV caching, every new token generation would require a full recalculation over the entire prompt plus all the previously generated tokens. For models with millions or billions of parameters, this would become computationally infeasible.
Example
Let’s say you input a prompt like: "What is the capital of India?"
Note: A token doesn’t necessarily mean a word. Here I am using words only to simplify learning.
Prefill Phase: The LLM computes the Keys and Values for the input tokens
[What]
,[is]
,[the]
,[capital]
,[of]
and[India?]
and caches them.Decode Phase: The model now starts generating the response. Say the first output token is
[New]
. To generate the next token,[Delhi]
, the model uses the cached Keys and Values from the first token i.e.[What]
to the last input token[India?]
along with the Query of[Delhi]
.This process continues, with each token's generation reusing the cache to accelerate computation until the stop token is generated or the model runs out of its max new tokens limit.
Memory and Performance Trade-offs
While KV caching improves speed, it also comes with memory costs. For every input token, the Key and Value matrices need to be stored in memory. As a result, longer prompts or conversations require more memory to store all the cached values.
This is why many models have a context length limit — beyond a certain number of tokens, the model can no longer keep caching all the Keys and Values, and performance might degrade. This is also the reason why some chat applications reset the conversation context after a certain point — to free up memory and keep the response times fast.
But how much memory does it actually take? Let’s understand it with an example. To load Llama 2 7b’s weights in 16 bit precision you would need 14 GB (16×7billion bits) of memory in your GPU. On top of that, if you use KV Cache you would need to allocate a cache memory which would be given by the following formula:
Total size of KV cache (bytes) = 2 * sequence_length * num_layers * hidden_size * num_bytes_for_precision = 2 × 2048 × 32 × 4096 × 2 = 2.1 GB
Note: The code above is a highly simplified version of LLM inference for learning. Real world implementations may vary.