2.5%KV缓存保持大模型90%性能,大模型金字塔式信息汇聚模式探秘|开源

561次阅读
没有评论

蔡泽凡 投稿
量子位 | 公众号 QbitAI

用KV缓存加速大模型的显存瓶颈,终于迎来突破。

北大、威斯康辛-麦迪逊、微软等联合团队提出了全新的缓存分配方案,只用2.5%的KV cache,就能保持大模型90%的性能。

这下再也不用担心KV占用的显存容量过高,导致显卡不够用了。

2.5%KV缓存保持大模型90%性能,大模型金字塔式信息汇聚模式探秘|开源

该方法名为PyramidKV,顾名思义,在KV缓存压缩的过程中融入了金字塔型的信息汇聚方式。

在内存受限的情况下,PyramidKV表现非常出色,既保留了长上下文理解能力,又显著减少了内存使用。

目前,PyramidKV相关代码已经在GitHub开源

引入金字塔信息汇聚方式

随着模型尺寸的增大,推理需要的时间越来越多。KV cache作为推理加速的关键技术,通过缓存之前的解码步骤中计算出的Transformer的K和V矩阵减少后续解码时间。

但是,随着序列长度增大,需要缓存的KV cache会快速增长,占用大量显存。针对这一问题,之前的工作设计策略是对KV cache进行压缩。

实际上,长文本的推理加速和显存节省作为一个重要的话题,这涉及到广泛的大模型下游应用,比如检索增强生成(Retrieval-Augmented Generation)、上下文学习(In-Context Learning)受到广泛关注。

KV cache及KV cache的压缩能否有效帮助长文本实现推理加速成为广受关注的研究方向。

采用均一压缩策略,是最佳方案吗?

传统压缩方法的一个共同特点是,均对每个Transformer层使用同样的KV cache压缩设置,使用同样的方法压缩到同样的长度。

2.5%KV缓存保持大模型90%性能,大模型金字塔式信息汇聚模式探秘|开源

但PyramidKV团队发现,对KV cache进行极致压缩情况下上述方法的表现,发现当超长文本压缩到极致小的KV大小时(从32k 长度压缩到64,即保留0.2%的KV cache长度)时,会面临严重的性能减弱。

于是作者提出了疑问:对每个Transformer层将KV cache压缩到同样的大小是否为最优方案?

为了回答上述问题,研究团队对大模型进行检索增强生成的机制进行深入分析。

作者研究了Llama模型进行多文档问答的逐层注意力图,发现了注意力层中的金字塔形信息汇聚模式(Pyramidal Information Funneling)的存在:

  • 在模型的低层(例如第0层)中,注意力得分呈现近似均匀分布,这表明模型在较低层时从所有可用内容中全局聚合信息,而不会优先关注特定的段落。
  • 当编码信息进行到中间层(6-18)时,逐渐转变为聚焦在段落内部的注意力模式 (Localized Attention)。在这个阶段,注意力主要集中在同一文档内的Token上,表明模型在单个段落内进行了段落内部的信息聚合。
  • 这种趋势在上层(24-30)继续并加强,本文观察到了“Attention Sink”和“Massive Activation”现象。

在这些层中,注意力机制极大地集中在少数几个关键Token上,因此只需要保留这些关键Token就能让输出保持一致并且减少显存占用。

2.5%KV缓存保持大模型90%性能,大模型金字塔式信息汇聚模式探秘|开源

这种注意力分配模式,即极高的注意力得分,表明模型已将信息聚合到这些关键标记中。

这种注意力现象显示了大模型对大量复杂的信息的进行编码的机制,最终得到生成准确答案所需的最关键信息。

根据以上的发现,作者认为之前的工作对所有Transformer层统一处理是低效的,因此不同Transformer层的注意力稀疏程度并不相同。在低层能观察到特别稠密的注意力,而在较高层则可以观察到非常稀疏的注意力。

因此,在不同层之间使用固定的 KV 缓存大小可能会导致性能不佳。这些方法可能在较高层的稀疏注意力中保留许多不重要的 tokens,而忽略了较低层密集注意力中的许多重要的 tokens。

每层注意力特点不同,分层施策才是正解

于是,作者选择了通过基于注意力模式动态分配缓存预算来提高压缩效率。

具体而言,PyramidKV在信息更加分散的较低层分配更多的KV cache缓存,而在信息集中于少数关键tokens的较高层减少KV cache缓存。

一旦为每一层确定了KV缓存预算,PyramidKV在每一个Transformer层中选择根据注意力选择要缓存的KV。

最后的部分Token的KV缓存,即Instruction Token,会在所有Transformer层中保留。

