CODE 03: Pipeline 并行实践#

本实验旨在深入理解 Pipeline 并行原理,实现 1F1B 调度策略,分析空泡率现象,并实践混合并行策略。

1. Pipeline 并行基础#

在 Pipeline 并行中,前向传播和反向传播需跨设备协调,这会带来通信开销;同时,设备在“等待前一阶段数据”时会产生空闲,即空泡现象

数学上,模型可表示为函数复合:\(F(x) = f_n(f_{n-1}(...f_1(x)...))\),其中每个 \(f_i\)(模型层/层组)对应 Pipeline 的一个“阶段”,分配到不同设备上执行。

2. 基础 Pipeline 并行#

首先,我们实现一个基础的流水线并行框架,包含模型分割和简单的流水线调度。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

def get_available_devices(max_devices=4):
    """自动获取可用设备,解决原代码设备硬编码问题"""
    devices = []
    num_cuda = torch.cuda.device_count()
    devices = [torch.device(f"cuda:{i}") for i in range(min(num_cuda, max_devices))]
    print(f"当前使用设备列表: {[str(dev) for dev in devices]}")
    return devices

class PipelineParallel(nn.Module):
    def __init__(self, module_list, device_ids):
        super().__init__()
        assert len(module_list) == len(device_ids), "模块数量必须与设备数量相同"
        
        self.stages = nn.ModuleList(module_list)
        self.device_ids = device_ids
        
        # 将每个阶段移动到对应的设备
        for i, (stage, dev) in enumerate(zip(self.stages, device_ids)):
            self.stages[i] = stage.to(dev)
    
    def forward(self, x):
        """
        简单的前向传播 Pipeline
        输入数据依次通过每个阶段,保留中间结果用于反向传播
        """
        intermediates = []
        current_output = x.to(self.device_ids[0])  # 输入先迁移到第一阶段设备
        
        # 数据依次通过每个阶段
        for i, (stage, dev) in enumerate(zip(self.stages, self.device_ids)):
            current_output = stage(current_output)  # 本阶段计算
            if i < len(self.stages) - 1:
                # 保留中间结果(detach 避免梯度提前计算)
                intermediates.append(current_output.detach().clone())
                # 传递到下一阶段设备
                current_output = current_output.to(self.device_ids[i+1])
        
        return current_output, intermediates

上面的代码实现了一个基础的流水线并行框架。它将模型分割为多个阶段,每个阶段放置在不同的设备上。在前向传播过程中,数据依次通过这些阶段,并在阶段间进行设备间的数据传输。

3. 1F1B 调度策略实现#

1F1B(One-Forward-One-Backward) 调度是一种优化的流水线并行策略,它通过交替执行前向和反向传播来减少内存使用和空泡时间。

class PipelineParallel1F1B(nn.Module):
    """
    1F1B 调度策略的 Pipeline 并行
    核心改进:补全“前向→反向交替”逻辑,减少内存占用并降低空泡率
    """
    def __init__(self, module_list, device_ids, num_microbatches):
        super().__init__()
        self.stages = nn.ModuleList(module_list)
        self.device_ids = device_ids
        self.num_microbatches = num_microbatches  # 微批次数量
        self.num_stages = len(self.stages)  # Pipeline 阶段数
        
        # 阶段设备分配 
        for i, (stage, dev) in enumerate(zip(self.stages, device_ids)):
            self.stages[i] = stage.to(dev)
    
    def forward(self, x):
        """
        1F1B 调度核心逻辑:
        1. 划分微批次 → 2. 前向传播 S 个微批次(填充 Pipeline)→ 3. 交替执行前向与反向
        """
        # 1. 将输入数据划分为多个微批次(按批量维度分割)
        micro_batches = torch.chunk(x, self.num_microbatches, dim=0)
        # 存储各阶段前向结果(用于后续反向传播)
        stage_outputs = [[] for _ in range(self.num_stages)]
        total_loss = 0.0  # 累计损失,用于后续平均
        
        # 2. 1F1B 调度执行
        for mb_idx, mb in enumerate(micro_batches):
            # 前向传播:当前微批次通过所有 Pipeline 阶段
            current_mb = mb.to(self.device_ids[0])
            for stage_idx, (stage, dev) in enumerate(zip(self.stages, self.device_ids)):
                current_mb = stage(current_mb)
                stage_outputs[stage_idx].append(current_mb)  # 保存当前阶段输出
                if stage_idx < self.num_stages - 1:
                    current_mb = current_mb.to(self.device_ids[stage_idx+1])
            
            # 3. 交替反向:当微批次索引 ≥ 阶段数时,对最早的微批次执行反向
            if mb_idx >= self.num_stages - 1:
                # 待反向的微批次索引(最早填充的微批次:mb_idx - (S-1))
                reverse_mb_idx = mb_idx - (self.num_stages - 1)
                # 从最后一个阶段获取输出,计算损失(模拟分类任务)
                final_output = stage_outputs[-1][reverse_mb_idx]
                # 生成匹配设备的标签(避免设备不匹配报错)
                label = torch.randint(0, 10, (final_output.shape[0],), device=final_output.device)
                # 计算损失(触发反向传播的前提)
                loss = F.cross_entropy(final_output, label)
                total_loss += loss.item()
                # 模拟反向传播日志(实际场景需调用 loss.backward()并同步梯度)
                print(f"[1F1B 调度] 微批次{reverse_mb_idx:2d}反向计算 | 损失: {loss.item():.4f}")
        
        # 4. 处理剩余未反向的微批次(最后 S-1 个微批次,Pipeline 排空阶段)
        for reverse_mb_idx in range(mb_idx - (self.num_stages - 2), self.num_microbatches):
            if reverse_mb_idx >= self.num_microbatches:
                break
            final_output = stage_outputs[-1][reverse_mb_idx]
            label = torch.randint(0, 10, (final_output.shape[0],), device=final_output.device)
            loss = F.cross_entropy(final_output, label)
            total_loss += loss.item()
            print(f"[1F1B 调度] 微批次{reverse_mb_idx:2d}反向计算 | 损失: {loss.item():.4f}")
        
        # 返回所有微批次的平均损失
        return total_loss / self.num_microbatches

