1.摘要

长上下文建模是下一代语言模型的关键,然而标准注意机制的高计算代价给计算带来了巨大挑战.稀疏注意为在保持模型性能的同时提高效率提供了一个有前景的方向。因此作者提出了NSA,一种可本地训练的稀疏注意机制,它将算法创新与硬件优化相结合,以实现高效的长上下文建模。

NSA采用动态分层稀疏策略,将粗粒度令牌压缩与细粒度令牌选择相结合,以保持全局上下文感知和局部精度。

如图1所示,实验表明,经过NSA预训练的模型在一般基准测试、长上下文任务和基于指令的推理中保持或超过了全注意模型。同时,NSA在解码、前向传播和后向传播过程中的速度比Full Attention有了显著的提高,验证了其在整个模型生命周期中的有效性。

-

-

2.简介

研究界越来越多地认识到,长上下文建模是下一代大型语言模型的关键能力,它的应用范围从深度推理到多种现实应用的驱动、存储库级代码生成和多轮自主代理系统。

最近的突破使模型能够处理整个代码库、冗长的文档,在数千个令牌上保持连贯的多回合对话,并在长范围的依赖关系上执行复杂的推理。然而,高度复杂性的注意机制随着序列长度的增加而成为关键的延迟瓶颈。理论估计表明,在解码64 k长度的上下文时,softmax架构的注意力计算占总延迟的70-80%,强调了对更高效的注意力机制的迫切需要。

一种有效的长上下文建模方法是利用softmax注意力的固有稀疏性,其中选择性地计算关键Query-Key对可以在保持性能的同时显著地减少计算开销。尽管这些策略很有前途,但现有的稀疏注意方法在实际部署中往往存在不足。许多方法都没有达到理论上的加速比;而且,大多数方法主要集中在推理阶段,缺乏有效的训练时间支持来充分利用注意力的稀疏性模式。

为了解决这些限制,有效的稀疏注意的部署必须解决两个关键挑战:

  1. 硬件对齐的推理加速:将理论计算减少转换为实际速度提高需要在预填充和解码阶段期间的硬件友好算法设计,以减轻存储器访问和硬件调度瓶颈;
  2. 训练感知算法设计:使用可训练的运算符实现端到端计算,在保持模型性能的同时降低训练成本。这些需求对于真实世界的应用程序实现快速的长上下文推理或训练是至关重要的。

为了实现更有效和高效的稀疏注意力,作者提出了NSA,一个原生可训练的稀疏注意力架构,集成了分层令牌建模。如图2所示,NSA通过将键和值组织到时间块中,并通过三种注意路径(压缩的粗粒度标记、选择性保留的细粒度标记和本地上下文信息的滑动窗口)处理它们,减少了每次查询的计算量。然后,我们实现了专门的内核,以最大限度地提高其实际效率。

作者通过对真实世界语言语料库的综合实验来评估NSA。在具有260 B令牌的27 B参数Transformer骨干上进行预训练,作者评估了NSA在一般语言,长上下文评估和思维链推理评估中的表现。实验结果表明,NSA实现了相当或更好的性能完全注意基线,同时优于现有的稀疏注意方法。此外,与Full Attention相比,NSA在解码、前向和后向阶段提供了显著的加速比,对于较长的序列,加速比会增加。这些结果验证了分层稀疏注意设计有效地平衡了模型能力和计算效率。

-

-

3.重新思考稀疏注意力方法

现代稀疏注意力方法在降低Transformer模型的理论计算复杂度方面取得了重大进展。然而,大多数方法在推理过程中主要应用稀疏性,同时保留预训练的Full Attention骨干,这可能会引入架构偏差,限制其充分利用稀疏注意力优势的能力。在介绍原生稀疏架构之前,我们通过两个关键镜头系统地分析了这些限制。

有效推理的错觉

