总结论文时安全检查不通过,跳过。 "Deja Vu: 推理时高效的大型语言模型的上下文稀疏性 Zichang Liu1 Jue Wang2 Tri Dao3 Tianyi Zhou4 Binhang Yuan5 Zhao Song6 Anshumali Shrivastava1 Ce Zhang5 Yuandong Tian7 Christopher Ré3 Beidi Chen8 7 摘要 拥有数千亿参数的大型语言模型引发了新一轮令人兴奋的AI应用。然而,它们在推理时计算成本很高。稀疏性是降低这一成本的自然方法,但现有方法要么需要昂贵的重新训练,要么不得不放弃大型语言模型(LLM)的上下文学习能力,要么在现代硬件上没有实现实时速度提升。我们假设上下文稀疏性,即对于给定输入,能够产生与密集模型大致相同输出的小型、输入依赖的注意力头和MLP参数集合,可以解决这些问题。我们展示了上下文稀疏性的存在,它可以被准确预测,并且我们可以利用它在不牺牲LLM质量和上下文学习能力的情况下,加快LLM推理的实时速度。 基于这些见解,我们提出了DEJAVU,这是一个系统,它使用低成本算法实时预测每层输入的上下文稀疏性,并结合异步和硬件感知的实现来加速LLM推理。我们验证了DEJAVU可以将OPT-175B的推理延迟比最先进的FasterTransformer减少超过2倍,比广泛使用的Hugging Face实现减少超过6倍,同时不牺牲模型质量。代码可在 https://github.com/FMInference/DejaVu 上获取。 1 引言 像GPT-3、PaLM和OPT这样的大型语言模型展示了巨大的参数数量释放出令人印象深刻的性能和新兴的上下文学习能力——它们可以通过条件输入输出示例来执行任务,而无需更新它们的参数。然而,它们在推理时非常昂贵,特别是对于延迟敏感的应用。理想的推理时模型应该在保持预训练LLM的性能和特殊能力的同时,使用更少的计算和内存。最简单和最自然的方法是稀疏化或剪枝,这在LLM时代之前有着悠久的历史。 不幸的是,在保持质量和上下文学习能力的同时,在实时速度上加快推理时间的稀疏LLM仍然是一个挑战。虽然稀疏化和剪枝已经被深入研究,但由于在现代硬件如GPU上的质量和效率权衡不佳,它们在LLM上的广泛应用仍然有限。首先,在数千亿参数的规模上重新训练或迭代剪枝模型是不可行的。因此,迭代剪枝和彩票假设的方法只能应用于较小规模的模型。其次,保持LLM的上下文学习能力的稀疏性很难找到。许多研究表明,任务依赖性剪枝是有效的,但为每个任务维护不同的模型与LLM的任务独立性目标相冲突。最后,由于现代硬件的已知困难,很难通过无结构稀疏性实现实时速度提升。例如,最近的零样本剪枝发展如SparseGPT找到了60%的无结构稀疏性,但尚未带来任何实时速度提升。 理想的LLM稀疏性应该不需要模型重新训练,保持质量和上下文学习能力,并在现代硬件上实现实时速度提升。为了实现这些苛刻的要求,我们超越了之前工作中的静态稀疏性。我们设想的是上下文稀疏性,即对于输入,能够产生与完整模型相同输出的小型、输入依赖的注意力头和MLP参数集合。受到LLM和隐藏层之间的联系的启发,我们提出了DEJAVU,这是一个系统,它使用低成本算法实时预测每层输入的上下文稀疏性,并结合异步和硬件感知的实现来加速LLM推理。我们验证了DEJAVU可以在不牺牲模型质量的情况下,将OPT-175B的推理延迟比最先进的FasterTransformer减少超过2倍,比广泛使用的Hugging Face实现减少超过6倍。代码可在 https://github.com/FMInference/DejaVu 上获取。 1 引言 像GPT-3、PaLM和OPT这样的大型语言模型展示了巨大的参数数量释放出令人印象深刻的性能和新兴的上下文学习能力——它们可以通过条件输入输出示例来执行任务,而无需更新它们的参数。然而,它们在推理时非常昂贵,特别是对于延迟敏感的应用。理想的推理时模型应该在保持预训练LLM的性能和特殊能力的同时,使用更少的计算和内存。最简单和最自然的方法是稀疏化或剪枝,这在LLM时代之前有着悠久的历史。 不幸的是,在保持质量和上下文学习能力的同时,在实时速度上加快推理时间的稀疏LLM仍然是一个挑战。虽然稀疏化和剪枝已经被深入研究,但由于在现代硬件如GPU上的质量和效率权衡不佳,它们在LLM上的广泛应用仍然有限。首先,在数千亿参数的规模上重新训练或迭代剪枝模型是不可行的。因此,迭代剪枝和彩票假设的方法只能应用于较小规模的模型。其次,保持LLM的上下文学习能力的稀疏性很难找到 "Deja Vu: 在推理时为高效大型语言模型实现上下文稀疏性 0 20 40 60 80 Transformer 层 0.00 0.20 0.40 0.60 0.80 1.00 上下文稀疏性 OPT-175B 上下文稀疏性 1 2 3 4 5 6 7 8 理论减少 0.7940.7960.7980.8000.8020.8040.8060.8080.810 准确率 静态稀疏性 非上下文稀疏性 上下文稀疏性 准确率-效率权衡 图1. 对于给定输入,LLMs(大型语言模型)具有高达85%的上下文稀疏性。与非上下文稀疏性或静态稀疏性相比,上下文稀疏性具有更好的效率-准确率权衡。 我们假设,就像马尔可夫模型和经典的维特比算法一样,对于预训练的LLMs,对于任何输入都存在上下文稀疏性。 如果这个假设成立,它将使我们能够在推理时动态地切断特定的注意力头和MLP(多层感知器)参数,而不需要修改预训练模型。然而,存在三个挑战。 存在性:验证这种上下文稀疏性的存在并不简单,而且简单的验证可能会非常昂贵。 预测:即使上下文稀疏性存在,提前预测给定输入的稀疏性也是具有挑战性的。 效率:即使可以预测稀疏性,实现端到端的实时速度提升也可能很困难。 以OPT-175B为例,在8×A100 80GB机器上,一个MLP块的延迟仅为0.2毫秒。如果没有快速的预测和优化实现,开销很容易增加LLM的延迟,而不是减少它。 在这项工作中,我们如下解决这些挑战: 存在性:幸运的是,我们用一个令人惊讶的简单方法验证了上下文稀疏性的存在。为了实现几乎相同的输出,上下文稀疏性平均为85%的结构稀疏,从而可能为每个特定输入实现7倍的参数减少,同时保持准确率。在探索上下文稀疏性的过程中,我们做出了重要的经验观察,并建立了对LLM主要组成部分的理论理解,以帮助解决预测和效率挑战。 Deja Vu AttentionkMLPkPredictorPredictorPredictorAttentionk+1……图2. DEJAVU使用前瞻预测器来规避预测成本:给定第k个注意力层的输入,它们预测第k个MLP层的上下文稀疏性,给定第k个MLP层的输入,它们预测下一层的注意力头稀疏性。 预测:我们发现上下文稀疏性不仅取决于单个输入标记,还取决于它们的交互。图1显示,仅凭动态信息,稀疏性预测是不准确的。只有具有足够上下文信息的标记嵌入,我们才能准确预测稀疏性。另一个发现是,每层的上下文动态稀疏性可以根据“相似性”来预测,这种相似性基于层参数和前一层输出之间的相似性,后者携带了标记嵌入的即时上下文混合。 效率:由于在推理时模型参数是静态的,受到经典最近邻搜索文献及其在高效深度学习中的应用的启发,我们可以将基于相似性的预测公式化为最近邻搜索(NNS)问题。然而,正如提到的,开销可能难以克服,因为我们需要在每一层之前进行实时预测。幸运的是,我们利用了LLM中的一个现象,即由于残差连接,标记嵌入在连续层中变化缓慢。由于几个连续层的输入非常相似,我们可以设计一个异步前瞻预测器。 基于我们的发现,我们提出了一个系统DEJAVU,利用上下文稀疏性,为对延迟敏感的应用实现高效的LLM。 •在第4.1节和第4.2节中,我们提出了一种低成本的基于学习的算法,用于实时预测稀疏性。给定特定层的输入,它预测下一层的相关子集注意力或MLP参数,并仅加载它们进行计算。 •在第4.3节中,我们提出了一个异步预测器来避免顺序开销。理论保证证明," 请注意,这段翻译是基于您提供的英文内容进行的直译。由于技术文档通常包含专业术语和特定领域的知识,可能需要相关领域的专家进行更准确的翻译和解释。 "Deja Vu: 在推理时为高效的大型语言模型(LLMs)设计上下文稀疏性 跨层设计足以实现准确的稀疏性预测。 在集成了针对硬件的稀疏矩阵乘法实现后,DEJAVU 能够将开源的大型语言模型(如 OPT-175B)的端到端延迟减少超过2倍,与英伟达的最先进的库 Faster-Transformer 相比,以及在小批量大小下与广泛使用的 Hugging Face 实现相比,减少超过2倍。此外,我们展示了对 DEJAVU 不同组件的多项消融分析及其与量化技术的兼容性。 2 相关工作和问题阐述 我们首先简要讨论了关于高效推理的丰富文献。然后,我们介绍了我们设置中的延迟分解。最后,我们提供了一个正式的问题阐述。 2.1 推理中的量化、剪枝、蒸馏 几十年来,为了机器学习模型推理,已经研究了各种放松技术。主要有三种技术:量化、剪枝或稀疏化,以及蒸馏。它们是正交领域,通常在不同设置中表现出色。最近,有活跃的研究试图在大型语言模型推理中应用一种或多种这样的技术。附录 A 中提供了更多讨论。 2.2 大型语言模型推理延迟分解 大型语言模型的生成过程包括两个阶段:提示阶段将输入序列转换为大型语言模型的每个变换器块的键和值,这类似于大型语言模型训练的前向传递过程;而标记生成阶段则利用并更新键值(KV)缓存,逐步生成标记,当前标记生成依赖于之前生成的标记。本文研究的设置是标记生成阶段容易主导端到端推理时间。如表1所示,由于加载模型参数的I/O延迟,生成长度为128的序列比处理长度为128的序列作为提示要花费更长的时间。此外,表2显示,在大型语言模型中,注意力和多层感知机(MLP)都是瓶颈,例如,在175B模型中,加载MLP参数大约占用总I/O的2/3,注意力头占用另外1/3。此外,在张量并行设置中,GPU之间有两次通信,一次在注意力块之后,另一次在MLP块之后。如表3所示,GPU之间的通信占用了大约15%的标记生成延迟。本文专注于使注意力和MLP更高效。通信成本意味着,如果跳过所有变换器块,这种加速的上限大约是6倍。 表1. 提示与标记生成的理论分解。 TFLOPs I/O 计算延迟 I/O 延迟 提示 128 44.6 330 GB 17.87 20.6 标记生成 128 44.6 41 TB 17.87 2600 表2. 在生成一个标记时,一个变换器层中的注意力块与MLP块的理论分解。 GFLOPs I/O 计算延迟 I/O 延迟 注意力块 1.21 1.12 0.00048 0.07 MLP 块 2.41 2.25 0.00096 0.14 表3. 在批量大小为1,提示长度为128的情况下,使用8个A100-80GB GPU生成1个标记的延迟分解。 All Reduce MLP 块 注意力块 其他 6 ms 19ms 13ms 2ms 2.3 问题阐述 目标是通过利用上下文稀疏性来减少大型语言模型的生成延迟。接下来,我们正式定义稀疏化的注意力和MLP块。 稀疏化的MLP:一个MLP块有两个线性层,W1, W2 ∈ R^(d×4d)。假设 y ∈ R^(1×d) 是当前生成步骤中MLP块的输入。让 W1_i, W2_i ∈ R^(d×1) 表示线性层的每一列。通过上下文稀疏性,只有一小部分需要进行计算。让 SM ⊆ [4d] 表示输入 y 的神经元集合。稀疏化的MLP计算是 MLP_SM = σ^T,其中 σ 是激活函数,例如 ReLU 或 GeLU。请注意,由于第一线性层的计算导致稀疏激活,第二线性层也被稀疏化。 稀疏化的注意力:让 X ∈ R^(n×d) 表示所有标记的嵌入。 让 y ∈ R^(1×d) 是当前生成步骤中多头注意力的输入。假设有 h 个头。对于每个 i ∈ [h],我们使用 WK_i, WQ_i, WV_i ∈ R^(d×dht) 表示第 i 个头的键、查询、值投影,WO_i ∈ R^(dh×d) 表示输出投影。通过上下文稀疏性,我们称 "Deja Vu: 在推理时为高效大型语言模型实现上下文稀疏性 0 20 40 60 80 Transformer 层 20% 40% 60% 80% 100% 未激活的头 OPT-30B OPT-66B OPT-175B 注意力头的上下文稀疏性 0 20 40 60 80 Transformer 层 90% 92% 94% 96% 98% 100% 未激活的神经元 OPT-30B OPT-66B OPT-175B MLP 块的上下文稀疏性 图 3. 在图中,我们绘制了未激活注意力头的百分比。通过仅保留对输入产生大输出范数的头,我们可以为给定的标记沉默超过 80% 的注意力头。 在图中,我们绘制了对 MLP 层施加的平均稀疏性。 我们可以为给定的标记将超过 95% 的 MLP 参数置零。 3 预训练的大型语言模型(LLMs)是上下文稀疏的 在本节中,我们提出了几个关键观察和理论理解,这些是 DEJAVU 设计的基础。我们首先在第 3.1 节测试了上下文稀疏性假设,并验证了在预训练的 LLMs 中存在上下文稀疏性。然后,在第 3.2 节,我们建立了一个理解,即即使在密集训练 LLMs 时,上下文稀疏性也会自然发生。最后,在第 3.3 节,我们提出了关于残差连接的观察,并分析了它们与上下文稀疏性的关系。 3.1 上下文稀疏性假设 受先前剪枝文献的启发,我们发现一个惊人的简单方法足以研究和验证我们的假设。在本节中,我们描述了测试程序、观察细节和这项研究的见解。 验证:我们的测试是在 OPT-175B、66B 和 30B 模型上进行的,以及各种下游数据集,如 Open-BookQA 和 Wiki-Text。我们通过模型的两次前向传递来发现每个输入示例的上下文稀疏性。在第一次传递中,我们记录了一部分参数,特别是哪些注意力头和 MLP 神经元对输入产生了大的输出范数。在第二次传递中,每个输入示例仅使用记录的参数子集进行计算。令人惊讶的是,这两次前向传递在所有上下文学习和语言建模任务上产生了相似的预测或性能。 观察:图 3 显示,平均而言,我们可以在注意力头上施加高达 80% 的稀疏性,在 MLP 神经元上施加高达 95% 的稀疏性。如第 2 节所述,OPT-175B 模型的 MLP 参数是注意力块的两倍。因此,这里的总稀疏性约为 85%。由于这些都是结构化稀疏性,准确预测它们可能会导致 7 倍的速度提升。 见解:在推理时在 MLP 块中发现上下文稀疏性是直观的,因为它们的激活函数,例如 ReLU 或 GeLU。类似的观察也被其他人提出。然而,我们在注意力层中发现上下文稀疏性是令人惊讶的。请注意,发现注意力中的上下文稀疏性与头剪枝不同。我们交叉检查了不同示例具有不同的上下文稀疏性。尽管对于给定的示例,80% 的参数没有包含在路径中,但它们可能被其他示例使用。接下来,我们将尝试理解为什么注意力块中存在上下文稀疏性。 3.2 注意力层中的标记聚类 在上一节中,我们已经验证了在 LLMs 中对于给定输入存在上下文稀疏性。在本节中,我们试图理解这种现象的原因,特别是在注意力层。我们首先展示了对注意力的深入观察。然后,我们提出了一个假设,即自注意力是概念上的聚类算法。最后,我们展示了支持这一假设的分析证据。 观察:图 4 显示了同一层中三个不同头对示例输入的注意力图。下一个要预测的标记是“Truck”。颜色越深表示注意力分数越高。我们观察到,中间的头是一个相对均匀的标记混合头,而顶部和底部的头是“重击手”注意力头。 不出所料,仅选择重击手头而不选择均匀头不会影响预测,因为均匀头没有建模或编码重要的标记交互。在下一节中,我们还将详细解释选择均匀注意力头和输出范数小的头的标准高度相关。 假设:我们假设注意力头执行的是均值漂移聚类。 回顾第 2.3 节中定义的符号。对于当前层的第 i 个头,X= [x1,...,x n]⊤∈Rn× "Deja Vu: 在推理时为高效LLMs设计的上下文稀疏性 第L层:这家水果运输公司提供不同的车辆选择,如汽车和[MASK]卡车 图4. 我们可视化了一个示例句子的三个不同头的注意力分数。第42头和第44头在特定标记上给出较高的注意力分数,而第43头则更加均匀。 当前标记的预测值ˆyi变为ˆyi=归一化=归一化,这有一个固定点y=γmi,对于任何标量γ。这个迭代过程类似于均值漂移聚类,它简单地执行迭代y←mi直到收敛。这有一个明显的固定点y=mi。因此,自注意力头可以被视为一个均值漂移步骤,将不同标记的输入嵌入推到一起,如果它们已经在由WQi⊤指定的投影空间中是邻居。不同的头学习不同的投影空间来进行聚类。这些动态解释了为什么标记嵌入在经过更多层后倾向于聚类,导致聚类成员之间的高注意力分数,以及非成员的低分数。此外,不同头的聚类模式也不同。 上述分析不仅提供了对预训练LLMs中自然存在的上下文稀疏性的理解,而且启发了我们在第4节中为DEJAVU设计的基于“相似性”的稀疏性预测。 3.3 跨层缓慢变化的嵌入 我们首先提出我们的观察结果,即嵌入在连续层之间变化缓慢。然后我们对这一现象进行了详细分析。最后,我们展示了它与上下文稀疏性的密切联系。详细信息在B节中。 在连续层中高度相似的嵌入:在图5中,我们展示了对于相同的输入,7种不同大小的OPT模型中,连续两层之间的嵌入或激活的余弦相似度异常高。具体来说,我们在执行C4验证集上的OPT模型推理时收集了每一层的激活。以OPT-175B为例,从第二层开始,任何两个连续层之间的相似度约为0.99,这表明当输入通过模型时,其嵌入的方向变化缓慢。有趣的是,最剧烈的变化发生在第一层。此外,我们增加了间隔并研究了层l和层l+n之间的嵌入的相似度,如图5所示。随着间隔的增加,相似度如预期般下降,而各种选择之间的余弦相似度差异在较浅层时较小。我们绘制了平均相似度,标准差由阴影表示。附录B中呈现了更多模型的类似图表。 与残差的关系:我们验证了LLM推理中嵌入的高相似度是由于残差连接。我们首先剖析了每个Transformer层内的计算图,以理解这一现象背后的原因。Transformer层内有两个残差连接,一个围绕注意力块,另一个围绕MLP块。残差连接可以写为X+F,其中F是多头注意力或两个MLP层。在图5和图5中,我们确实可以看到∥X∥显著大于∥F∥,这证实了嵌入变化缓慢是因为残差范数大。 与上下文稀疏性的联系:我们更深入地尝试理解大残差范数背后的原因,通过数学建模。我们发现,小∥F∥的一个可能原因是高稀疏性。对于MLP块,高稀疏性可能导致F的小范数,因为大部分输出的范数很小。类似的推理适用于注意力块,因此大量的注意力头产生小范数输出。 残差两边的界限:除了经验推理外,我们通过数学建模进一步探索了大残差范数的原因。我们发现,对于MLP块,当输入嵌入的稀疏性足够高时,F的范数会显著减小。对于注意力块,当注意力分数的稀疏性足够高时,F的范数也会显著减小。" "Deja Vu: 在推理时为高效LLMs实现上下文稀疏性 我们正式地从数学上定义了LLMs的计算。 在我们的计算模型下,我们可以展示我们在实际实验中观察到的收缩性质。 证明在附录G、H、I中。 引理3.1. 设0<ϵ1<ϵ2<1为收缩因子的下界和上界。设x为输入,y为输出。我们有残差连接y=x+F。对于MLP块F,我们有ϵ1≤ ||y−x||2≤ϵ2。对于注意力块F,我们也有ϵ1≤ ||y−x||2≤ϵ2。 4 DEJAVU 在本节中,我们介绍了我们为LLMs推理时上下文稀疏性搜索的框架。我们在4.1节介绍了MLP的稀疏性预测器,在4.2节介绍了注意力头的稀疏性预测器。DEJAVU的工作流程如图2所示。4.3节讨论了如何利用我们对LLMs的观察来避免稀疏性预测的开销,并提供理论保证。在4.4节,我们提出了优化的实现,以实现端到端的延迟减少。更多细节在D节中呈现。 4.1 MLP块中的上下文稀疏性预测 正如2节所解释的,MLP块是LLM生成的主要瓶颈之一。在本节中,我们讨论了如何在MLP块中通过上下文稀疏性实现实时速度提升。 挑战 图3显示,对于给定的token,95%的上下文稀疏性是可能的。MLP块中的上下文稀疏性可以在计算激活后被识别。然而,这仅展示了上下文稀疏性的存在,但在效率方面没有带来好处。需要快速而精确的预测来利用上下文稀疏性以实现端到端的效率。简单的做法是随机选择一部分神经元。不出所料,随机选择无法准确识别上下文稀疏性,导致模型性能大幅下降。 近邻搜索问题:回想我们通过记录哪些神经元产生显著的范数来验证上下文稀疏性的存在。本质上,给定输入,目标是搜索与输入有高内积的神经元,因为激活函数“过滤”低激活。因此,我们将MLP层的上下文稀疏性预测公式化为内积度量下的典型近邻搜索问题。 定义4.1. 设c和τ为两个参数。给定一个n维数据集W1⊂Sd−1在单位球上,-MaxIP的目标是构建一个数据结构,给定查询y∈Sd−1使得max w∈W1⟨y,w⟩≥τ,它检索一个向量z从W1满足⟨y,z⟩≥c·max w∈W1⟨y,w⟩。 备注4.2. 我们的W1和y在MLP块中可以被视为定义4.1中的数据集和查询。设计 标准的最先进的近邻搜索方法和实现会减慢计算。以OPT-175B为例,其中dis 12288。HNSW需要超过10ms,FAISS需要超过4ms,而MLP计算仅为0.2ms。高维度和在GPU上的数据结构实现复杂性使得搜索时间比MLP计算更长。因此,我们选择神经网络分类器作为我们的近邻搜索方法,以利用GPU上的快速矩阵乘法。对于每个MLP块,我们训练一个小的两层全连接网络来预测上下文稀疏性。收集训练数据很简单,因为我们使用密集计算知道上下文稀疏性。训练算法总结在算法1中。W1中的稀疏计算有两个步骤:给定y,稀疏性预测器SPM预测权重W1中重要的神经元集合SM。计算定义在方程1中的稀疏化MLP。注意这里的MLP中的稀疏性是高度结构化的。 算法1 稀疏性预测器训练 输入:一个预训练的LLM块,参数集M,块M={xi}i∈[N],阈值t 稀疏预测器SP P+←∅, P−←∅ for i=1→N do P+←P +∪{|mr∈M, m r≥t} P−←P−∪{|mr∈M, m r