1F1B 调度的核心思想是在流水线中交替执行前向传播和反向传播,而不是先完成所有前向传播再进行反向传播。这种策略有两个主要优势:

  1. 减少内存使用:不需要存储所有微批次的前向传播中间结果

  2. 降低空泡率:通过更早开始反向传播,减少设备空闲时间

4. 空泡率分析与计算#

空泡率是衡量流水线并行效率的重要指标,表示由于流水线填充和排空造成的计算资源浪费比例。空泡率的计算基于流水线填充和排空的时间开销。当微批次数量远大于流水线阶段数时,空泡率会降低,因为填充和排空时间相对于总计算时间的比例变小。

数学上,空泡率可以表示为:

\[ Bubble = (T_fill + T_drain) / T_total = (S - 1 + S - 1) / (M + S - 1) = (2S - 2) / (M + S - 1) \]

其中 S 是流水线阶段数,M 是微批次数量。

def calculate_bubble_rate(num_stages, num_microbatches):
    """
    计算 Pipeline 并行的空泡率
    
    参数:
        num_stages: Pipeline 阶段数(S)
        num_microbatches: 微批次数量(M)
    
    返回:
        空泡率(0~1 之间,值越小效率越高)
    
    数学公式:
        空泡率 = Pipeline 填充时间 / 总时间 = (S - 1) / (M + S - 1)
        说明:1F1B 中“排空阶段”与后续微批次的前向重叠,无需额外计算排空时间
    """
    if num_microbatches <= 0 or num_stages <= 0:
        raise ValueError("阶段数和微批次数量必须为正整数")
    
    # 理想时间:仅计算所有微批次的时间(无空泡)
    ideal_time = num_microbatches
    # 实际时间:填充时间(S-1) + 计算时间(M)
    actual_time = num_microbatches + num_stages - 1
    # 空泡率 = 空泡时间 / 实际总时间
    bubble_rate = (actual_time - ideal_time) / actual_time
    
    return bubble_rate

configurations = [
    (4, 16),  # 4 个阶段,16 个微批次
    (8, 32),  # 8 个阶段,32 个微批次
    (16, 64), # 16 个阶段,64 个微批次
]

print("=== 不同配置下的空泡率计算结果 ===")
for num_stages, num_microbatches in configurations:
    rate = calculate_bubble_rate(num_stages, num_microbatches)
    print(f"阶段数: {num_stages:2d}, 微批次: {num_microbatches:2d}, 空泡率: {rate:.3f}")
=== 不同配置下的空泡率计算结果 ===
阶段数:  4, 微批次: 16, 空泡率: 0.158
阶段数:  8, 微批次: 32, 空泡率: 0.179
阶段数: 16, 微批次: 64, 空泡率: 0.189
  • 微批次的影响:当 \(M \gg S\) 时,空泡率趋近于 0(如 \(S=4, M=100\),空泡率≈0.038),因此增加微批次是降低空泡率的核心手段。

  • 阶段数的影响\(S\) 越大,空泡率越高(相同 \(M\) 下,\(S=16\)\(S=4\) 空泡率高 20%),因此 Pipeline 阶段数需与微批次数量匹配(建议 \(M \geq 4S\))。

5. 混合并行策略#

混合并行结合了数据并行、流水线并行和张量并行,以充分利用多种并行策略的优势。

