CODE 02: PagedAttention 复现#
这篇文章将带你深入理解 PagedAttention 的工作原理,并通过简化的 Python 实现来展示其核心机制,帮助你在有限资源下高效运行大模型。
大模型推理内存挑战#
随着大语言模型在自然语言处理、图像识别等领域的广泛应用,其庞大的参数规模和高内存需求已成为实际部署的主要瓶颈。特别是在移动设备和资源受限环境中,有限的内存容量往往无法满足大模型运行的基本需求。Transformer 模型在自回归推理过程中需要存储历史计算的键值对(KV Cache),这部分内存占用随着序列长度增加而线性增长,成为显存使用的主要因素。
传统 KV Cache 管理方式存在两个突出问题:一是内存碎片化,由于不同序列长度变化不可预测,导致显存分配不连续;二是过度保留,系统为应对可能的长序列往往预先分配过多显存,实际利用率却很低。研究表明,现有系统中 60%-80% 的显存因此被浪费。
PagedAttention 技术借鉴了操作系统中的虚拟内存和分页思想,通过将 KV Cache 划分为固定大小的块(block)并动态管理,显著提高了内存利用效率,使大模型在资源受限环境中的部署成为可能。
2. PagedAttention 原理#
PagedAttention 的核心创新在于将操作系统的分页机制引入到注意力计算中。与传统方法需要为每个序列分配连续内存空间不同,PagedAttention 将键值缓存分割成固定大小的页面(page),每个页面存储一定数量 token 的键和值。
这种设计使得内存分配从连续变为非连续,通过一个块表(block table)来维护逻辑页面到物理页面的映射关系,类似于操作系统中的页表机制。当序列增长需要更多空间时,系统只需分配新的物理页面并更新块表,无需重新分配整个连续内存空间。
从数学角度看,传统的注意力计算可以表示为:
其中 \(Q\)、\(K\)、\(V\) 分别是查询、键和值矩阵,\(d_k\) 是键向量的维度。
在 PagedAttention 中,由于键和值被分散存储在多个页面中,注意力计算需要按页面进行:
这里 \(K_b\) 和 \(V_b\) 表示第 \(b\) 个页面的键和值,\(B\) 是每个页面存储的 token 数量。这种分页计算方式虽然增加了页面管理开销,但大大降低了内存碎片化。
PagedAttention 的另一个重要优势是支持高效的内存共享。在并行采样或波束搜索等场景中,多个输出序列通常由同一个提示(prompt)生成。传统方法需要为每个序列单独存储键值缓存,而 PagedAttention 允许不同序列共享相同的物理页面,只需在块表中设置不同的映射关系即可。
这种内存共享机制显著减少了重复内容的内存占用,提高了整体吞吐量,特别是在处理长提示文本时效果更加明显。
3 页面与块表结构#
下面我们通过一个简化的 Python 实现来展示 PagedAttention 的核心机制。为了使代码易于理解,我们省略了生产环境中的一些优化措施。
首先定义页面(Page)和块表(PageTable)的数据结构,这是 PagedAttention 的基础:
import torch
import math
class Page:
"""表示一个物理页面,存储固定数量 token 的键值对"""
def __init__(self, page_size, num_heads, head_dim):
self.page_size = page_size # 每个页面存储的 token 数量
self.num_heads = num_heads
self.head_dim = head_dim
# 初始化键值存储空间
self.keys = torch.zeros(page_size, num_heads, head_dim)
self.values = torch.zeros(page_size, num_heads, head_dim)
self.ref_count = 0 # 引用计数,用于页面复用
def update_access(self):
"""更新页面访问信息,用于实现缓存替换策略"""
self.ref_count += 1
Page
类表示一个物理页面,类似于操作系统中的内存页。每个页面有固定容量(page_size
),存储多个 token 的键和值。ref_count
字段记录页面被引用的次数,用于实现类似 LRU 的页面置换算法。
class PageTable:
"""管理逻辑页面到物理页面的映射关系"""
def __init__(self):
self.logical_to_physical = {} # 映射表:逻辑页面 ID → 物理页面 ID
def map_page(self, logical_page_id, physical_page_id):
"""建立逻辑页面到物理页面的映射"""
self.logical_to_physical[logical_page_id] = physical_page_id
def get_physical_page(self, logical_page_id):
"""根据逻辑页面 ID 获取物理页面 ID"""
return self.logical_to_physical.get(logical_page_id, -1)
PageTable
类管理逻辑页面与物理页面的映射关系,类似于操作系统的页表。当需要访问特定逻辑页面时,通过这个映射表找到对应的物理页面位置。
4. 块管理器与序列管理#
接下来实现块管理器(BlockManager)和序列管理器(SequenceManager),负责页面的分配、回收和序列的页面映射:
class BlockManager:
"""管理全局页面池,负责页面的分配和回收"""
def __init__(self, num_pages, page_size, num_heads, head_dim):
self.num_pages = num_pages
self.page_size = page_size
self.num_heads = num_heads
self.head_dim = head_dim
# 初始化页面池
self.pages = [Page(page_size, num_heads, head_dim) for _ in range(num_pages)]
self.free_pages = list(range(num_pages)) # 空闲页面列表
self.allocated_pages = set() # 已分配页面集合
def allocate_page(self):
"""从内存池中分配一个物理页面"""
if not self.free_pages:
raise RuntimeError("No free pages available")
page_id = self.free_pages.pop()
self.allocated_pages.add(page_id)
self.pages[page_id].ref_count += 1
return page_id
def free_page(self, page_id):
"""释放页面回到内存池"""
if page_id in self.allocated_pages:
self.pages[page_id].ref_count -= 1
if self.pages[page_id].ref_count == 0:
self.allocated_pages.remove(page_id)
self.free_pages.append(page_id)
BlockManager
管理物理页面池,处理页面分配和回收。它维护了一个空闲页面列表和已分配页面集合,采用引用计数机制确保页面安全复用。当页面引用计数降为零时,页面被回收至空闲池。
class SequenceManager:
"""管理每个序列的页面映射和令牌存储"""
def __init__(self, block_manager):
self.block_manager = block_manager
self.sequences = {} # 存储序列 ID 到页表的映射
def create_sequence(self, seq_id):
"""创建新序列并初始化页表"""
page_table = PageTable()
self.sequences[seq_id] = page_table
return page_table
def append_token(self, seq_id, token_pos, key, value):
"""为序列添加 token 的键值对"""
page_table = self.sequences[seq_id]
# 计算逻辑页面 ID 和页面内偏移量
logical_page_id = token_pos // self.block_manager.page_size
page_offset = token_pos % self.block_manager.page_size
# 获取或分配物理页面
physical_page_id = page_table.get_physical_page(logical_page_id)
if physical_page_id == -1:
physical_page_id = self.block_manager.allocate_page()
page_table.map_page(logical_page_id, physical_page_id)
# 将键值对写入页面
page = self.block_manager.pages[physical_page_id]
page.keys[page_offset] = key
page.values[page_offset] = value
page.update_access() # 更新页面访问信息
SequenceManager
为每个序列维护独立的页表。当写入新 token 时,它计算 token 应该所在的逻辑页面和页面内偏移量。如果逻辑页面尚未映射到物理页面,则从块管理器分配新物理页面。这种设计使得不同序列的页面可以混合存储在物理内存中,减少了内存碎片。
5. PagedAttention 计算#
现在实现分页注意力计算的核心逻辑,从分散的页面收集键值并执行注意力运算:
class PagedAttention:
"""执行分页注意力计算"""
def __init__(self, block_manager):
self.block_manager = block_manager
def compute_attention(self, query, page_table, seq_len):
"""
计算分页注意力
query: [batch_size, num_heads, head_dim]
page_table: 序列的页表
seq_len: 当前序列长度
"""
batch_size, num_heads, head_dim = query.shape
scale = 1.0 / math.sqrt(head_dim)
# 收集所有页面的键值
all_keys, all_values = [], []
num_pages = (seq_len + self.block_manager.page_size - 1) // self.block_manager.page_size
for logical_page_id in range(num_pages):
physical_page_id = page_table.get_physical_page(logical_page_id)
if physical_page_id == -1:
continue # 跳过未分配的页面
page = self.block_manager.pages[physical_page_id]
# 计算当前页面有效的 token 数量(最后一页可能不满)
start_idx = logical_page_id * self.block_manager.page_size
valid_tokens = min(self.block_manager.page_size, seq_len - start_idx)
# 提取有效键值
page_keys = page.keys[:valid_tokens]
page_values = page.values[:valid_tokens]
all_keys.append(page_keys)
all_values.append(page_values)
if not all_keys:
return torch.zeros(batch_size, num_heads, head_dim)
# 拼接所有键值
keys = torch.cat(all_keys, dim=0) # [seq_len, num_heads, head_dim]
values = torch.cat(all_values, dim=0) # [seq_len, num_heads, head_dim]
# 计算注意力分数
query = query.unsqueeze(2) # [batch_size, num_heads, 1, head_dim]
keys = keys.transpose(0, 1).unsqueeze(0) # [1, num_heads, seq_len, head_dim]
scores = torch.matmul(query, keys.transpose(-2, -1)) * scale # [batch_size, num_heads, 1, seq_len]
scores = scores.squeeze(2) # [batch_size, num_heads, seq_len]
# Softmax 和加权求和
attention_weights = torch.softmax(scores, dim=-1)
values = values.transpose(0, 1).unsqueeze(0) # [1, num_heads, seq_len, head_dim]
output = torch.sum(attention_weights.unsqueeze(-1) * values, dim=2)
return output # [batch_size, num_heads, head_dim]
PagedAttention
类负责实际的分页注意力计算。它首先根据页表收集所有相关的键值页面,考虑到最后一页可能不满的情况。然后将这些页面拼接成完整的键值矩阵,执行标准的注意力计算。这种实现虽然增加了页面收集的开销,但大大降低了内存碎片化,提高了内存利用率。
6. 实验验证#
以下代码演示了 PagedAttention 的完整工作流程,模拟了实际推理场景:
# 实验设置和参数初始化
num_pages = 100
page_size = 8 # 每个页面存储 8 个 token
num_heads = 4
head_dim = 16
batch_size = 1
# 初始化管理器和注意力模块
block_manager = BlockManager(num_pages, page_size, num_heads, head_dim)
seq_manager = SequenceManager(block_manager)
paged_attn = PagedAttention(block_manager)
# 创建序列
seq_id = 0
page_table = seq_manager.create_sequence(seq_id)
seq_len = 0
# 模拟生成 20 个 token 的过程
print("开始模拟生成 20 个 token 的过程...")
for token_pos in range(20):
# 生成随机键值(模拟 Transformer 层的输出)
key = torch.randn(num_heads, head_dim)
value = torch.randn(num_heads, head_dim)
# 存储到 Cache
seq_manager.append_token(seq_id, token_pos, key, value)
seq_len += 1
# 每 5 个 token 计算一次注意力
if (token_pos + 1) % 5 == 0:
query = torch.randn(batch_size, num_heads, head_dim)
output = paged_attn.compute_attention(query, page_table, seq_len)
print(f"已处理 Token 数量: {token_pos+1}, 注意力输出形状: {output.shape}")
print(f"当前使用的物理页面数: {len([pid for pid in page_table.logical_to_physical.values() if pid != -1])}")
这段实验代码展示了 PagedAttention 的完整工作流程。它模拟了生成 20 个 token 的过程,每生成 5 个 token 计算一次注意力。通过输出信息可以看到内存使用情况,验证了 PagedAttention 的动态内存分配特性。
7. 总结与思考#
PagedAttention 技术已广泛应用于多种大模型推理场景。在长序列处理中,如长文档生成、代码补全等任务,PagedAttention 通过分页机制有效支持了万级别 token 长度的序列,突破了传统方法的内存限制。
在高并发推理服务中,PagedAttention 的内存共享特性特别有价值。多个请求可以共享相同的提示页面,显著提高了吞吐量和资源利用率。vLLM 框架基于 PagedAttention 实现了高达 23 倍的吞吐量提升。