久久ER99热精品一区二区-久久精品99国产精品日本-久久精品免费一区二区三区-久久综合九色综合欧美狠狠

博客專欄

EEPW首頁 > 博客 > 手撕大模型|KVCache 原理及代碼解析

手撕大模型|KVCache 原理及代碼解析

發布人:地平線開發者 時間:2025-09-13 來源:工程師 發布文章

在大型語言模型(LLM)的推理過程中,KV Cache 是一項關鍵技術,它通過緩存中間計算結果顯著提升了模型的運行效率。本文將深入解析 KV Cache 的工作原理、實現方式,并通過代碼示例展示其在實際應用中的效果。

一、為什么需要 KV Cache?

在 Transformer 進行自回歸推理(如文本生成,每次生成一個 token 的時候需要結合前面所有的 token 做 attention 操作)時,計算注意力機制時需要存儲 Key(K) 和 Value(V),以便下一個時間步可以復用這些緩存,而不必重新計算整個序列。

在標準 Transformer 解碼時,每次生成新 token 時:

  • 需要 重新計算所有之前 token 的 K 和 V,并與當前 token 進行注意力計算。

  • 計算復雜度是 O(n2)(對于長度為 n 的序列)。

img

而 KV Cache 通過存儲 K 和 V 的歷史值,避免重復計算:

  • 只需計算 新 token 的 K 和 V,然后將其與緩存的值結合使用。

  • 計算復雜度下降到 O(n)(每個 token 只與之前緩存的 token 計算注意力)。

二、KV Cache 的工作原理

KV Cache 的核心思想是緩存歷史計算中的鍵(Key)和值(Value)矩陣,避免重復計算。具體來說:

  1. 在生成第一個 token 時,模型計算并緩存所有輸入 token 的 K 和 V 矩陣

  2. 生成后續 token 時,只需要計算新 token 的查詢(Query)矩陣

  3. 將新的 Q 矩陣與緩存的 K、V 矩陣進行注意力計算,同時將新 token 的 K、V 追加到緩存中

這個過程可以用偽代碼直觀展示:

初始輸入: [t0, t1, t2]
首次計算: K=[K0,K1,K2], V=[V0,V1,V2] → 生成t3
緩存狀態: K=[K0,K1,K2], V=[V0,V1,V2]
第二次計算: 新Q=Q3
注意力計算: Attention(Q3, [K0,K1,K2]) → 生成t4
更新緩存: K=[K0,K1,K2,K3], V=[V0,V1,V2,V3]
第三次計算: 新Q=Q4
注意力計算: Attention(Q4, [K0,K1,K2,K3]) → 生成t5
更新緩存: K=[K0,K1,K2,K3,K4], V=[V0,V1,V2,V3,V4]
...

通過這種方式,每次新生成 token 時,只需計算新的 Q 矩陣并與歷史 KV 矩陣進行注意力計算,將時間復雜度從 O (n2) 降低到 O (n),極大提升了長序列生成的效率。

下面,我們結合示意圖進一步剖析一下 KV Cache 部分的邏輯。

img

img

img

img

KV Cache 核心節約的時間有三大塊:

  1. 前面 n-1 次的 Q 的計算,當然這塊對于一次一個 token 的輸出本來也沒有用;

  2. 同理還有 Attention 計算時對角矩陣變為最后一行,和 b 是同理的,這樣 mask 矩陣也就沒有什么用了;

  3. 前面 n-1 次的 K 和 V 的計算,也就是上圖紫色部分,這部分是實打實被 Cache 過不需要再重新計算的部分。

這里還有個 softmax 的問題,softmax 原本就是針對同一個 query 的所有 key 的計算,所以并不受影響。

2.1 KV Cache 的技術細節
  1. 緩存結構

KV Cache 通常為每個注意力頭維護獨立的緩存,結構如下:

  1. Key 緩存:形狀為 [batch_size, num_heads, seq_len, head_dim]

  2. Value 緩存:形狀為 [batch_size, num_heads, seq_len, head_dim]

其中,seq_len 會隨著生成過程動態增長,直到達到模型最大序列長度限制。

  1. 內存與速度的權衡

KV Cache 雖然提升了速度,但需要額外的內存存儲緩存數據。以 GPT-3 175B 模型為例,每個 token 的 KV 緩存約占用 20KB 內存,當生成 1000 個 token 時,單個樣本就需要約 20MB 內存。在批量處理時,內存消耗會線性增加。