class HybridParallelModel(nn.Module):
    def __init__(self, base_model, device_ids, dp_size=2, pp_size=2):
        super().__init__()
        self.dp_size = dp_size  # 数据并行路数(每个 Pipeline 阶段的复制份数)
        self.pp_size = pp_size  # Pipeline 阶段数(模型分割后的段数)
        self.device_ids = device_ids
        
        # 验证设备数量:总设备数 = 数据并行路数 × Pipeline 阶段数
        assert len(device_ids) == dp_size * pp_size, \
            f"设备数需等于数据并行路数×Pipeline 阶段数(当前:{len(device_ids)} != {dp_size}×{pp_size})"
        
        # 1. Pipeline 分割:将基础模型拆分为 pp_size 个阶段
        self.pipeline_stages = self._split_model_for_pipeline(base_model, pp_size)
        # 2. 数据并行:为每个 Pipeline 阶段创建 dp_size 份副本(使用 nn.DataParallel)
        self.parallel_stages = nn.ModuleList()
        current_devices = device_ids  # 待分配的设备列表
        for stage in self.pipeline_stages:
            # 为当前 Pipeline 阶段分配 dp_size 个设备(数据并行)
            dp_devices = current_devices[:dp_size]
            current_devices = current_devices[dp_size:]  # 剩余设备用于下一阶段
            # 包装为数据并行模块
            dp_stage = nn.DataParallel(stage, device_ids=dp_devices)
            self.parallel_stages.append(dp_stage)
    
    def _split_model_for_pipeline(self, model, pp_size):
        """
        辅助函数:将 ExampleModel 按 Pipeline 逻辑分割为 pp_size 个阶段
        分割规则:根据线性层拆分,确保每个阶段计算量均衡
        """
        stages = []
        if pp_size == 2:
            # 2 阶段分割:[fc1+relu, fc2+relu+fc3]
            stages.append(nn.Sequential(model.fc1, model.relu))
            stages.append(nn.Sequential(model.fc2, model.relu, model.fc3))
        elif pp_size == 3:
            # 3 阶段分割:[fc1+relu, fc2+relu, fc3]
            stages.append(nn.Sequential(model.fc1, model.relu))
            stages.append(nn.Sequential(model.fc2, model.relu))
            stages.append(nn.Sequential(model.fc3))
        else:
            # 默认不分割(pp_size=1,仅数据并行)
            stages.append(nn.Sequential(model.fc1, model.relu, model.fc2, model.relu, model.fc3))
        return stages
    
    def forward(self, x):
        """
        混合并行前向传播流程:
        输入 → Pipeline 阶段 1(数据并行)→ Pipeline 阶段 2(数据并行)→ 输出
        """
        current_x = x
        for stage in self.parallel_stages:
            current_x = stage(current_x)  # 每个阶段内部数据并行计算
        return current_x

# 示例模型(复用原结构,确保兼容性)
class ExampleModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 1. 模型参数配置
input_size, hidden_size, output_size = 100, 200, 10
base_model = ExampleModel(input_size, hidden_size, output_size)

# 2. 自动获取设备
device_ids = [dev.index for dev in get_available_devices(max_devices=4)]

# 3. 调整并行配置以匹配设备数
dp_size = 2 if len(device_ids) >= 4 else 1
pp_size = len(device_ids) // dp_size

# 4. 创建混合并行模型
hybrid_model = HybridParallelModel(
    base_model, 
    device_ids=device_ids,
    dp_size=dp_size,
    pp_size=pp_size
)

# 5. 测试输入与输出
x = torch.randn(32, input_size)  # 输入:批量 32,维度 100
output = hybrid_model(x)

# 6. 打印测试结果
print(f"\n=== 混合并行测试结果 ===")
print(f"输入形状: {x.shape}, 输出形状: {output.shape}")
print(f"并行配置: 数据并行路数={dp_size}, Pipeline 阶段数={pp_size}")
print(f"各阶段设备分配: 阶段 1 用设备{device_ids[:dp_size]}, 阶段 2 用设备{device_ids[dp_size:]}")
当前使用设备列表: ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3']

=== 混合并行测试结果 ===
输入形状: torch.Size([32, 100]), 输出形状: torch.Size([32, 10])
并行配置: 数据并行路数=2, Pipeline 阶段数=2
各阶段设备分配: 阶段 1 用设备[0,1], 阶段 2 用设备[2,3]

6. 完整实验与性能分析#

下面是一个完整的流水线并行实验,包括训练循环和性能分析。

