ZeRO-Offload

ZeRO-3 Offload 包含我们新发布的 ZeRO-Infinity 中的一组功能。阅读我们的 ZeRO-Infinity 博客 以了解更多信息!

我们建议你在详细了解本教程之前阅读有关 入门ZeRO 的教程。

ZeRO-Offload 是一种 ZeRO 优化,它将优化器内存和计算从 GPU 转移到主机 CPU。ZeRO-Offload 使得具有高达 130 亿参数的大模型能够在单个 GPU 上高效训练。在本教程中,我们将使用 ZeRO-Offload 在 DeepSpeed 中训练一个 100 亿参数 GPT-2 模型。此外,在 DeepSpeed 模型中使用 ZeRO-Offload 非常简单快速,因为你只需要更改 DeepSpeed 配置 json 中的几个配置。不需要进行任何代码更改。

ZeRO-Offload 概述

对于大型模型训练,优化器(例如 Adam)会消耗大量的 GPU 计算和内存。ZeRO-Offload 通过利用主机 CPU 上的计算和内存资源来执行优化器,从而减少了此类模型的 GPU 计算和内存要求。此外,为了防止优化器成为瓶颈,ZeRO-Offload 使用 DeepSpeed 高度优化的 CPU 实现的 Adam,称为 DeepSpeedCPUAdam。DeepSpeedCPUAdam 比标准 PyTorch 实现快 5-7 倍。要深入了解 ZeRO-Offload 的设计和性能,请参阅我们的 博客文章

训练环境

在本教程中,我们将使用 DeepSpeed Megatron-LM GPT-2 代码配置一个 100 亿参数 GPT-2 模型。我们建议如果你之前没有这样做,请逐步完成 Megatron-LM 教程。在本练习中,我们将使用一个配备 32GB RAM 的 NVIDIA Tesla V100-SXM3 Tensor Core GPU

在单个 V100 GPU 上训练 100 亿参数 GPT-2

我们需要对 Megatron-LM 启动脚本和 DeepSpeed 配置 json 进行更改。

Megatron-LM GPT-2 启动脚本更改

我们需要对 DeepSpeed Megatron-LM GPT-2 模型的启动脚本进行两个更改。第一个更改是配置一个 100 亿参数 GPT-2 模型,并启用激活检查点,这可以通过以下一组更改来实现。

       --model-parallel-size 1 \
       --num-layers 50 \
       --hidden-size 4096 \
       --num-attention-heads 32 \
       --batch-size 10 \
       --deepspeed_config ds_zero_offload.config \
       --checkpoint-activations

如果你已完成 Megatron-LM 教程,则上述更改中的大多数标志应该很熟悉。

其次,我们需要应用以下更改,以确保仅使用一个 GPU 进行训练。

   deepspeed --num_nodes 1 --num_gpus 1 ...

DeepSpeed 配置更改

ZeRO-Offload 利用了许多 ZeRO 阶段 1 和 2 机制,因此启用 ZeRO-Offload 的配置更改是启用 ZeRO 阶段 1 或 2 所需更改的扩展。启用 ZeRO-Offload 的 zero_optimization 配置如下所示。

{
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
        }
        "contiguous_gradients": true,
        "overlap_comm": true
    }
}

如上所示,除了将 stage 字段设置为 2(启用 ZeRO 阶段 2,但阶段 1 也可行)之外,我们还需要将 offload_optimizer 设备设置为 cpu 以启用 ZeRO-Offload 优化。此外,我们可以设置其他 ZeRO 阶段 2 优化标志,例如 overlap_comm 来调整 ZeRO-Offload 性能。进行这些更改后,我们现在就可以运行模型了。我们在下面分享了一些训练的屏幕截图。

以下是训练日志的屏幕截图。

以下是 nvidia-smi 的屏幕截图,显示在训练期间仅 GPU 0 处于活动状态。

最后,以下是 htop 的屏幕截图,显示优化器计算期间的主机 CPU 和内存活动。

恭喜你!你已完成 ZeRO-Offload 教程。

更新: