“LONGLORA: EFFICIENT FINE-TUNING OF LONG-CONTEXT LARGE LANGUAGE MODELS”
香港中文大学和 MIT 联合提出全新大模型微调方法 LongLoRA。只要两行代码 + 11 个小时微调,就能把大模型 4k 的窗口长度提高到 32k。规模上,最长可以扩展到 10 万 token。
论文地址:https://arxiv.org/pdf/2309.12307.pdf
Github地址:https://github.com/dvlab-research/LongLoRA
摘要
本文提出了一种高效的Fine-tuning方法LongLoRA,可以在有限的计算成本下扩展预训练大型语言模型(LLMs)的上下文大小。LongLoRA通过稀疏局部注意力和可训练的嵌入和归一化来实现上下文扩展,同时保留原始架构和兼容现有技术。实验结果表明,LongLoRA在各种任务上表现出色,可在单个8×A100机器上将LLaMA2 7B从4k上下文扩展到100k,或将LLaMA2 70B扩展到32k。同时,作者还提供了一个包含3k个长上下文问答对的数据集LongQA,以供监督Fine-tuning使用。
简介
LLMs通常使用预定义的上下文大小进行训练,例如LLaMA为2048个标记,LLaMA2为4096个标记。预定义大小限制限制了其在处理长文档或长问题时的效果。为了解决这个问题,一些最近的研究通过训练或微调LLM来扩展其上下文长度。然而,从头训练LLM的计算挑战很大,微调现有的预训练LLM也非常昂贵。因此,我们需要寻找一种高效的方法来扩展LLM的上下文窗口。
本文介绍了一种fine-tune预训练LLM的方法,即利用低秩矩阵进行自注意力块中的线性投影层的适应。然而,实验结果表明,这种方法在训练长上下文模型时既不够有效也不够高效。在效果方面,低秩适应会导致长上下文扩展的困惑度较高。在效率方面,无论是否使用LoRA,由于标准的自注意力机制,计算成本随着上下文大小的扩展而急剧增加。因此,即使使用LoRA,标准LLaMA2模型的训练时间也会随着上下文窗口的扩展而大幅增加。
LongLoRA扩展了预训练LLMs的上下文窗口。我们发现,短注意力也能在训练过程中近似长上下文。我们提出了一种名为S2-Attn的高效替代标准自注意力的方法。我们将上下文长度分成几个组,并在每个组中进行注意力计算。在一半的注意力头中,我们将标记向后移动半个组的大小,以确保相邻组之间的信息流动。例如,我们使用组大小为2048的S2-Attn来近似总共8192个上下文长度的训练。这与Swin Transformer有相似的思路。
S2-Attn fine-tuning模型在推理过程中保留了原始的注意力架构,这有助于现有的优化和基础设施。同时,我们的方法也可以应用于常见的LLM技术,例如FlashAttention-2。与短注意力相似,这种方法类似于LLM的预训练阶段的注意力方案。与此不同的是,其他高效的注意力机制,如扩张或稀疏注意力,在预训练阶段与标准风格存在较大差距。
我们通过实验证明,可学习的嵌入和归一化层是解锁长上下文LoRA微调的关键。嵌入和归一化层在整个LLM中占据了很小比例的参数。例如,在LLaMA2 7B中,嵌入层的参数占总参数的小于2%,而归一化层的参数占总参数的小于0.004%。对于更大的LLM,这个比例会进一步降低。
LongLoRA可以有效地扩展上下文窗口,提高模型性能。实验结果表明,LongLoRA可以在单个8×A100机器上将LLaMA2 7B模型fine-tune到100k上下文,或将70B模型fine-tune到32k上下文,同时计算成本大大降低。
LongQA是一个包含超过3k个长问题和相应答案的数据集,用于监督微调(SFT)。我们设计了各种类型的问题,涵盖技术论文、科幻小说和其他书籍。SFT对于提高LLMs的聊天能力非常重要。附录中展示了我们训练模型的一些示例。
相关工作
Long-context Transformers。目前增加transformer上下文长度的方法,包括检索式和修改注意力机制的方法。本文提出的方法是对注意力机制的近似,但与标准注意力机制相似,可以在推理过程中保持完整的注意力机制,从而可以对预训练的LLMs进行微调。
Long-context LLMs。最近的研究尝试通过微调来扩展LLM的上下文长度,但大多数方法都需要昂贵的计算资源。本文提出了一种高效的方法,可以节省微调成本,同时保留原始注意力的质量。这段文字介绍了一些关于LLMs的位置嵌入修改方法,包括Position Interpolation、NTK-aware、Yarn、positional Skipping和out-of-distribution相关方法。而我们的方法则专注于在推理过程中保持原始架构的高效微调。我们的模型在实验中应用了Position Interpolation方法。
高效微调。本文基于经典的高效微调方法LoRA。除了LoRA之外,还有许多其他参数高效的微调方法,包括prompt tuning、prefix tuning、hidden state tuning、bias tuning和masked weight learning。Input-tuning引入了一个适配器来调整输入嵌入。虽然我们的输入嵌入层也是可训练的,但这对于长上下文扩展来说还不够。我们在实验中对层类型进行了全面分析,见表3。
LONGLORA
转移注意力(SHIFT SHORT ATTENTION)
标准的自注意力模式计算成本为O(n^2),导致长序列的LLMs内存成本高且速度慢。为了避免这个问题,在训练过程中,我们提出了移位短注意力(S2-Attn)。
试点研究。本研究通过实验验证了微调的重要性,没有微调的模型在上下文长度增加时表现更差。作者建立了一个标准基线模型,使用全注意力和全微调,能够在不同上下文长度下保持一致的好质量。
通过将长的输入分成几组进行自注意力计算,以降低计算成本。但是,这种方法在处理非常长的上下文时仍然存在困难,因为不同组之间没有信息交换。
为了引入不同组之间的通信,我们使用了一种移位模式。在每个自注意力头中,我们将组划分向后移动半个组大小。以总共8192个上下文长度为例,在模式1中,第一组从第1个到第2048个标记进行自注意力计算。在模式2中,组划分向后移动1024个标记。第一组的自注意力计算从第1025个标记开始,到第3072个标记结束,而第一个和最后1024个标记属于同一组。我们在每个自注意力头的前半部分和后半部分分别使用模式1和模式2。这种方式不增加额外的计算成本,但使得不同组之间的信息流动成为可能。我们在表1中展示了它与标准注意力基线的接近程度。
一致性到充分注意。S2-Attn可以在fine-tuning时提高效率,同时支持全局attention测试。与其他高效attention设计相比,S2-Attn具有更好的性能。一些高效的注意力设计在长上下文微调中是不可行的。Transformer模型在从头训练时与标准的全注意力存在差距,因此不适用于长上下文微调。S2-Attn支持全注意力测试,尽管模型是用短移注意力进行微调的。其他注意力模型如扩张注意力和跨步稀疏注意力也可以用于长上下文微调,但模型必须使用微调时的注意力进行测试。移动注意力可以防止模型过度拟合特定的注意力模式。在S2-Attn中,仅使用模式1或2是不起作用的。
实现简单。短注意力的转移很容易实现。只需要两个步骤:(1)将令牌在半注意力头中进行转移,(2)将特征从令牌维度转置到批次维度。只需要两行代码就足够了。我们在算法1中提供了一个类似于PyTorch风格的代码。接下来,我们进行了一项试点研究,并逐步澄清了我们设计的原因。
改进适用于长上下文的LORA
LoRA是一种高效且流行的方法,用于将LLMs适应到其他数据集。与完全微调相比,它节省了大量可训练参数和内存成本。然而,从短上下文长度适应LLMs到长上下文长度并不容易。我们在实证中观察到LoRA和完全微调之间存在明显差距。如表3所示,随着目标上下文长度的增加,LoRA和完全微调之间的差距也在增大。而且,增加LoRA的秩也不能减小这种差距。
为了弥补这一差距,我们开放了嵌入和归一化层进行训练。这些层所占的参数很少,但对于长上下文适应具有重要作用。特别是对于归一化层,其参数仅占整个LLaMA2 7B的0.004%。我们在实验中将这个改进版本称为LoRA +。
实验
实验设置
模型。本文使用预训练的7B、13B和70B LLaMA2模型进行扩展。这些模型的最大扩展上下文窗口大小分别为100k、65536和32768。这些模型的位置索引使用了位置插值进行重新缩放。
训练过程。使用了AdamW优化器和线性学习率预热,采用了单机8个GPU的方式进行训练,每个GPU的batch size为1,梯度累积步数为8,全局batch size为64,训练1000步。
数据集。本文使用Redpajama数据集进行训练,评估模型在PG19和Arxiv Math proof-pile数据集上的长序列语言建模性能。其中,PG19测试集包含100个文档,Arxiv Math proof-pile数据集采用Position Interpolation进行处理,并使用S=256的滑动窗口方法计算困惑度。
我们构建了一个名为LongQA的长文本问答数据集,用于监督微调。尽管使用Redpajama(计算机,2023年)进行微调的模型表现出良好的困惑度,但它们的聊天能力有限。我们收集了3000多个问题-答案对,涉及技术论文、科幻小说和其他书籍等材料。我们设计的问题包括摘要、关系、角色和其他与材料相关的细节。详细信息请参考附录。
结果
长序列语言建模。通过在Proof-pile和PG19数据集上的实验,证明了我们的模型在更长的上下文长度下表现更好。通过增加上下文窗口大小,我们的模型在相同的训练和评估上下文长度下,可以显著降低困惑度。
本文研究了在单个8×A100机器上fine-tune的最大上下文长度,将LLaMA2 7B、13B和70B扩展到100k、65536、32768上下文长度,LongLoRA在这些极大设置上取得了有希望的结果。此外,我们发现对于扩展模型的小上下文大小存在一些困惑的恶化。这是Position Interpolation的已知限制。
基于检索评估。我们还进行了关于长上下文的检索实验。通过与其他开放式LLM模型在LongChat的主题检索任务上的比较,我们的模型在不同长度的对话中实现了可比较的性能。与LongChat-13B相比,我们的模型通过下一个标记生成在开放的RedPajama上进行了高效的适应,并在16k评估中略微超过了LongChat-13B的性能。
消融学习
配置。本文分析了LLaMA2 7B模型的不同层类型,包括FFN、Proj、Attn和Others,并分析了FLOPs。随着上下文长度的增加,Attn的比例急剧增加,例如在8192上下文长度时,Attn占总FLOPs的24.5%,而在65536上下文长度时,它增加到了72.2%。当使用S2-Attn时,它降至39.4%。
微调步数。在LLaMA2 7B模型上进行8192上下文长度的微调步数,我们在PG19验证集上报告了困惑度和微调步数之间的关系。我们发现,在没有微调的情况下,模型在步数0时具有有限的长上下文能力,例如,困惑度为15.82。我们展示了困惑度迅速下降的情况。完全微调比低秩训练更快地收敛。在200步之后,它们趋于接近,最后没有太大的差距。
注意力模式。在Redpajama数据集上,我们使用LLaMA2 7B模型进行微调,并在PG19验证集上评估困惑度。我们发现在LongLoRA中,层间切换是可接受的,但不是最佳选择。此外,将所有注意力头设置为模式1或模式2是不起作用的。
本文测试了不同类型的高效注意力设计,包括扩张注意力和步幅稀疏注意力。扩张注意力在完全微调中表现良好,但在低秩适应中表现不佳。微调步幅稀疏注意力有害,与预训练模型中的全注意力存在较大差距。
总结
本文提出了LongLoRA方法,可以有效地扩展LLMs的上下文长度,具有比标准全微调更少的GPU内存成本和训练时间,同时最小化精度损失。在架构层面上,我们提出了shift short attention来近似标准的自注意力模式。在训练层面上,我们通过可训练的归一化和嵌入来弥合LoRA和全微调之间的差距。我们的方法可以在单个8×A100机器上将LLaMA2 7B扩展到100k上下文长度,将70B模型扩展到32k上下文长度。我们相信LongLoRA是一种通用方法,可以与更多类型的LLMs和位置编码兼容,我们计划在未来进行研究。