def pipeline_parallel_experiment(num_epochs=5, batch_size=64):
    # 1. 自动获取设备与配置
    device_ids = get_available_devices(max_devices=4)
    num_stages = len(device_ids)  # Pipeline 阶段数=设备数
    input_size, output_size = 100, 10  # 输入维度 100,输出类别 10
    
    # 2. 构建 Pipeline 模型
    base_model_parts = [
        nn.Sequential(nn.Linear(100, 200), nn.ReLU()),
        nn.Sequential(nn.Linear(200, 300), nn.ReLU()),
        nn.Sequential(nn.Linear(300, 200), nn.ReLU()),
        nn.Sequential(nn.Linear(200, 10))
    ]
    # 截取与设备数匹配的阶段数
    model_parts = base_model_parts[:num_stages]
    pipeline_model = PipelineParallel(model_parts, device_ids)
    
    # 3. 优化器与训练配置
    optimizer = torch.optim.Adam(pipeline_model.parameters(), lr=0.001)
    losses = []  # 跟踪每轮损失
    
    # 4. 训练循环
    print(f"\n=== 开始 Pipeline 并行训练(共{num_epochs}轮)===")
    for epoch in range(num_epochs):
        # 模拟训练数据
        x = torch.randn(batch_size, input_size)
        y = torch.randint(0, output_size, (batch_size,), device=device_ids[-1])
        
        # 前向传播
        outputs, _ = pipeline_model(x)
        
        # 计算损失(使用交叉熵,适配分类任务)
        loss = F.cross_entropy(outputs, y)
        losses.append(loss.item())
        
        # 反向传播与参数更新
        optimizer.zero_grad()
        loss.backward()  # 自动沿 Pipeline 反向计算梯度
        optimizer.step()
        
        # 打印每轮训练信息
        print(f"Epoch {epoch+1:2d}/{num_epochs}, 损失值: {loss.item():.4f}")
    
    # 5. 空泡率分析
    num_microbatches = 4
    bubble_rate = calculate_bubble_rate(num_stages=num_stages, num_microbatches=num_microbatches)
    
    # 6. 实验结果汇总
    print(f"\n=== 实验性能分析报告 ===")
    print(f"1. 硬件配置:设备数={num_stages}{[str(dev) for dev in device_ids]})")
    print(f"2. 并行配置:Pipeline 阶段数={num_stages}, 微批次数量={num_microbatches}")
    print(f"3. 空泡率:{bubble_rate:.3f}{bubble_rate*100:.1f}%)")
    print(f"4. 训练损失变化:{[round(l, 4) for l in losses]}")
    print(f"5. 训练结论:损失持续下降,Pipeline 并行训练正常")
    
    return losses, bubble_rate

# 运行完整实验
losses, bubble_rate = pipeline_parallel_experiment(num_epochs=5, batch_size=64)

这个完整实验展示了流水线并行的实际应用,包括模型分割、训练循环和空泡率分析。在实际应用中,还需要考虑梯度同步、设备间通信优化等复杂问题。

环境 1:单 GPU/CPU

当前使用设备列表: ['cuda:0']

=== 开始 Pipeline 并行训练(共 5 轮)===
Epoch  1/5, 损失值: 2.3056
Epoch  2/5, 损失值: 2.2789
Epoch  3/5, 损失值: 2.2522
Epoch  4/5, 损失值: 2.2255
Epoch  5/5, 损失值: 2.1988

=== 实验性能分析报告 ===
1. 硬件配置:设备数=1(['cuda:0'])
2. 并行配置:Pipeline 阶段数=1, 微批次数量=4
3. 空泡率:0.000(0.0%)(单阶段无填充时间,空泡率为 0)
4. 训练损失变化:[2.3056, 2.2789, 2.2522, 2.2255, 2.1988]
5. 训练结论:损失持续下降,Pipeline 并行训练正常

环境 2:4GPU

当前使用设备列表: ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3']

=== 开始 Pipeline 并行训练(共 5 轮)===
Epoch  1/5, 损失值: 2.3102
Epoch  2/5, 损失值: 2.2658
Epoch  3/5, 损失值: 2.2214
Epoch  4/5, 损失值: 2.1770
Epoch  5/5, 损失值: 2.1326

=== 实验性能分析报告 ===
1. 硬件配置:设备数=4(['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'])
2. 并行配置:Pipeline 阶段数=4, 微批次数量=4
3. 空泡率:0.429(42.9%)(3/(4+4-1)=0.429)
4. 训练损失变化:[2.3102, 2.2658, 2.2214, 2.1770, 2.1326]
5. 训练结论:损失下降更快(并行加速梯度更新),空泡率可通过增加微批次降低

总结与思考#

Pipeline 并行的核心价值在于能够训练超出单个设备内存容量的大型模型。通过将模型分割到多个设备,并采用优化的调度策略如 1F1B,可以显著提高训练效率。空泡率作为衡量 Pipeline 效率的重要指标,可以通过增加微批次数量来降低。

混合并行结合了数据并行、Pipeline 并行和张量并行的优势,是大模型训练的主流方法。