尽管在注意力计算中实现了稀疏性,但许多方法未能实现推理延迟的相应减少,主要是由于两个挑战:

  • 相位限制稀疏度(Phase-Restricted Sparsity)。例如H2O的方法在自回归解码期间应用稀疏性,而在预填充期间需要计算密集的预处理(例如,注意力图计算、索引构建)。相比之下,像MInference仅关注预填充稀疏性。这些方法无法在所有推理阶段实现加速,因为至少有一个阶段的计算成本与全注意相当。
  • 与高级注意力架构不兼容。一些稀疏注意力方法无法适应现代解码高效架构,如多查询注意力(MQA)和分组查询注意力(GQA),这通过跨多个查询头共享KV来显著减少解码期间的存储器访问瓶颈。尽管这些方法可以减少计算操作,但所需的KV高速缓存存储器访问仍然相对较高。虽然一些稀疏注意方法减少了计算量,但它们分散的内存访问模式与高级架构中高效的内存访问设计相冲突。

-

可训练稀疏性

在本2节中,作者探讨了现有稀疏注意力方法在训练阶段的局限性,指出这些方法大多专注于推理阶段的稀疏性优化,而在训练阶段的支持不足,导致无法充分发挥稀疏注意力的潜力。作者指出,将稀疏性应用于预训练好的全注意力模型会导致模型性能下降,因为这种后处理方式会偏离模型原有的优化轨迹。此外,训练阶段的计算效率对于现代大型语言模型的开发至关重要,但现有方法主要关注推理阶段,未能有效解决训练阶段的计算挑战。

作者还指出,一些稀疏注意力方法包含不可训练的组件,例如基于聚类或哈希的离散操作,这些操作会阻断梯度流,限制模型学习最优稀疏模式的能力。此外,即使某些稀疏注意力方法理论上可训练,但在实际训练中存在效率问题。例如,基于哈希的稀疏注意力方法在训练时需要加载大量分散的KV缓存,这与现代高效注意力技术(如FlashAttention)所依赖的连续内存访问和块计算不兼容,导致硬件利用率低下,训练效率显著降低。

基于这些分析,作者强调了开发一种原生可训练稀疏注意力机制的必要性,这种机制需要同时满足高效推理和训练的要求。这为后续介绍NSA(Native Sparse Attention)的设计和实现奠定了基础,NSA通过算法创新和硬件对齐优化,实现了高效的训练和推理,同时保持了模型性能。

-

-

4.方法

背景

注意力机制:广泛用于语言建模中,其中每个查询标记q_t计算针对所有前面的键k_{:t}的相关性得分,以生成值v_{:t}的加权和。形式上,对于长度为1的输入序列,注意操作被定义为:o_t=Attn(q_t,k_{:t},v_{:t})=\sum_{i=1}^{t}\frac{\alpha _{t,i}v_i}{\sum_{j=1}^{t}\alpha _{t,j}},\: \; \; \; \alpha _{t,i}=e^{\frac{q_t^Tk_i}{\sqrt{d_k}}},这里,\alpha _{t,i}表示q和k之间的注意力权重,d是键的特征维度。随着序列长度的增加,注意力计算在整体计算成本中变得越来越占主导地位,这对长上下文处理提出了重大挑战。

-

总体框架

为了使用具有自然稀疏模式的注意力,作者建议在给定每个查询q的情况下,用更紧凑和信息密集的表示键值对的集合\tilde{K_t},\tilde{V_t}来替换等式(1)中的原始键值对k_{:t}v_{:t}。具体来说,作者将优化的注意力输出正式定义如下:\begin{matrix} \tilde{K_t}=f_K(q_t,k_{:t},v_{:t})\\ \tilde{V_t}=f_V(q_t,k_{:t},v_{:t})\\ o_t^*=Attn(q_t,\tilde{K_t},\tilde{V_t}) \end{matrix},其中,\tilde{K_t},\tilde{V_t}是基于当前查询q和上下文记忆k_{:t}v_{:t}动态地构造的。

可以设计各种映射策略来获得不同类别的数据流,并将它们组合如下:o_t^*=\sum_{c\in C}g_t^c\cdot Attn(q_t,\tilde{K_t^c},\tilde{V_t^c}),如图2所示,NSA有三种映射策略C = {cmp,slc,win},分别表示键和值的压缩、选择和滑动窗口。可使用MLP匹配不同的映射策略。

