
搜广推校招面经三十八
在Transformer模型中,梯度消失和梯度爆炸是深度学习中常见的问题,尤其是在处理长序列数据时。
字节推荐算法
一、场景题:在抖音场景下为用户推荐广告词,吸引用户点击搜索,呈现广告这一流程的关键点以及可能遇到的困难。
二、Transformer中对梯度消失或者梯度爆炸的处理
在Transformer模型中,梯度消失和梯度爆炸是深度学习中常见的问题,尤其是在处理长序列数据时。为了克服这些问题,Transformer采用了一系列技术:
2.1. 残差连接(Residual Connections)
每个子层(包括多头自注意力机制和前馈神经网络层)之后都接了一个残差连接,并且紧接着一个层归一化(Layer Normalization)。残差连接有助于缓解深层网络中的梯度消失问题,因为它允许梯度直接通过恒等映射传递到前面的层,从而使得更深层次的网络能够被有效地训练。
2.2. 层归一化(Layer Normalization)
与批量归一化不同,层归一化对单个样本的所有特征进行归一化,而不是对整个批次的同一特征进行归一化。这使得它更适合于动态变化的输入序列长度,并且可以帮助稳定训练过程中的梯度,防止它们变得过大或过小。
2.3. Scaled Dot-Product Attention
为了避免当输入维度较大时softmax函数进入饱和区导致梯度消失的问题,Transformer引入了缩放因子(通常为键向量维度的平方根),来缩放点积结果。
2.4. 初始化策略
合理的权重初始化对于避免梯度爆炸非常重要。例如,使用Xavier初始化或He初始化方法可以确保每一层的输入信号的标准差大致保持不变,从而防止梯度因初始值过小而消失或者过大而爆炸。
2.5. 梯度裁剪(Gradient Clipping)
这是一种简单但有效的方法,用于限制梯度的最大范数。如果计算出的梯度超过了某个阈值,则将其按比例缩小以保证更新步长不会过大,这样可以避免梯度爆炸带来的不稳定训练ty-reference。
2.6. 自适应优化器
使用如Adam这样的自适应学习率优化算法,可以根据历史梯度动态调整学习率,有助于更好地控制参数更新的尺度,减少梯度爆炸的风险。
2.7. Warmup技巧
在训练开始阶段,逐渐增加学习率,可以帮助解决由于初始学习率过高而导致的梯度爆炸问题。
三、自注意力机制(Attention)的时间复杂度
自注意力机制的核心是计算Query、Key和Value矩阵,并通过点积得到注意力分数。
-
输入表示:假设输入序列长度为 n ,每个词的嵌入维度为 d 。
-
计算步骤:
- 计算 Q 、 K 、 V Q 、K 、V Q、K、V矩阵:时间复杂度为 O ( n ⋅ d 2 ) O(n \cdot d^2) O(n⋅d2)。
- 计算注意力分数 Q K T QK^T QKT:时间复杂度为 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d)。
- 对注意力分数进行Softmax归一化:时间复杂度为 O ( n 2 ) O(n^2) O(n2)。
- 计算加权和 Attention ( Q , K , V ) = Softmax ( Q K T ) V \text{Attention}(Q, K, V) = \text{Softmax}(QK^T)V Attention(Q,K,V)=Softmax(QKT)V:时间复杂度为 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d)。
-
总时间复杂度: O ( n 2 ⋅ d + n ⋅ d 2 ) O(n^2 \cdot d + n \cdot d^2) O(n2⋅d+n⋅d2)。当 n > d n > d n>d 时,主要项为 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d)。
四、前馈神经网络(FFN)的时间复杂度
FFN由两个全连接层组成,通常先扩展维度再压缩回原始维度。
-
输入表示:输入维度为 d d d,隐藏层维度为 d f f d_{ff} dff(通常 d f f = 4 d d_{ff} = 4d dff=4d)。
-
计算步骤:
- 第一层全连接:时间复杂度为 O ( n ⋅ d ⋅ d f f ) O(n \cdot d \cdot d_{ff}) O(n⋅d⋅dff)。
- 第二层全连接:时间复杂度为 O ( n ⋅ d f f ⋅ d ) O(n \cdot d_{ff} \cdot d) O(n⋅dff⋅d)。
-
总时间复杂度: O ( n ⋅ d ⋅ d f f ) O(n \cdot d \cdot d_{ff}) O(n⋅d⋅dff)。由于 d f f = 4 d d_{ff} = 4d dff=4d,可简化为 O ( n ⋅ d 2 ) O(n \cdot d^2) O(n⋅d2)。
五、Transformer整体时间复杂度
一个Transformer层包含一个自注意力机制和一个FFN,假设有 L L L 层。
- 单层时间复杂度: O ( n 2 ⋅ d + n ⋅ d 2 ) O(n^2 \cdot d + n \cdot d^2) O(n2⋅d+n⋅d2)。
- 整体时间复杂度: L ⋅ O ( n 2 ⋅ d + n ⋅ d 2 ) L \cdot O(n^2 \cdot d + n \cdot d^2) L⋅O(n2⋅d+n⋅d2)。
- 自注意力机制: O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d)。
- FFN: O ( n ⋅ d 2 ) O(n \cdot d^2) O(n⋅d2)。
- Transformer整体: L ⋅ O ( n 2 ⋅ d + n ⋅ d 2 ) L \cdot O(n^2 \cdot d + n \cdot d^2) L⋅O(n2⋅d+n⋅d2)。
但,当序列长度 n n n 较大时,自注意力机制的时间复杂度 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d) 是主要瓶颈。
更多推荐
所有评论(0)