一口气看完从零到一构建transformer架构代码一:多头注意力机制
的形状兼容,通常在 Transformer 的自注意力或多头注意力机制中使用。这是 Transformer 自注意力机制的核心实现,适用于编码器、解码器或跨注意力场景。,负责将输入向量线性变换并拆分为多个头的表示。的注意力分数矩阵,每个位置的值是。对应位置向量的点积。
目录
1.构建PrepareForMultiHeadAttention类
1. assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
2. assert mask.shape[1] == key_shape[0]
3. assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
代码参考地址:Transformer 编码器和解码器模型
为了新人小白这里将它详细解释了一下
一、多头注意力机制
1.构建PrepareForMultiHeadAttention类
这个 PrepareForMultiHeadAttention
类实现了 多头注意力机制的输入预处理,负责将输入向量线性变换并拆分为多个头的表示。以下是详细解析:
(1). 初始化方法 __init__
def __init__(self, d_model, heads, d_k, bias):
super().__init__()
self.linear = nn.Linear(d_model, heads * d_k, bias=bias) # 线性变换
self.heads = heads # 头数
self.d_k = d_k # 每个头的维度
- 参数说明:
d_model
:输入向量的维度(如 512)。heads
:注意力头的数量(如 8)。d_k
:每个注意力头的维度(通常d_k = d_model // heads
,如 64)。bias
:是否在线性变换中添加偏置项。
(2). 前向传播 forward
def forward(self, x):
head_shape = x.shape[:-1] # 保留除最后一个维度外的形状(如 [seq_len, batch_size])
x = self.linear(x) # 线性变换:[..., d_model] → [..., heads * d_k]
x = x.view(*head_shape, self.heads, self.d_k) # 拆分多头:[..., heads, d_k]
return x
- 输入
x
:- 形状可为
(seq_len, batch_size, d_model)
或(batch_size, d_model)
。 - 最后一个维度必须为
d_model
。
- 形状可为
- 处理流程:
- 线性变换:将
d_model
维输入映射到heads * d_k
维空间。 - 拆分多头:通过
view
将最后一维拆分为(heads, d_k)
。
- 线性变换:将
- 输出形状:
- 输入为
(seq_len, batch_size, d_model)
→ 输出(seq_len, batch_size, heads, d_k)
。 - 输入为
(batch_size, d_model)
→ 输出(batch_size, heads, d_k)
。
- 输入为
(3). 设计原理
(1) 多头注意力的核心思想
- 拆分注意力:将输入向量拆分为
heads
个独立的子空间,每个子空间学习不同的注意力模式。 - 维度关系:
其中 h 为头数,确保总参数量不变。
(2) 线性变换的作用
- 投影到子空间:
self.linear
等效于将 WQ, WK, WV 合并为一个矩阵,通过后续view
拆分。
(3) 形状变换示例
假设 d_model=512
, heads=8
, d_k=64
:
- 输入
(10, 32, 512)
→ 线性变换 →(10, 32, 512)
→ 拆分 →(10, 32, 8, 64)
。
2.MultiHeadAttention类
(1). 初始化方法 __init__
def __init__(self, heads, d_model, dropout_prob=0.1, bias=True):
super().__init__()
self.d_k = d_model // heads # 每个注意力头的维度
self.heads = heads # 头数
# 初始化 Q/K/V 的投影层
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
# 注意力计算相关
self.softmax = nn.Softmax(dim=1) # 沿序列长度维度归一化
self.output = nn.Linear(d_model, d_model) # 输出投影层
self.dropout = nn.Dropout(dropout_prob) # Dropout 层
self.scale = 1 / math.sqrt(self.d_k) # 缩放因子
self.attn = None # 存储注意力权重(用于可视化或分析)
- 关键参数:
heads
:注意力头的数量(如 8)。d_model
:输入/输出向量的维度(如 512)。dropout_prob
:注意力权重的 Dropout 概率(默认 0.1)。bias
:是否在 Q/K 投影中添加偏置(V 投影强制启用偏置)。
- 维度关系:
- 确保
d_model % heads == 0
,即d_k = d_model // heads
(如 512 / 8 = 64)。
- 确保
(2) git_score函数:计算注意力分数
def get_score(self, query, key):
return torch.einsum('ibhd,jbhd->ijbh', query, key)
输入张量形状假设
假设:
-
query
的形状为(i, b, h, d)
i
:目标序列长度(如解码端的 token 数)。b
:batch size。h
:注意力头的数量(多头注意力)。d
:每个注意力头的维度(d_k
或d_q
)。
-
key
的形状为(j, b, h, d)
j
:源序列长度(如编码端的 token 数)。b, h, d
含义与query
相同。
爱因斯坦求和规则解析
下标规则 'ibhd,jbhd->ijbh'
的分解:
- 输入张量的标记:
query
的维度标记为i, b, h, d
。key
的维度标记为j, b, h, d
。
- 重复下标
b, h, d
:- 这些下标在输入中重复出现,表示在这些维度上保持对齐(不求和)。
- 只有
d
是重复的且未出现在输出中,因此会沿d
维度求和(点积操作)。
- 输出形状
ijbh
:- 输出保留
i, j, b, h
维度,即对每个 batch(b
)、每个注意力头(h
),计算query
的第i
个位置与key
的第j
个位置的注意力分数。
- 输出保留
计算过程
- 点积求和:
对query
和key
的最后一个维度d
做点积(求和),得到未归一化的注意力分数。- 公式:output[i,j,b,h]=d∑query[i,b,h,d]⋅key[j,b,h,d]
- 输出形状:
结果为(i, j, b, h)
,表示:- 对 batch 中每个样本(
b
),每个注意力头(h
),query
的第i
个位置与key
的第j
个位置的相似度分数。
- 对 batch 中每个样本(
直观示例
假设:
i=2
(目标序列长度),j=3
(源序列长度),b=1
(batch size),h=2
(注意力头数),d=4
(每个头的维度)。query
形状:(2, 1, 2, 4)
key
形状:(3, 1, 2, 4)
计算后输出:
形状为 (2, 3, 1, 2)
,即一个 2x3
的注意力分数矩阵,每个位置的值是 query
和 key
对应位置向量的点积。
(3).prepare_mask函数,掩码处理函数
def prepare_mask(self, mask, query_shape, key_shape):
# mask具有shape ,其中第一个维度是查询维度。如果查询维度等于[seq_len_q, seq_len_k, batch_size]1它将会被广播。
assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
assert mask.shape[1] == key_shape[0]
assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
# 应用所有头部的相同模板,生成模板形状:[seq_len_q,seq_len_k,batch_size,heads]
mask = mask.unsqueeze(-1)
return mask
这段代码中的 assert
语句用于验证 注意力掩码(mask
) 的形状是否与 query
和 key
的形状兼容,通常在 Transformer 的自注意力或多头注意力机制中使用。以下是逐条解析:
假设的输入形状
query_shape = (query_seq_len, batch_size, num_heads, d_k)
key_shape = (key_seq_len, batch_size, num_heads, d_k)
mask
的形状通常为(batch_size, key_seq_len, query_seq_len)
或其广播形式。
Assert 语句解析
1. assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
- 检查维度 0(
mask.shape[0]
):mask.shape[0]
对应batch_size
(或可广播的维度)。- 条件要求:
mask
的batch_size
必须为1
(支持广播到所有样本)或等于query
的batch_size
(即query_shape[0]
)。
- 为什么?
- 如果
mask
的batch_size=1
,PyTorch 会自动广播到所有样本;否则需严格匹配query
的 batch 维度。
- 如果
2. assert mask.shape[1] == key_shape[0]
- 检查维度 1(
mask.shape[1]
):mask.shape[1]
必须等于key
的序列长度(key_seq_len
)。
- 为什么?
- 注意力机制中,
mask
的该维度用于屏蔽key
的无效位置(如填充符PAD
),因此必须与key
的序列长度一致。
- 注意力机制中,
3. assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
- 检查维度 2(
mask.shape[2]
):mask.shape[2]
对应query_seq_len
(或可广播的维度)。- 条件要求:
mask
的该维度必须为1
(支持广播到所有查询位置)或等于query
的序列长度(query_shape[1]
)。
- 为什么?
- 如果
mask.shape[2]=1
,表示所有查询位置共享同一掩码规则(如解码时的因果掩码);否则需为每个查询位置单独指定掩码。
- 如果
(4)forward前向传播
1. 输入形状处理
seq_len, batch_size, _ = query.shape
- 输入
query
形状:(seq_len, batch_size, embed_dim)
seq_len
:序列长度(如 token 数量)。batch_size
:批大小。embed_dim
:输入嵌入维度(未使用,用_
忽略)。
2. 掩码(Mask)处理
if mask is not None:
mask = self.prepare_mask(mask, query.shape, key.shape)
-
mask
的作用:- 屏蔽无效位置(如填充符
PAD
或未来 token)。 - 形状通常为
(batch_size, key_seq_len, query_seq_len)
或其广播形式(如(1, key_seq_len, 1)
)。
- 屏蔽无效位置(如填充符
-
prepare_mask
方法:- 确保
mask
的形状与query
和key
兼容(如广播或调整维度)。
- 确保
3. Query、Key、Value 投影
query = self.query(query) # 形状: (seq_len, batch_size, num_heads * d_k)
key = self.key(key) # 形状: (key_seq_len, batch_size, num_heads * d_k)
value = self.value(value) # 形状: (key_seq_len, batch_size, num_heads * d_v)
- 线性变换:
self.query
、self.key
、self.value
是nn.Linear
层,将输入投影到多头空间。- 投影后形状:
(seq_len, batch_size, num_heads * head_dim)
。
- 多头拆分:
- 通常在后续操作中通过
view
拆分为(seq_len, batch_size, num_heads, head_dim)
(此处未显式写出,可能在get_score
中处理)。
- 通常在后续操作中通过
4. 注意力分数计算
scores = self.get_score(query, key) # 形状: (query_seq_len, key_seq_len, batch_size, num_heads)
-
get_score
方法:- 计算
query
和key
的点积注意力分数。 - 通常实现为:
# 假设 query 和 key 已拆分为多头 scores = torch.einsum("ibhd,jbhd->ijbh", query, key) # 形状: (i, j, b, h)
- 输出形状:
(query_seq_len, key_seq_len, batch_size, num_heads)
。
- 计算
5. 缩放注意力分数
scores *= self.scale # scale = 1 / sqrt(d_k)
- 缩放目的:
- 防止点积结果过大导致 Softmax 梯度消失。
self.scale
通常设为1 / sqrt(d_k)
(d_k
是key
的每个注意力头的维度)。
6. 掩码应用
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
-
masked_fill
逻辑:- 将
mask
中为0
的位置替换为-inf
,使得 Softmax 后这些位置的权重为0
。 - 典型掩码类型:
- 填充掩码(Padding Mask):屏蔽
PAD
token。 - 因果掩码(Causal Mask):屏蔽未来 token(用于解码器)。
- 填充掩码(Padding Mask):屏蔽
- 将
7. Softmax 归一化
attn = self.softmax(scores) # 形状: (query_seq_len, key_seq_len, batch_size, num_heads)
- Softmax 作用:
- 沿
key_seq_len
维度(dim=1
)归一化,使得每行的注意力权重和为1
。 - 输出形状与
scores
相同。
- 沿
8. Dropout 正则化
attn = self.dropout(attn)
- Dropout 目的:
- 随机丢弃部分注意力权重,防止过拟合。
9. 注意力权重应用(Value 加权求和)
x = torch.einsum('ijbh,jbhd->ibhd', attn, value) # 形状: (seq_len, batch_size, num_heads, d_v)
- 爱因斯坦求和规则:
ijbh,jbhd->ibhd
:对j
(key_seq_len
)维度求和,得到加权后的value
。- 输出形状:
(seq_len, batch_size, num_heads, d_v)
。
10. 多头结果合并 & 输出投影
x = x.reshape(seq_len, batch_size, -1) # 形状: (seq_len, batch_size, num_heads * d_v)
return self.output(x) # 形状: (seq_len, batch_size, output_dim)
- 合并多头:
- 将
num_heads
和d_v
维度合并,恢复为(seq_len, batch_size, num_heads * d_v)
。
- 将
- 输出投影:
self.output
是nn.Linear
层,将多头结果映射到最终输出维度。
总结流程
- 输入投影:
query
、key
、value
线性变换。 - 计算注意力分数:
query
和key
的点积 + 缩放。 - 掩码处理:屏蔽无效位置。
- Softmax:归一化注意力权重。
- Value 加权求和:生成上下文向量。
- 输出投影:合并多头并映射到目标维度。
这是 Transformer 自注意力机制的核心实现,适用于编码器、解码器或跨注意力场景。
完整的多头注意力代码
# 一、导入相关模块
import math
from typing import Optional, List
import torch
from torch import nn
from labml import tracker
# 二、为多头关注做好准备
# 此模块执行线性变换,并将向量拆分为给定数量的头,以实现多头关注。这用于转换键、查询和值向量。
class PrepareForMultiHeadAttention(nn.Module):
def __init__(self, d_model, heads, d_k, bias):
super().__init__()
# 线性变换层
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
# 头数量
self.heads = heads
# 每个头中的向量维度
self.d_k = d_k
def forward(self, x, ):
# 一般传过来的向量维度为[seq_len, batch_size,d_model]或[batch_size,d_model]
head_shape = x.shape[:-1]
# 因为传来的最后一个维度都是d_model,所以我们要对最后一个维度使用线性变换层,并将其拆分为heads
x = self.linear(x)
# 将最后一个维度拆分为多个head,x.view(*head_shape, heads, d_k) → 新形状为 (batch_size, seq_len, heads, d_k)
x = x.view(*head_shape, self.heads, self.d_k)
# 输出具有 shape 或[seq_len, batch_size, heads, d_k][batch_size, heads, d_model]
return x
class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout_prob=0.1, bias=True):
super().__init__()
self.d_k = d_model // heads
self.heads = heads
# 生成qkv并且转换成多头形式
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
# softmax对行维度进行处理
self.softmax = nn.Softmax(dim=1)
# 定义输出层,这里要注意,要保证传入的数据的shape什么样这里就要返回什么样
self.output = nn.Linear(d_model, d_model)
# dropout层
self.dropout = nn.Dropout(dropout_prob)
# softmax之前的缩放因子
self.scale = 1 / math.sqrt(self.d_k)
# 我们存储 attentions,以便在需要时将其用于日志记录或其他计算
self.attn = None
def get_score(self, query, key):
return torch.einsum('ibhd,jbhd->ijbh', query, key)
# 生成掩码函数
def prepare_mask(self, mask, query_shape, key_shape):
# mask具有shape ,其中第一个维度是查询维度。如果查询维度等于[seq_len_q, seq_len_k, batch_size]1它将会被广播。
assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
assert mask.shape[1] == key_shape[0]
assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
# 应用所有头部的相同模板,生成模板形状:[seq_len_q,seq_len_k,batch_size,heads]
mask = mask.unsqueeze(-1)
return mask
def forward(self, query, key, value, mask):
seq_len, batch_size, _ = query.shape
if mask is not None:
mask = self.prepare_mask(mask, query.shape, key.shape)
query = self.query(query)
key = self.key(key)
value = self.value(value)
scores = self.get_score(query, key)
scores *= self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = self.softmax(scores)
tracker.debug('attn', attn)
attn = self.dropout(attn)
x = torch.einsum('ijbh,jbhd->ibhd', attn, value)
self.attn = attn.detach()
x = x.reshape(seq_len, batch_size, -1)
return self.output(x)
更多推荐
所有评论(0)