實際應用中需要根據硬件條件在以下方面進行權衡:

  1. 最大緩存長度(影響能處理的序列長度)

  2. 批量大小(影響并發處理能力)

  3. 精度選擇(FP16 比 FP32 節省一半內存)

  4. 滑動窗口機制

當處理超長序列時,一些模型(如 Llama 2)采用滑動窗口機制,只保留最近的 N 個 token 的 KV 緩存,以控制內存占用。這種機制在犧牲少量上下文信息的情況下,保證了模型能處理更長的對話。

四、代碼實現解析

下面以 PyTorch 為例,展示 KV Cache 在自注意力計算中的實現方式。

  1. 基礎自注意力實現(無緩存)

首先看一下標準的自注意力計算,沒有緩存機制:

import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # 定義Q、K、V投影矩陣
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        
        # 計算Q、K、V
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 計算注意力分數
        attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_probs = F.softmax(attn_scores, dim=-1)
        
        # 應用注意力權重
        output = attn_probs @ v
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        
        return self.out_proj(output)
  1. 帶 KV Cache 的自注意力實現

下面修改代碼,加入 KV Cache 機制:

class CachedSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # 定義投影矩陣
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        # 初始化緩存
        self.cache_k = None
        self.cache_v = None
    
    def forward(self, x, use_cache=False):
        batch_size, seq_len, embed_dim = x.shape
        
        # 計算Q、K、V
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 如果使用緩存且緩存存在,則拼接歷史KV
        if use_cache and self.cache_k is not None:
            k = torch.cat([self.cache_k, k], dim=-2)
            v = torch.cat([self.cache_v, v], dim=-2)
        
        # 如果使用緩存,更新緩存
        if use_cache:
            self.cache_k = k
            self.cache_v = v
        
        # 計算注意力分數(注意這里的k是包含歷史緩存的)
        attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_probs = F.softmax(attn_scores, dim=-1)
        
        # 應用注意力權重
        output = attn_probs @ v
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        
        return self.out_proj(output)
    
    def reset_cache(self):
        """重置緩存,用于新序列的生成"""
        self.cache_k = None
        self.cache_v = None
  1. 生成過程中的緩存使用

在文本生成時,我們可以這樣使用帶緩存的注意力機制:

def generate_text(model, input_ids, max_length=50):
    # 初始化模型緩存
    model.reset_cache()
    
    # 處理初始輸入
    output = model(input_ids, use_cache=True)
    next_token = torch.argmax(output[:, -1, :], dim=-1, keepdim=True)
    generated = [next_token]
    
    # 生成后續token
    for _ in range(max_length - 1):
        # 只輸入新生成的token
        output = model(next_token, use_cache=True)
        next_token = torch.argmax(output[:, -1, :], dim=-1, keepdim=True)
        generated.append(next_token)
        
        # 如果生成結束符則停止
        if next_token.item() == 102:  # 假設102是[SEP]的id
            break
    
    return torch.cat(generated, dim=1)
五、KV Cache 的優化策略

在實際部署中,為了進一步提升 KV Cache 的效率,還會采用以下優化策略:

  1. 分頁 KV Cache(Paged KV Cache):借鑒內存分頁機制,將連續的 KV 緩存分割成固定大小的塊,提高內存利用率,代表實現有 vLLM。

  2. 動態緩存管理:根據輸入序列長度動態調整緩存大小,在批量處理時優化內存分配。

  3. 量化緩存:使用 INT8 或 INT4 等低精度格式存儲 KV 緩存,在犧牲少量精度的情況下大幅減少內存占用。

  4. 選擇性緩存:對于一些不重要的層或注意力頭,選擇性地不進行緩存,平衡速度和內存。

六、總結

KV Cache 通過緩存中間計算結果,有效解決了 Transformer 模型在生成式任務中的效率問題,是大模型能夠實現實時交互的關鍵技術之一。理解 KV Cache 的工作原理和實現方式,對于優化大模型推理性能、解決實際部署中的挑戰具有重要意義。

七、參考鏈接

https://zhuanlan.zhihu.com/p/670515231

https://zhuanlan.zhihu.com/p/714288577

https://zhuanlan.zhihu.com/p/715921106https://zhuanlan.zhihu.com/p/19489285169

https://medium.com/@joaolages/kv-caching-explained-276520203249


*博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



相關推薦

技術專區

關閉