使用渐进式层丢弃加速基于 Transformer 的语言模型训练

在本教程中,我们将介绍 DeepSpeed 中的渐进式层丢弃 (PLD) 并提供有关如何使用 PLD 的示例。PLD 允许在相同样本数量下将 Transformer 网络(如 BERT)的训练速度提高 24%,并在下游任务中获得类似精度的情况下将训练速度提高 2.5 倍。PLD 的详细描述和实验结果可在我们的 技术报告 中找到。

为了说明如何在 DeepSpeed 中使用 PLD,我们将展示如何启用 PLD 来预训练 BERT 模型,以及如何在 GLUE 数据集上微调预训练的模型。

使用 DeepSpeed 和 PLD 运行预训练

要执行预训练,需要先准备数据集。对于这部分,请参阅我们的 BERT 预训练 文章,其中包含有关如何进行数据下载和预处理的详细信息。对于以下实验,我们使用维基百科文本和图书语料库,类似于 Devlin 等人

预训练的主要部分在 deepspeed_train.py 中完成,该文件已修改为使用 DeepSpeed。 ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh 是启动使用 DeepSpeed 和 PLD 进行预训练的 shell 脚本。

bash ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh

如果您已经完成了 BERT 预训练 教程,那么上面的脚本中的大多数标志应该都很熟悉。要启用使用 PLD 进行训练,需要在客户端脚本和 DeepSpeed 引擎中都启用 PLD。要在客户端脚本中启用 PLD,需要添加以下命令行标志以在 Transformer 块上启用渐进式层丢弃。

--progressive_layer_drop

要在 DeepSpeed 中启用 PLD,需要使用以下适当的 PLD 配置字典更新 json 配置文件。

{
  ...
  "progressive_layer_drop": {
    "enabled": true,
    "theta": 0.5,
    "gamma": 0.001
  }
}

我们建议使用 0.5 的 PLD theta 值和 0.001 的 gamma 值,因为这些值在我们的实验中效果很好。

进行这些配置更改后,DeepSpeed 引擎应该会打印以下运行时消息。

[INFO] [logging.py:60:log_dist] [Rank 0] Enabled progressive layer dropping (theta = 0.5)

deepspeed_bsz4k_progressive_layer_drop_config_seq128.json 文件允许用户指定 DeepSpeed 选项,例如批次大小、微批次大小、优化器、学习率、序列长度和其他参数。以下是我们用于运行 BERT 和 PLD 的 DeepSpeed 配置文件。

{
  "train_batch_size": 4096,
  "train_micro_batch_size_per_gpu": 16,
  "steps_per_print": 1000,
  "prescale_gradients": true,
  "gradient_predivide_factor": 8,
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 1e-3,
      "weight_decay": 0.01,
      "bias_correction": false
    }
  },
  "gradient_clipping": 1.0,

  "wall_clock_breakdown": false,

  "fp16": {
    "enabled": true,
    "loss_scale": 0
  },

  "progressive_layer_drop": {
    "enabled": true,
    "theta": 0.5,
    "gamma": 0.001
  }
}

请注意,上述配置假设在 64 个 32GB V100 GPU 上进行训练。每个 GPU 使用 16 的微批次大小,并在有效批次大小达到 4096 之前累积梯度。如果您的 GPU 内存较少,则可能需要减少 “train_micro_batch_size_per_gpu”。或者,如果您有更多 GPU,可以增加 “train_batch_size” 以提高训练速度。我们使用以下超参数来进行启用 PLD 的 BERT 预训练。

参数
有效批次大小 4K
每个 GPU 的训练微批次大小 16
优化器 Adam
峰值学习率 1e-3
序列长度 128
学习率调度器 预热线性衰减 exp
预热比率 0.02
衰减率 0.99
衰减步长 1000
权重衰减 0.01
梯度裁剪 1.0

表 1. 预训练超参数

注意:DeepSpeed 现在支持 PreLayerNorm 作为训练 BERT 的默认方法,因为它能够避免梯度消失,稳定优化并提高性能,如我们关于 BERT 最快训练的 博客文章 中所述。因此,我们直接支持 BERT 上的可切换 Transformer 块,该块带有 PreLayerNorm。实现可以在 “example\bing_bert\nvidia\modelingpreln_layerdrop.py” 中找到。

在 GLUE 任务上使用 DeepSpeed 微调

我们使用 GLUE 进行微调任务。GLUE(通用语言理解评估基准)(https://gluebenchmark.com/)是一个包含句子或句子对自然语言理解任务的集合,包括问答、情感分析和文本蕴涵。它的设计旨在有利于样本高效学习和跨各种不同语言任务(在不同领域)的知识迁移。

可以使用提供的帮助程序 脚本 下载所有 GLUE 数据。下载完数据后,可以设置数据并将数据移动到 “/data/GlueData”,这是托管 GLUE 数据的默认位置。然后可以使用 PLD 预训练的 BERT 模型检查点来运行微调。

微调的主要部分在 run_glue_classifier_bert_base.py 中完成,该文件已修改为使用 DeepSpeed。在微调之前,需要通过 run_glue_classifier_bert_base.py 中的以下配置指定 BERT 模型配置。在本例中,它已修改为与预训练模型的配置相同。

    bert_model_config = {
        "vocab_size_or_config_json_file": 119547,
        "hidden_size": 768,
        "num_hidden_layers": 12,
        "num_attention_heads": 12,
        "intermediate_size": 3072,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "attention_probs_dropout_prob": 0.1,
        "max_position_embeddings": 512,
        "type_vocab_size": 2,
        "initializer_range": 0.02
    }

接下来,可以使用以下命令加载 DeepSpeed 风格的检查点,该命令也已添加到脚本中。

model.load_state_dict(checkpoint_state_dict['module'], strict=False)

最后, run_glue_classifier_bert_base.sh 脚本调用预训练并设置与微调相关的几个超参数。

bash run_glue_bert_base_finetune.sh [task] [batch size] [learning rate] [number of epochs] [job name] [checkpoint path]

一个示例将是

bash run_glue_bert_base_finetune.sh MNLI 32 3e-5 5 "fine_tune_MNLI" deepspeed_checkpoint.pt

预期结果

微调结果可以在 “logs” 目录下找到,以下是 PLD 在 GLUE 任务上的预期结果。“Lr” 行表示我们用于获得每个任务的对应精度结果的学习率。

  RTE MRPC STS-B CoLA SST-2 QNLI QQP MNLI-m/mm GLUE
指标 Acc. F1/Acc. PCC/SCC Acc. Acc. Acc. F1/Acc. Acc.  
Bert_{base}(原始) 66.4 88.9/84.8 87.1/89.2 52.1 93.5 90.5 71.2/89.2 84.6/83.4 80.7
Bert_{base}(我们的实现) 67.8 88.0/86.0 89.5/89.2 52.5 91.2 87.1 89.0/90.6 82.5/83.4 82.1
PLD 69.3 86.6/84.3 90.0/89.6 55.8 91.6 90.7 89.6/91.2 84.1/83.8 82.9
Lr 7e-5 9e-5 7e-5 5e-5 7e-5 9e-5 2e-4 3e-5  

更新: