目录

一、多头注意力机制

1.构建PrepareForMultiHeadAttention类

​​(2). 前向传播 forward​​

​​(3). 设计原理​​

​​(1) 多头注意力的核心思想​​

​​(2) 线性变换的作用​​

​​(3) 形状变换示例​​

 2.MultiHeadAttention类

(1). 初始化方法 __init__​​

(2) git_score函数:计算注意力分数

输入张量形状假设​​

​​爱因斯坦求和规则解析​​

​​计算过程​​

​​直观示例​​

 (3).prepare_mask函数,掩码处理函数

​​假设的输入形状​​

​​Assert 语句解析​​

1. assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]

2. assert mask.shape[1] == key_shape[0]

3. assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

 (4)forward前向传播

1. 输入形状处理​​

​​2. 掩码(Mask)处理​​

​​3. Query、Key、Value 投影​​

​​4. 注意力分数计算​​

​​5. 缩放注意力分数​​

​​6. 掩码应用​​

​​7. Softmax 归一化​​

​​8. Dropout 正则化​​

​​9. 注意力权重应用(Value 加权求和)​​

​​10. 多头结果合并 & 输出投影​​

​​总结流程​​

完整的多头注意力代码


  代码参考地址:Transformer 编码器和解码器模型

为了新人小白这里将它详细解释了一下

一、多头注意力机制

1.构建PrepareForMultiHeadAttention类

这个 PrepareForMultiHeadAttention 类实现了 ​​多头注意力机制的输入预处理​​,负责将输入向量线性变换并拆分为多个头的表示。以下是详细解析:


​(1). 初始化方法 __init__

def __init__(self, d_model, heads, d_k, bias):
    super().__init__()
    self.linear = nn.Linear(d_model, heads * d_k, bias=bias)  # 线性变换
    self.heads = heads  # 头数
    self.d_k = d_k      # 每个头的维度
  • ​参数说明​​:
    • d_model:输入向量的维度(如 512)。
    • heads:注意力头的数量(如 8)。
    • d_k:每个注意力头的维度(通常 d_k = d_model // heads,如 64)。
    • bias:是否在线性变换中添加偏置项。

​(2). 前向传播 forward

def forward(self, x):
    head_shape = x.shape[:-1]  # 保留除最后一个维度外的形状(如 [seq_len, batch_size])
    x = self.linear(x)         # 线性变换:[..., d_model] → [..., heads * d_k]
    x = x.view(*head_shape, self.heads, self.d_k)  # 拆分多头:[..., heads, d_k]
    return x
  • ​输入 x​:
    • 形状可为 (seq_len, batch_size, d_model) 或 (batch_size, d_model)
    • 最后一个维度必须为 d_model
  • ​处理流程​​:
    1. ​线性变换​​:将 d_model 维输入映射到 heads * d_k 维空间。
    2. ​拆分多头​​:通过 view 将最后一维拆分为 (heads, d_k)
  • ​输出形状​​:
    • 输入为 (seq_len, batch_size, d_model) → 输出 (seq_len, batch_size, heads, d_k)
    • 输入为 (batch_size, d_model) → 输出 (batch_size, heads, d_k)

​(3). 设计原理​

​(1) 多头注意力的核心思想​
  • ​拆分注意力​​:将输入向量拆分为 heads 个独立的子空间,每个子空间学习不同的注意力模式。
  • ​维度关系​​:其中 h 为头数,确保总参数量不变。
​(2) 线性变换的作用​
  • ​投影到子空间​​:self.linear 等效于将 WQ, WK, WV 合并为一个矩阵,通过后续 view 拆分。
​(3) 形状变换示例​

假设 d_model=512heads=8d_k=64

  • 输入 (10, 32, 512) → 线性变换 → (10, 32, 512) → 拆分 → (10, 32, 8, 64)

 2.MultiHeadAttention类

(1). 初始化方法 __init__

def __init__(self, heads, d_model, dropout_prob=0.1, bias=True):
    super().__init__()
    self.d_k = d_model // heads  # 每个注意力头的维度
    self.heads = heads           # 头数
    # 初始化 Q/K/V 的投影层
    self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
    self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
    self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
    # 注意力计算相关
    self.softmax = nn.Softmax(dim=1)  # 沿序列长度维度归一化
    self.output = nn.Linear(d_model, d_model)  # 输出投影层
    self.dropout = nn.Dropout(dropout_prob)    # Dropout 层
    self.scale = 1 / math.sqrt(self.d_k)       # 缩放因子
    self.attn = None  # 存储注意力权重(用于可视化或分析)
  • ​关键参数​​:
    • heads:注意力头的数量(如 8)。
    • d_model:输入/输出向量的维度(如 512)。
    • dropout_prob:注意力权重的 Dropout 概率(默认 0.1)。
    • bias:是否在 Q/K 投影中添加偏置(V 投影强制启用偏置)。
  • ​维度关系​​:
    • 确保 d_model % heads == 0,即 d_k = d_model // heads(如 512 / 8 = 64)。

(2) git_score函数:计算注意力分数

 def get_score(self, query, key):

        return torch.einsum('ibhd,jbhd->ijbh', query, key)
输入张量形状假设​

假设:

  • query 的形状为 (i, b, h, d)

    • i:目标序列长度(如解码端的 token 数)。
    • b:batch size。
    • h:注意力头的数量(多头注意力)。
    • d:每个注意力头的维度(d_k 或 d_q)。
  • key 的形状为 (j, b, h, d)

    • j:源序列长度(如编码端的 token 数)。
    • b, h, d 含义与 query 相同。

​爱因斯坦求和规则解析​

下标规则 'ibhd,jbhd->ijbh' 的分解:

  1. ​输入张量的标记​​:
    • query 的维度标记为 i, b, h, d
    • key 的维度标记为 j, b, h, d
  2. ​重复下标 b, h, d​:
    • 这些下标在输入中重复出现,表示在这些维度上​​保持对齐​​(不求和)。
    • 只有 d 是重复的且未出现在输出中,因此会​​沿 d 维度求和​​(点积操作)。
  3. ​输出形状 ijbh​:
    • 输出保留 i, j, b, h 维度,即对每个 batch(b)、每个注意力头(h),计算 query 的第 i 个位置与 key 的第 j 个位置的注意力分数。

​计算过程​
  1. ​点积求和​​:
    对 query 和 key 的最后一个维度 d 做点积(求和),得到未归一化的注意力分数。
    • 公式:output[i,j,b,h]=d∑​query[i,b,h,d]⋅key[j,b,h,d]
  2. ​输出形状​​:
    结果为 (i, j, b, h),表示:
    • 对 batch 中每个样本(b),每个注意力头(h),query 的第 i 个位置与 key 的第 j 个位置的相似度分数。

​直观示例​

假设:

  • i=2(目标序列长度),j=3(源序列长度),b=1(batch size),h=2(注意力头数),d=4(每个头的维度)。
  • query 形状:(2, 1, 2, 4)
  • key 形状:(3, 1, 2, 4)

​计算后输出​​:
形状为 (2, 3, 1, 2),即一个 2x3 的注意力分数矩阵,每个位置的值是 query 和 key 对应位置向量的点积。

 (3).prepare_mask函数,掩码处理函数

def prepare_mask(self, mask, query_shape, key_shape):
    # mask具有shape ,其中第一个维度是查询维度。如果查询维度等于[seq_len_q, seq_len_k, batch_size]1它将会被广播。
    assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
    assert mask.shape[1] == key_shape[0]
    assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
    # 应用所有头部的相同模板,生成模板形状:[seq_len_q,seq_len_k,batch_size,heads]
    mask = mask.unsqueeze(-1)

    return mask

这段代码中的 assert 语句用于验证 ​​注意力掩码(mask)​​ 的形状是否与 query 和 key 的形状兼容,通常在 Transformer 的自注意力或多头注意力机制中使用。以下是逐条解析:


​假设的输入形状​

  • query_shape = (query_seq_len, batch_size, num_heads, d_k)
  • key_shape = (key_seq_len, batch_size, num_heads, d_k)
  • mask 的形状通常为 (batch_size, key_seq_len, query_seq_len) 或其广播形式。

​Assert 语句解析​

1. assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
  • ​检查维度 0(mask.shape[0])​​:
    • mask.shape[0] 对应 batch_size(或可广播的维度)。
    • 条件要求:mask 的 batch_size 必须为 1(支持广播到所有样本)或等于 query 的 batch_size(即 query_shape[0])。
  • ​为什么?​
    • 如果 mask 的 batch_size=1,PyTorch 会自动广播到所有样本;否则需严格匹配 query 的 batch 维度。
2. assert mask.shape[1] == key_shape[0]
  • ​检查维度 1(mask.shape[1])​​:
    • mask.shape[1] 必须等于 key 的序列长度(key_seq_len)。
  • ​为什么?​
    • 注意力机制中,mask 的该维度用于屏蔽 key 的无效位置(如填充符 PAD),因此必须与 key 的序列长度一致。
3. assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
  • ​检查维度 2(mask.shape[2])​​:
    • mask.shape[2] 对应 query_seq_len(或可广播的维度)。
    • 条件要求:mask 的该维度必须为 1(支持广播到所有查询位置)或等于 query 的序列长度(query_shape[1])。
  • ​为什么?​
    • 如果 mask.shape[2]=1,表示所有查询位置共享同一掩码规则(如解码时的因果掩码);否则需为每个查询位置单独指定掩码。

 (4)forward前向传播

1. 输入形状处理​
seq_len, batch_size, _ = query.shape
  • ​输入 query 形状​​:(seq_len, batch_size, embed_dim)
    • seq_len:序列长度(如 token 数量)。
    • batch_size:批大小。
    • embed_dim:输入嵌入维度(未使用,用 _ 忽略)。

​2. 掩码(Mask)处理​
if mask is not None:
    mask = self.prepare_mask(mask, query.shape, key.shape)
  • mask 的作用​​:
    • 屏蔽无效位置(如填充符 PAD 或未来 token)。
    • 形状通常为 (batch_size, key_seq_len, query_seq_len) 或其广播形式(如 (1, key_seq_len, 1))。
  • prepare_mask 方法​​:
    • 确保 mask 的形状与 query 和 key 兼容(如广播或调整维度)。

​3. Query、Key、Value 投影​
query = self.query(query)  # 形状: (seq_len, batch_size, num_heads * d_k)
key = self.key(key)        # 形状: (key_seq_len, batch_size, num_heads * d_k)
value = self.value(value)  # 形状: (key_seq_len, batch_size, num_heads * d_v)
  • ​线性变换​​:
    • self.queryself.keyself.value 是 nn.Linear 层,将输入投影到多头空间。
    • 投影后形状:(seq_len, batch_size, num_heads * head_dim)
  • ​多头拆分​​:
    • 通常在后续操作中通过 view 拆分为 (seq_len, batch_size, num_heads, head_dim)(此处未显式写出,可能在 get_score 中处理)。

​4. 注意力分数计算​
scores = self.get_score(query, key)  # 形状: (query_seq_len, key_seq_len, batch_size, num_heads)
  • get_score 方法​​:
    • 计算 query 和 key 的点积注意力分数。
    • 通常实现为:
      # 假设 query 和 key 已拆分为多头
      scores = torch.einsum("ibhd,jbhd->ijbh", query, key)  # 形状: (i, j, b, h)
    • 输出形状:(query_seq_len, key_seq_len, batch_size, num_heads)

​5. 缩放注意力分数​
scores *= self.scale  # scale = 1 / sqrt(d_k)
  • ​缩放目的​​:
    • 防止点积结果过大导致 Softmax 梯度消失。
    • self.scale 通常设为 1 / sqrt(d_k)d_k 是 key 的每个注意力头的维度)。

​6. 掩码应用​
if mask is not None:
    scores = scores.masked_fill(mask == 0, float('-inf'))
  • masked_fill 逻辑​​:
    • 将 mask 中为 0 的位置替换为 -inf,使得 Softmax 后这些位置的权重为 0
    • ​典型掩码类型​​:
      • ​填充掩码(Padding Mask)​​:屏蔽 PAD token。
      • ​因果掩码(Causal Mask)​​:屏蔽未来 token(用于解码器)。

​7. Softmax 归一化​
attn = self.softmax(scores)  # 形状: (query_seq_len, key_seq_len, batch_size, num_heads)
  • ​Softmax 作用​​:
    • 沿 key_seq_len 维度(dim=1)归一化,使得每行的注意力权重和为 1
    • 输出形状与 scores 相同。

​8. Dropout 正则化​
attn = self.dropout(attn)
  • ​Dropout 目的​​:
    • 随机丢弃部分注意力权重,防止过拟合。

​9. 注意力权重应用(Value 加权求和)​
x = torch.einsum('ijbh,jbhd->ibhd', attn, value)  # 形状: (seq_len, batch_size, num_heads, d_v)
  • ​爱因斯坦求和规则​​:
    • ijbh,jbhd->ibhd:对 jkey_seq_len)维度求和,得到加权后的 value
    • 输出形状:(seq_len, batch_size, num_heads, d_v)