说白了,就是不要用原来的KV计算注意力,现在采用某些策略生成简短版的KV,重新计算注意力

-

算法设计

既然我们知道了核心思想就是用重新生成(即重映射策略)的KV计算注意力,那么在本小节中,我们将介绍三种重映射策略的设计方法:令牌压缩、令牌选择和滑动窗口。

令牌压缩

公式7描述了NSA(Native Sparse Attention)机制中压缩注意力(Compressed Attention)的计算过程。具体来说,它展示了如何将连续的键(key)序列通过块压缩(block compression)的方式转换为压缩后的键表示(compressed keys)。

压缩后的Key表示被定义为:\tilde{K_t}^{cmp}=f_K^{cmp}(K_{:t})=\left \{ \varphi (k_{id+1:id+l})|1\leq i\leq \left \lfloor \frac{t-l}{d} \right \rfloor \right \}

符号说明:

  • K_{:i}:表示从序列开始到位置 i 的所有键向量。
  • \tilde{K_t}^{cmp}:压缩后的键表示,用于后续的注意力计算。

  • i_A:压缩块的滑动步长(stride),决定了每个压缩块之间的间隔。

  • i_B​:压缩块的长度,即每个块包含的键向量数量。

  • \varphi:一个可学习的多层感知机(MLP),用于将块内的键向量映射为一个单一的压缩键向量。

  • \left \lfloor \frac{i}{i_B} \right \rfloor:表示压缩后的块数量,即原始序列长度 i 除以块长度 iB​ 的向下取整。

目标:公式7的目标是将长度为 i 的键序列 K_{:i}压缩为更短的表示 \tilde{K_t}^{cmp},以减少计算量并保留关键信息。

压缩过程:

  1. 将键序列K_{:i}按照长度为 iB​ 的块进行划分。
  2. 对于每个块 k_{iA+1:iA+iB},使用 MLP 层\varphi将块内的键向量压缩为一个单一的表示(压缩键向量)。

  3. 将所有压缩后的块组合成一个新的键序列\tilde{K_t}^{cmp}

通过将键序列压缩为更短的表示,模型能够在解码过程中快速扫描全局上下文,同时减少计算复杂度。这种压缩策略使得NSA在处理长序列时更加高效,同时保留了全局上下文信息,为后续的细粒化注意力(Fine-grained Attention)和局部上下文注意力(Local Context Attention)提供了补充。

-

令牌选择

如果只使用压缩的键和值,则可能会丢失重要的细粒度信息。下面,我们描述高效令牌选择机制,该机制以低计算开销来识别和保存最相关的令牌。

以块为顺序的选择。以块为顺序的选择对于在现代GPU上实现高效计算至关重要。这是因为,与基于索引的随机读取相比,现代GPU架构对于连续块访问表现出显著更高的吞吐量。此外,逐块计算能够实现张量核的最佳利用,FlashAttention的基于块的设计就是一个例证。

为了实现逐块选择,作者首先将键、值序列划分为选择块。为了识别最重要的块用于注意力计算,需要为每个块分配重要性分数。下面,我们将介绍计算这些块级重要性得分的方法。

重要性分数计算。计算块重要性分数可能会引入显著的开销。幸运的是,压缩过程产生了中间注意力分数,我们可以利用它来诱导选择块重要性分数,公式为:p_t^{cmp}=Softmax(q_t^T\tilde{K_t}^{cmp})

符号说明

  • q_t^T:当前查询(query)向量,表示当前时间步 i 的查询。

  • \tilde{K_t}^{cmp}​:压缩后的键表示,由公式7计算得到。

  • p_t^{cmp}:每个压缩块的重要性分数,是一个概率分布。

计算过程

  • 计算当前查询向量 q_t^T与压缩后的键表示 \tilde{K_t}^{cmp}的点积 q_t^T\cdot \tilde{K_t}^{cmp}

  • 对点积结果应用softmax函数,得到每个压缩块的重要性分数p_t^{cmp}​。

