CODE 02: Min-P vs Top-P 探索#
在自然语言生成领域,采样策略的选择对生成文本的质量和多样性影响深远。本文将探讨一种新兴的采样方法——最小 P 采样(Min-P Sampling),并与广泛使用的 Top-P(核)采样进行对比实验。
通过直观的代码示例和深入分析,我们希望帮助读者理解这两种方法的原理与差异,特别关注它们在 Qwen3 4B 模型上的表现。
1. Top-P 采样实现#
Top-P 采样从概率分布中选择概率累积和超过阈值 p 的最小 token 集合,然后从这个集合中重新归一化概率并采样。这种方法确保只考虑概率较高的 token,同时保持一定的多样性。
数学表达式为:\(V_{\text{Top-P}} = \{v_i \in V \mid \sum_{j=1}^{i} p(v_j) \geq p\}\),其中 \(V\) 是按概率降序排列的词汇表。
import torch
import torch.nn.functional as F
def top_p_sampling(logits, p=0.9):
"""
实现 Top-P 采样策略
参数:
logits: 模型输出的原始 logits
p: 累积概率阈值
返回:
采样得到的 token 索引
"""
# 将 logits 转换为概率
probs = F.softmax(logits, dim=-1)
# 对概率进行降序排序
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
# 计算累积概率
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# 移除累积概率超过 p 的 token
indices_to_remove = cumulative_probs > p
# 确保至少保留一个 token
indices_to_remove[..., 1:] = indices_to_remove[..., :-1].clone()
indices_to_remove[..., 0] = False
# 创建过滤后的概率分布
sorted_indices_to_remove = sorted_indices[indices_to_remove]
probs[sorted_indices_to_remove] = 0
# 重新归一化概率
filtered_probs = probs / probs.sum()
# 从过滤后的分布中采样
next_token = torch.multinomial(filtered_probs, num_samples=1)
return next_token.item()
2. Min-P 采样实现#
Min-P 采样是 Top-P 采样的变体,它设置一个最小概率阈值而不是累积概率阈值。具体来说,它保留所有概率大于最小阈值 \(p_{\text{min}}\) 的 token,然后从这些 token 中采样。这种方法提供了一种更直接的概率阈值控制方式。
数学表达式为:\(V_{\text{min-p}} = \{v_i \in V \mid p(v_i) \geq p_{\text{min}}\}\)
def min_p_sampling(logits, min_prob=0.05):
"""
实现 Min-P 采样策略
参数:
logits: 模型输出的原始 logits
min_prob: 最小概率阈值
返回:
采样得到的 token 索引
"""
# 将 logits 转换为概率
probs = F.softmax(logits, dim=-1)
# 移除概率低于阈值的 token
indices_to_remove = probs < min_prob
probs[indices_to_remove] = 0
# 如果所有 token 都被移除,则保留概率最高的 token
if probs.sum() == 0:
probs = F.softmax(logits, dim=-1)
probs[1:] = 0 # 只保留概率最高的 token
probs[0] = 1
# 重新归一化概率
filtered_probs = probs / probs.sum()
# 从过滤后的分布中采样
next_token = torch.multinomial(filtered_probs, num_samples=1)
return next_token.item()
3. 实验设置#
为了对比这两种采样策略在现代化模型上的效果,我们选择使用 Qwen3 4B 模型进行实验。这个模型在参数量和性能之间取得了良好平衡,适合进行采样策略的对比研究。
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# 加载 Qwen3 4B 模型和分词器
model_name = "Qwen/Qwen2-4B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# 设置填充 token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 将模型设置为评估模式
model.eval()
4. 文本生成#
我们编写一个通用的文本生成函数,它可以接受不同的采样策略,并在 Qwen3 4B 模型上进行文本生成。
def generate_text(prompt, sampling_function, sampling_param, max_length=50):
"""
使用指定的采样策略生成文本
参数:
prompt: 输入文本提示
sampling_function: 采样函数
sampling_param: 采样参数
max_length: 生成的最大长度
返回:
生成的文本
"""
# 编码输入文本
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
# 存储生成的 token
generated = input_ids
# 使用 no_grad 避免计算梯度
with torch.no_grad():
for _ in range(max_length):
# 获取模型输出
outputs = model(generated)
next_token_logits = outputs.logits[:, -1, :]
# 使用指定的采样策略获取下一个 token
next_token = sampling_function(next_token_logits, sampling_param)
# 将新 token 添加到已生成序列中
generated = torch.cat([generated, torch.tensor([[next_token]], device=model.device)], dim=-1)
# 如果生成了结束 token,停止生成
if next_token == tokenizer.eos_token_id:
break
# 解码生成的文本
return tokenizer.decode(generated[0], skip_special_tokens=True)
5. 对比实验#
现在我们来对比两种采样策略在不同参数下的文本生成效果。首先定义一个测试函数来批量生成文本。
def compare_sampling_strategies(prompt, top_p_values, min_p_values):
"""
对比不同参数下的采样策略效果
参数:
prompt: 输入提示
top_p_values: 要测试的 Top-P 值列表
min_p_values: 要测试的 min-p 值列表
"""
print(f"输入提示: '{prompt}'\n")
# 测试 Top-P 采样
print("Top-P 采样结果:")
for p in top_p_values:
generated_text = generate_text(prompt, top_p_sampling, p)
print(f"p={p}: {generated_text[len(prompt):]}")
print("\n" + "="*50 + "\n")
# 测试 Min-P 采样
print("Min-P 采样结果:")
for min_prob in min_p_values:
generated_text = generate_text(prompt, min_p_sampling, min_prob)
print(f"min_prob={min_prob}: {generated_text[len(prompt):]}")
使用相同的提示文本,我们观察两种采样策略在 Qwen3 4B 模型上的生成效果。选择适当的参数范围对于观察差异至关重要。
# 设置测试参数
test_prompt = "人工智能的未来发展将"
top_p_values = [0.5, 0.8, 0.9]
min_p_values = [0.01, 0.05, 0.1]
# 运行对比实验
compare_sampling_strategies(test_prompt, top_p_values, min_p_values)
从生成结果中,我们可以观察到一些有趣的现象。较低的 p 值(如 0.5)在 Top-P 采样中会导致更保守的生成,只考虑概率最高的几个 token,而较高的 p 值(如 0.9)会允许更多样化的生成,但可能包含一些不太相关的 token。生成结果的质量和多样性对 p 值非常敏感,这反映了 Top-P 采样的动态特性。
对于 Min-P 采样,较低的最小概率阈值(如 0.01)会保留更多 token,导致更多样化的生成,而较高的最小概率阈值(如 0.1)会更严格,只保留概率较高的 token。相对于 Top-P,Min-P 提供了一种更直接的概率阈值控制方式,使生成过程更加稳定和可预测。
6. 可视化分析#
为了更直观地理解两种方法的差异,我们可以可视化它们对概率分布的过滤效果。
import matplotlib.pyplot as plt
import numpy as np
def visualize_sampling_effect(logits, top_p=0.9, min_p=0.05):
"""
可视化两种采样策略对概率分布的影响
"""
probs = F.softmax(logits, dim=-1).cpu().numpy().flatten()
# 对概率进行排序
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
# 创建图表
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Top-P 可视化
top_p_mask = cumulative_probs <= top_p
ax1.bar(range(len(sorted_probs)), sorted_probs, alpha=0.7)
ax1.bar(range(len(sorted_probs))[top_p_mask], sorted_probs[top_p_mask], color='red', alpha=0.7)
ax1.axhline(y=0, color='k', linestyle='-', alpha=0.3)
ax1.set_title(f'Top-P Sampling (p={top_p})')
ax1.set_xlabel('Token Rank')
ax1.set_ylabel('Probability')
# Min-P 可视化
min_p_mask = sorted_probs >= min_p
ax2.bar(range(len(sorted_probs)), sorted_probs, alpha=0.7)
ax2.bar(range(len(sorted_probs))[min_p_mask], sorted_probs[min_p_mask], color='green', alpha=0.7)
ax2.axhline(y=min_p, color='k', linestyle='--', alpha=0.7, label=f'Min-p threshold ({min_p})')
ax2.set_title(f'Min-P Sampling (min_p={min_p})')
ax2.set_xlabel('Token Rank')
ax2.set_ylabel('Probability')
ax2.legend()
plt.tight_layout()
plt.show()
# 获取一个示例 logits 分布
with torch.no_grad():
test_input = tokenizer.encode("人工智能的", return_tensors="pt").to(model.device)
outputs = model(test_input)
sample_logits = outputs.logits[:, -1, :].squeeze().cpu()
# 可视化
visualize_sampling_effect(sample_logits, top_p=0.9, min_p=0.05)
通过可视化分析,我们可以清楚地看到两种采样方法如何过滤概率分布。Top-P 采样保留累积概率达到阈值的最少 token 集合,而 Min-P 采样保留所有概率超过最小阈值的 token。这种差异在实际生成过程中会导致不同的文本质量和多样性特征。
7. 总结与思考#
从计算效率角度来看,两种方法在复杂度上相似,都需要对概率进行排序和过滤。在实际应用中,最佳采样策略和参数往往需要根据具体任务和模型进行调优。
Min-P 采样作为 Top-P 采样的新兴变体,提供了另一种控制文本生成质量的方式。随着大语言模型技术的不断发展,采样策略的研究也将继续深入,为自然语言生成任务提供更多灵活性和控制能力。