​10. 多头结果合并 & 输出投影​
x = x.reshape(seq_len, batch_size, -1)  # 形状: (seq_len, batch_size, num_heads * d_v)
return self.output(x)  # 形状: (seq_len, batch_size, output_dim)
  • ​合并多头​​:
    • 将 num_heads 和 d_v 维度合并,恢复为 (seq_len, batch_size, num_heads * d_v)
  • ​输出投影​​:
    • self.output 是 nn.Linear 层,将多头结果映射到最终输出维度。

​总结流程​
  1. ​输入投影​​:querykeyvalue 线性变换。
  2. ​计算注意力分数​​:query 和 key 的点积 + 缩放。
  3. ​掩码处理​​:屏蔽无效位置。
  4. ​Softmax​​:归一化注意力权重。
  5. ​Value 加权求和​​:生成上下文向量。
  6. ​输出投影​​:合并多头并映射到目标维度。

这是 Transformer 自注意力机制的核心实现,适用于编码器、解码器或跨注意力场景。

完整的多头注意力代码

# 一、导入相关模块
import math
from typing import Optional, List
import torch
from torch import nn
from labml import tracker


# 二、为多头关注做好准备
# 此模块执行线性变换,并将向量拆分为给定数量的头,以实现多头关注。这用于转换键、查询和值向量。
class PrepareForMultiHeadAttention(nn.Module):
    def __init__(self, d_model, heads, d_k, bias):
        super().__init__()
        # 线性变换层
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        # 头数量
        self.heads = heads
        # 每个头中的向量维度
        self.d_k = d_k

    def forward(self, x, ):
        # 一般传过来的向量维度为[seq_len, batch_size,d_model]或[batch_size,d_model]
        head_shape = x.shape[:-1]
        # 因为传来的最后一个维度都是d_model,所以我们要对最后一个维度使用线性变换层,并将其拆分为heads
        x = self.linear(x)
        # 将最后一个维度拆分为多个head,x.view(*head_shape, heads, d_k) → 新形状为 (batch_size, seq_len, heads, d_k)
        x = x.view(*head_shape, self.heads, self.d_k)
        # 输出具有 shape 或[seq_len, batch_size, heads, d_k][batch_size, heads, d_model]
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout_prob=0.1, bias=True):
        super().__init__()
        self.d_k = d_model // heads
        self.heads = heads
        # 生成qkv并且转换成多头形式
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
        # softmax对行维度进行处理
        self.softmax = nn.Softmax(dim=1)
        # 定义输出层,这里要注意,要保证传入的数据的shape什么样这里就要返回什么样
        self.output = nn.Linear(d_model, d_model)
        # dropout层
        self.dropout = nn.Dropout(dropout_prob)
        # softmax之前的缩放因子
        self.scale = 1 / math.sqrt(self.d_k)
        # 我们存储 attentions,以便在需要时将其用于日志记录或其他计算
        self.attn = None

    def get_score(self, query, key):

        return torch.einsum('ibhd,jbhd->ijbh', query, key)

    # 生成掩码函数
    def prepare_mask(self, mask, query_shape, key_shape):
        # mask具有shape ,其中第一个维度是查询维度。如果查询维度等于[seq_len_q, seq_len_k, batch_size]1它将会被广播。
        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
        # 应用所有头部的相同模板,生成模板形状:[seq_len_q,seq_len_k,batch_size,heads]
        mask = mask.unsqueeze(-1)

        return mask

    def forward(self, query, key, value, mask):
        seq_len, batch_size, _ = query.shape
        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        scores = self.get_score(query, key)

        scores *= self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.softmax(scores)
        tracker.debug('attn', attn)
        attn = self.dropout(attn)
        x = torch.einsum('ijbh,jbhd->ibhd', attn, value)
        self.attn = attn.detach()
        x = x.reshape(seq_len, batch_size, -1)

        return self.output(x)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