字节推荐算法

一、场景题:在抖音场景下为用户推荐广告词,吸引用户点击搜索,呈现广告这一流程的关键点以及可能遇到的困难。

二、Transformer中对梯度消失或者梯度爆炸的处理

在Transformer模型中,梯度消失和梯度爆炸是深度学习中常见的问题,尤其是在处理长序列数据时。为了克服这些问题,Transformer采用了一系列技术:

2.1. 残差连接(Residual Connections)

每个子层(包括多头自注意力机制和前馈神经网络层)之后都接了一个残差连接,并且紧接着一个层归一化(Layer Normalization)。残差连接有助于缓解深层网络中的梯度消失问题,因为它允许梯度直接通过恒等映射传递到前面的层,从而使得更深层次的网络能够被有效地训练。

2.2. 层归一化(Layer Normalization)

与批量归一化不同,层归一化对单个样本的所有特征进行归一化,而不是对整个批次的同一特征进行归一化。这使得它更适合于动态变化的输入序列长度,并且可以帮助稳定训练过程中的梯度,防止它们变得过大或过小。

2.3. Scaled Dot-Product Attention

为了避免当输入维度较大时softmax函数进入饱和区导致梯度消失的问题,Transformer引入了缩放因子(通常为键向量维度的平方根),来缩放点积结果。

2.4. 初始化策略

合理的权重初始化对于避免梯度爆炸非常重要。例如,使用Xavier初始化或He初始化方法可以确保每一层的输入信号的标准差大致保持不变,从而防止梯度因初始值过小而消失或者过大而爆炸。

2.5. 梯度裁剪(Gradient Clipping)

这是一种简单但有效的方法,用于限制梯度的最大范数。如果计算出的梯度超过了某个阈值,则将其按比例缩小以保证更新步长不会过大,这样可以避免梯度爆炸带来的不稳定训练ty-reference

2.6. 自适应优化器

使用如Adam这样的自适应学习率优化算法,可以根据历史梯度动态调整学习率,有助于更好地控制参数更新的尺度,减少梯度爆炸的风险。

2.7. Warmup技巧

在训练开始阶段,逐渐增加学习率,可以帮助解决由于初始学习率过高而导致的梯度爆炸问题。

三、自注意力机制(Attention)的时间复杂度

自注意力机制的核心是计算Query、Key和Value矩阵,并通过点积得到注意力分数。

  • 输入表示:假设输入序列长度为 n ,每个词的嵌入维度为 d 。

  • 计算步骤

    1. 计算 Q 、 K 、 V Q 、K 、V QKV矩阵:时间复杂度为 O ( n ⋅ d 2 ) O(n \cdot d^2) O(nd2)
    2. 计算注意力分数 Q K T QK^T QKT:时间复杂度为 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2d)
    3. 对注意力分数进行Softmax归一化:时间复杂度为 O ( n 2 ) O(n^2) O(n2)
    4. 计算加权和 Attention ( Q , K , V ) = Softmax ( Q K T ) V \text{Attention}(Q, K, V) = \text{Softmax}(QK^T)V Attention(Q,K,V)=Softmax(QKT)V:时间复杂度为 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2d)
  • 总时间复杂度 O ( n 2 ⋅ d + n ⋅ d 2 ) O(n^2 \cdot d + n \cdot d^2) O(n2d+nd2)。当 n > d n > d n>d 时,主要项为 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2d)

四、前馈神经网络(FFN)的时间复杂度

FFN由两个全连接层组成,通常先扩展维度再压缩回原始维度。

  • 输入表示:输入维度为 d d d,隐藏层维度为 d f f d_{ff} dff(通常 d f f = 4 d d_{ff} = 4d dff=4d)。

  • 计算步骤

    1. 第一层全连接:时间复杂度为 O ( n ⋅ d ⋅ d f f ) O(n \cdot d \cdot d_{ff}) O(nddff)
    2. 第二层全连接:时间复杂度为 O ( n ⋅ d f f ⋅ d ) O(n \cdot d_{ff} \cdot d) O(ndffd)
  • 总时间复杂度 O ( n ⋅ d ⋅ d f f ) O(n \cdot d \cdot d_{ff}) O(nddff)。由于 d f f = 4 d d_{ff} = 4d dff=4d,可简化为 O ( n ⋅ d 2 ) O(n \cdot d^2) O(nd2)

五、Transformer整体时间复杂度

一个Transformer层包含一个自注意力机制和一个FFN,假设有 L L L 层。

  • 单层时间复杂度 O ( n 2 ⋅ d + n ⋅ d 2 ) O(n^2 \cdot d + n \cdot d^2) O(n2d+nd2)
  • 整体时间复杂度 L ⋅ O ( n 2 ⋅ d + n ⋅ d 2 ) L \cdot O(n^2 \cdot d + n \cdot d^2) LO(n2d+nd2)
  • 自注意力机制 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2d)
  • FFN O ( n ⋅ d 2 ) O(n \cdot d^2) O(nd2)
  • Transformer整体 L ⋅ O ( n 2 ⋅ d + n ⋅ d 2 ) L \cdot O(n^2 \cdot d + n \cdot d^2) LO(n2d+nd2)

但,当序列长度 n n n 较大时,自注意力机制的时间复杂度 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2d) 是主要瓶颈。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