根据UIUC、普林斯顿等提出的SnapKV方法,剩余的KV的选择由从这些Instruction Token中获得的对其他的Token注意力分数来指导——

接收到更高注意力分数的Token被认为与生成过程更相关,因此其KV状态优先保存在GPU缓存中。

2.5%KV缓存保持大模型90%性能,大模型金字塔式信息汇聚模式探秘|开源

2.5%的KV cache,保持90%模型性能

为了评估PyramidKV的表现,作者使用最新的开源大模型Llama-3-8B-Instruct和Mistral-7B-Instruct,来对PyramidKV和其他方法进行对比。

测试示例以生成格式进行评估,所有任务的答案均通过贪婪解码生成,并使用 LongBench来评估PyramidKV在处理长上下文输入任务中的表现。

LongBench是一个精心设计的基准测试套件,用于测试语言模型处理长文档和复杂信息序列的能力。

该基准测试旨在对长上下文输入进行多任务评估,包括17个数据集,涵盖单文档问答、多文档问答、摘要生成、少样本学习、合成数据和代码生成等任务。

数据集的平均输入长度从1235个到18409个tokens不等,需要大量的内存来管理KV缓存。

对于所有这些任务,作者都遵循 LongBench推荐的标准指标。

结果,在64、96、128、256和512个KV cache缓存大小的设定下,PyramidKV在LongBench中均取得了优于baseline的效果。

2.5%KV缓存保持大模型90%性能,大模型金字塔式信息汇聚模式探秘|开源

在此基础上,作者还研究了两种不同的操作场景——节省内存场景(Memory-Efficient Scenario)和保持性能场景(Performance-Preserving Scenario),分别用于在内存和模型性能之间进行权衡。

PyramidKV在Longbench的多个任务和平均得分上均取得了优于baseline的效果。

值得注意的是,PyramidKV在size为128的设定下,在TREC任务(上下文学习问答挑战)中表现出显著优越的性能,相较于baseline,提高了20.的ACC结果。

2.5%KV缓存保持大模型90%性能,大模型金字塔式信息汇聚模式探秘|开源

总体而言,PyramidKV仅用12%的KV缓存就能保持完整的性能,并且在各种KV缓存大小的设定下和不同主干模型中始终优于其他方法,特别是在仅保留约128(0.7%)KV cache缓存的节省内存场景中,其性能优势尤为明显。

在具体任务的检查中,PyramidKV在TREC任务(上下文学习问答挑战)中表现出显著优越的性能,仅仅使用64的KV cache缓存大小(原始输入是5k长度)就能达到90%的性能。

这表明模型有效地聚合了样本中的任务信息,突出了在上下文学习任务上进一步研究的潜力。

下面的表则展示了PyramidKV使KV缓存的占用减少的情况。作者评估了Llama-3-8B-Instruct的内存消耗。

具体来说,作者发现在固定批量大小为1、输入长度为8192、模型权重为fp16格式的情况下,PyramidKV在不同缓存大小下显著减少了KV缓存的内存,还一定程度上保留了任务性能。

2.5%KV缓存保持大模型90%性能,大模型金字塔式信息汇聚模式探秘|开源

为了进一步理解PyramidKV在LongBench上的性能,作者还进行了“大海捞针”实验,将PyramidKV与SnapKV进行比较,并且对比128大小的KV缓存和完整的KV缓存。

在输入序列长度在2000到4000之间的中等上下文情况下,SnapKV在“大海捞针”测试中产生了越来越多的错误案例。

在输入序列长度超过6000的长上下文情况下,SnapKV显著降低了LLMs在评估中的性能。

相比之下,PyramidKV在大多数情况下减轻了这种弱化效应。下图展示了定量结果。分数越高、颜色越浅,表示着检索能力越强。

在该任务的平均得分中,完整KV得分为65.0,PyramidKV得分为62.6,而SnapKV得分为57.3。

2.5%KV缓存保持大模型90%性能,大模型金字塔式信息汇聚模式探秘|开源

此外,作者的实验表明,PyramidKV在上下文学习(In-Context Learning)的少样本学习任务中显著优于其他方法。

这表明KV cache缓存压缩在上下文学习中的应用前景广阔,这种方法有可能在受限的内存条件下实现更多样本的引入。

论文地址:
https://arxiv.org/abs/2406.02069
项目主页:
https://zefan-cai.github.io/PyramidKV.github.io/
GitHub:
https://github.com/Zefan-Cai/PyramidKV

Read More 

正文完
可以使用微信扫码关注公众号(ID:xzluomor)
post-qrcode
 
评论(没有评论)
Generated by Feedzy