作用:重要性分数p_t^{cmp}表示每个压缩块与当前查询向量的相关性。这些分数将用于后续的块选择过程,选择与当前查询向量最相关的块进行细粒化注意力计算。

当压缩块和选择块共享相同的分块方案时,即l'=l=d,因此,我们可以直接通过p_t^{slc}=p_t^{cmp}来直接获得选择块重要性得分p_t^{slc}。对于分块方案不同的情况下,我们根据它们的空间关系推导出选择块的重要性分数,有:p_t^{slc}[j]=\sum_{m=0}^{\frac{l'}{d}-1}\sum_{n=0}^{\frac{l}{d}-1}p_t^{cmp}[\frac{l'}{d}j+m+n]

目标:公式的目标是计算每个选择块的重要性分数p_t^{slc},这些分数将用于后续的块选择过程。

符号说明

  • p_t^{cmp}:压缩块的重要性分数,由公式8计算得到。

  • p_t^{slc}:选择块的重要性分数。

  • j:选择块的索引。

计算过程

  • 当压缩块和选择块的划分方式不同时,需要通过公式将压缩块的重要性分数映射到选择块。

  • 公式中通过双重求和\sum_{m=0}^{\frac{l'}{d}-1}\sum_{n=0}^{\frac{l}{d}-1}对压缩块分数进行空间上的聚合,以计算每个选择块的综合重要性分数。

  • 公式中的索引计算 p_t^{cmp}[\frac{l'}{d}j+m+n]确定了压缩块分数在选择块中的对应位置。

  • 最终,通过平均聚合的方式得到每个选择块的重要性分数p_t^{slc}[j]

对于采用GQA或MQA的模型,其中键值缓存在查询头之间共享,必须确保在这些头之间进行一致的块选择,以最大限度地减少解码期间的KV cache加载。一个组中所有头的共享重要性分数被正式定义为:p_t^{slc'}=\sum_{h=1}^{H}p_t^{slc_,(h)},此方法可确保同一组内磁头之间的数据块选择一致。

Top-n块选择:获得选择块重要性分数后,保留按块重要性分数排名的前几个稀疏块中的令牌,公式为:I_t=\left \{ i|rank(p_t^{slc'}[i])\leq n \right \} \\ \tilde{K}_t^{slc}=Cat[\left \{ k_{il'+1:(i+1)l'}|i\in I_t \right \}]

目标:根据块重要性分数p_t^{slc} 选择最重要的 top-n个块,根据选择的块索引集合I_i,提取这些块中的键(key)并拼接成一个新的张量。

计算过程

  • 对所有块的重要性分数p_t^{slc}进行排序,选择前top-n个分数最高的块。

  • rank(p_t^{slc'}[i])表示块 j 的重要性分数在所有块中的排名。

  • 如果块 j 的排名小于或等于 top-n​,则将该块的索引 j 添加到集合I_i中。

  • 对于每个选择的块索引 j∈I_i,提取该块中的键序列 k_{il'+1:(i+1)l'}​​。

  • 将所有选择的块中的键序列拼接成一个新的张量\tilde{K_t}^{slc}

-

滑动窗口

在注意力机制中,局部模式通常适应得更快,并且可以主导学习过程,这可能会阻止模型从压缩和选择令牌中有效地学习。为了解决这个问题,作者引入了一个专用的滑动窗口分支,它显式地处理本地上下文,允许其他分支(压缩和选择)专注于学习各自的功能,而不会被局部模式所局限。

具体来说,作者在一个窗口中维护最近的tokens \tilde{K_t}^{win}=k_{t-w:t},\tilde{V_t}^{win}=v_{t-w:t},并将不同信息源的注意力计算(压缩tokens,选择tokens,滑动窗口)隔离到单独的分支中。然后,这些分支输出通过可学习的门控机制聚合在一起。

为了进一步降低交叉注意力分支的边际计算开销,作者为三个分支提供独立的键和值。这种架构设计通过防止本地和远程模式识别之间的梯度干扰来实现稳定的学习,同时引入最小的开销。

在获得所有三个类别的键和值之后,作者计算最终的注意力输出。与上述的压缩、选择和滑动窗口机制一起,这形成了NSA的完整算法框架。

-

内核设计

在本节中,作者详细介绍了NSA(Native Sparse Attention)机制的硬件优化设计,特别是针对稀疏注意力的高效计算内核实现。

为了实现与FlashAttention相当的训练和预填充阶段的加速,作者基于Triton框架开发了专门的稀疏注意力内核,重点关注共享键值(KV)缓存的架构,如分组查询注意力(GQA)和多查询注意力(MQA),这些架构在现代大型语言模型中被广泛应用。

在设计中,作者提出了一种与硬件对齐的稀疏选择注意力内核,通过创新的查询分组策略来优化内存访问和计算效率。具体而言,内核在每个查询位置加载整个GQA组的所有查询头及其共享的稀疏键值块索引,然后通过连续加载这些稀疏块到片上存储器(SRAM)中,最小化内存加载次数。此外,内核的内部循环被设计为在Triton的网格调度器中运行,以简化和优化查询/输出循环,进一步提升效率。

这种设计通过消除冗余的键值传输和平衡GPU流处理器之间的计算负载,实现了接近最优的算术强度。实验表明,这种硬件对齐的稀疏注意力内核在长序列处理中展现出显著的加速效果,特别是在64k长度的上下文中,其速度提升比全注意力机制更为明显。这一优化不仅提高了稀疏注意力的计算效率,还为NSA在实际应用中的高效部署提供了关键支持。

-

-

5.实验

作者通过三个镜头来评估NSA:(1)一般基准性能,(2)长上下文基准性能,(3)思想链推理性能,与完全注意基线和最先进的稀疏注意方法进行比较。

预训练设置

作者采用了一个骨干相结合的分组查询注意力(GQA)和混合的专家(MoE),具有27 B总参数与3B可激活参数。该模型由30层组成,隐藏维度为2560。对于GQA,作者将组的数量设置为4,总共有64个注意力头。对于每个头部,查询、键和值的隐藏维度分别配置为d_q=d_k=192d_v=128。对于MoE,作者使用DeepSeek MoE结构,其具有72个路由专家和2个共享专家,并将top-k专家设置为6个。为了确保训练稳定性,第一层中的MoE被SwiGLU形式的MLP替换。

作者在8k长度文本的270B个标记上预训练全注意和稀疏注意模型,然后用YaRN在32k长度文本上继续训练和监督微调,实现长文本适应。这两个模型都经过足够的训练以完全收敛,从而确保公平比较。如图4所示,NSA和全注意基线的训练前丢失曲线显示了稳定和平滑的下降,NSA的表现始终优于全注意模型。

-

基线和方法

除了与Full Attention进行比作者外,我们还评估了几种最先进的推理阶段稀疏注意方法:H2O、infLLM、Quest,以及Exact-Top,这些方法跨越不同的稀疏注意力范例,包括KV缓存驱逐(KV-cache eviction),查询感知选择(query-aware selection)和精确的顶部稀疏选择(exact top-𝑛 sparse selection)。

对于一般评估,其中大多数样本的长度在稀疏注意基线的局部上下文窗口内,这些方法实际上等同于完全注意。因此,作者仅在此设置中呈现NSA和完全注意基线之间的比较结果。

在长上下文评估中,作者对所有基线方法进行比较,将所有稀疏注意力方法的稀疏性设置为相同,以确保公平比较。对于需要长文本监督微调的思想链推理评估,作者将比较限制为Full Attention,因为稀疏注意基线不支持训练。

性能比较

在本节中,作者详细评估了NSA(Native Sparse Attention)在多种任务上的性能表现,并与其他稀疏注意力方法及全注意力模型进行了对比。

在通用基准测试中,NSA在多个知识、推理和编程能力相关的基准测试上与全注意力模型进行了比较。尽管NSA采用了稀疏化设计,但其在大多数指标上均超越了全注意力模型,尤其是在推理相关任务(如DROP和GSM8K)上表现突出。这表明NSA通过稀疏化预训练能够迫使模型专注于关键信息,从而在过滤无关注意力路径的同时提升性能。

在长文本基准测试中,NSA在LongBench上的表现尤为突出,其平均得分超过了所有基线方法,包括全注意力模型和其他稀疏注意力方法。NSA在多跳问答任务(如HPQ和2Wiki)以及代码理解任务(如LCC)中均取得了显著的性能提升,这验证了其在处理长文本任务时的能力。此外,NSA在针头寻草(Needle-in-a-Haystack)测试中实现了完美检索精度,进一步证明了其在长文本中平衡全局感知和局部精确性的能力。

在链式推理任务中,作者通过监督微调的方式对NSA进行了数学推理能力的训练,并在AIME数学竞赛基准上进行了评估。结果显示,NSA在8k和16k上下文长度下的表现均优于全注意力模型,这表明NSA能够有效地捕捉长距离逻辑依赖,并在推理深度增加时保持足够的上下文密度。

总体而言,第4.3节的实验结果表明,NSA作为一种原生稀疏注意力机制,不仅在通用任务上与全注意力模型相媲美,而且在长文本任务和复杂推理任务中展现出显著的优势。这些结果验证了NSA在训练和推理阶段的高效性和实用性,为未来长文本建模和复杂任务处理提供了一种有效的解决方案。

-

效率比较

作者对NSA(Native Sparse Attention)的计算效率进行了全面分析,分别从训练阶段和解码阶段评估了其性能表现,并与全注意力机制进行了对比。

在训练阶段,作者通过在A100 GPU系统上使用Triton实现的NSA内核与FlashAttention-2进行了对比测试。结果显示,随着上下文长度的增加,NSA的加速比逐渐增大,在64k上下文长度时,NSA在前向传播和反向传播阶段分别实现了9.0倍和6.0倍的速度提升。这种显著的加速归功于NSA的硬件对齐算法设计,其通过块状内存访问模式最大化了Tensor Core的利用率,并通过精心设计的循环调度消除了冗余的键值(KV)传输,从而在训练阶段大幅降低了计算延迟。

在解码阶段,作者分析了NSA在自回归解码过程中的效率表现。由于解码阶段的内存访问瓶颈,注意力机制的解码速度主要取决于KV缓存的加载量。NSA通过压缩注意力、选择性注意力和滑动窗口注意力的结合,在每个解码步骤中仅需加载少量的压缩块、选择块和邻近块,从而显著减少了KV缓存的加载量。实验结果表明,随着解码长度的增加,NSA的加速比逐渐增大,在64k上下文长度时,解码速度提升了11.6倍。这种效率提升源于NSA在设计中对长序列的优化,使其在处理长文本时能够有效减少内存访问开销。

综合来看,NSA在训练和解码阶段均展现出显著的效率优势,尤其是在处理长序列时。通过硬件对齐的算法设计和稀疏化策略,NSA不仅在训练阶段实现了高效的计算,还在解码阶段大幅减少了内存访问开销,从而在长文本建模中实现了显著的加速效果。这些结果验证了NSA作为一种高效的稀疏注意力机制,在实际应用中的潜力和价值。

-

-

6.总结

这篇论文介绍了一种名为NSA(Native Sparse Attention)的新型稀疏注意力机制,旨在解决长文本建模中标准注意力机制计算成本高昂的问题。NSA通过结合算法创新和硬件优化,实现了高效的长文本建模,同时保持了模型的性能。

NSA作为一种硬件对齐的稀疏注意力架构,通过分层的稀疏策略和可训练的设计,在保持全注意力性能的同时,显著降低了计算成本,为长文本建模提供了一种高效且实用的解决方案。


如果你喜欢我的内容,别忘了点赞、关注和收藏哦!你的支持是我不断进步的动力,也让我更有信心为大家带来更多精彩的内容。感谢你的陪伴,让我们一起成长!

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