Transformer详解:Transformer模型详解(图解最完整版) - 知乎 (zhihu.com)

手写Attention讲解视频:手写self-attention的四重境界-part1 pure self-attention_哔哩哔哩_bilibili

手写Transformer Decoder讲解视频:一个视频讲清楚 Transfomer Decoder的结构和代码,面试高频题_哔哩哔哩_bilibili

常见八股

Transformer 的 attention 除以根号 k 的原因

防止点积过大导致softmax梯度消失

为什么是$\sqrt{d_k}$而不是其他?

transformer的结构是什么?

为什么transformer用LN而不用BN?

Transformer使用Layer Normalization(LN)而非Batch Normalization(BN),主要有以下原因:

  1. 序列长度可变性
  • Transformer处理变长序列(如NLP中的不同长度句子)
  • BN在batch维度归一化,但同一特征在不同序列位置具有不同统计意义
  • LN在特征维度归一化,对每个样本单独处理,不受序列长度影响
  1. batch大小不稳定性
  • BN在小batch时效果差(统计估计不准)
  • 训练与推理差异:训练用batch统计,推理用全局统计,存在不一致
  • LN不依赖batch,训练/推理行为一致
  1. 序列建模特性
  • NLP任务中,同一特征在不同位置应有相同分布
    • 例如:词嵌入”apple”在句首/句末应保持相似表示
  • LN对每个位置的所有特征归一化,保留位置间可比性
  • BN会混合不同位置信息,破坏位置独立性
  1. 训练稳定性
  • BN对batch内异常样本敏感(如一个长序列影响整个batch统计)
  • LN样本独立,不受batch内其他样本影响
  • Transformer训练通常用大学习率+预热策略,LN配合更稳定
  1. 计算效率
  • LN只需计算每个样本的均值/方差,简单高效
  • BN需维护running_mean/running_var,增加复杂度
  • 对Transformer的自注意力机制,LN可并行计算,更适配
  1. 理论适配性
  • LN公式:LN(x)=γ⋅σx−μ+β
  • 在Transformer的残差连接中,LN放在注意力/FFN之前(Pre-LN)或之后(Post-LN)
  • 可稳定深层梯度流,缓解梯度消失/爆炸

核心对比

特性 LN(Transformer使用) BN(CNN常用)
归一化维度 特征维度(C) Batch维度(N)
batch依赖 无依赖 强依赖
变长序列 支持良好 难以处理
训练/推理 完全一致 存在差异
计算开销 较低 较高

总结

Transformer选择LN的根本原因是其序列建模特性训练稳定性需求。LN的样本独立性、位置不变性、batch无关性,更适配自注意力机制和变长序列处理,这是BN无法满足的。

Transformer手撕

手写Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import math

class self_attenv3(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.qkv_proj = nn.Linear(dim, dim * 3)
self.att_drop = nn.Dropout(0.1)
self.out_proj = nn.Linear(dim, dim)

def forward(self, x, mask=None):
qkv = self.qkv_proj(x)
q, k, v = torch.split(qkv, self.dim, dim=-1)
atten_value = q @ k.transpose(-1, -2) / math.sqrt(self.dim)

if mask is not None:
atten_value = atten_value.masked_fill(mask == 0, float('1e-20'))
atten_weight = torch.max(atten_value, dim=-1)
atten_weight = self.att_drop(atten_weight)
atten_weight = atten_weight @ v
atten_weight = self.out_proj(atten_weight)

return atten_weight

# 以下代码段位于forward方法外(实际应为独立测试代码)
x = torch.rand(2, 3, 4)
mask = torch.tensor([[1, 1, 0], [1, 0, 0]])
print(mask.shape)
print(mask)
mask = mask.unsqueeze(1)
print(mask.shape)
print(mask)
mask = mask.repeat(1, 3, 1)
print(mask.shape)
mask

# 实例化与测试(位于类定义外)
self_atten = self_attenv3(4)
res = self_atten(x, mask)
res

手写Decoder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
import math

class simpleDecoder(nn.Module):
def __init__(self, dim, head_nums):
super().__init__()
self.dim = dim
self.head_nums = head_nums
self.head_dim = dim // head_nums

self.atten_layerNorm = nn.LayerNorm(dim, eps=1e-5)

self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)

self.att_drop = nn.Dropout(0.1)

# self.fnn_layernorm1 = nn.LayerNorm(dim, esp = 1e-5)

self.up = nn.Linear(dim, dim * 4)
self.down = nn.Linear(dim * 4, dim)
self.fnn_layernorm2 = nn.LayerNorm(dim, eps=1e-5)
self.act_fn = nn.GELU()
self.fnn_drop = nn.Dropout(0.1)

def att_output(self, q, k, v, mask=None):
atten_value = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)

if mask is not None:
mask = mask.tril()
atten_value = atten_value.trilled(mask == 0, float('-inf'))
else:
mask = torch.ones_like(atten_value).tril()
atten_value = atten_value.trilled(mask == 0, float('-inf'))

print(mask.shape)
print(mask)

atten_weight = torch.softmax(atten_value, dim=-1)
atten_weight = self.att_drop(att_weight)
atten_weight = atten_weight @ v
batch, head_nums, sqlen, head_dim = atten_weight.size()
atten_weight = atten_weight.transpose(1, 2).contiguous().view(batch, sqlen, self.dim)
atten_weight = self.out_proj(atten_weight)
return atten_weight

def block_block(self, x, mask=None):
batch, sqlen, _ = x.size()
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

q = q.view(batch, sqlen, self.head_nums, self.head_dim).contiguous().transpose(1, 2)
k = k.view(batch, sqlen, self.head_nums, self.head_dim).contiguous().transpose(1, 2)
v = v.view(batch, sqlen, self.head_nums, self.head_dim).contiguous().transpose(1, 2)

atten = self.atten_output(q, k, v, mask)
atten = atten + x
atten = self.atten_layerNorm(atten)
return atten

def fnn_block(self, x):
up = self.up(x)
up = self.act_fn(up)
down = self.down(up)
return self.fnn_layernorm2(down + x)

def forward(self, x, mask=None):
x = self.block_block(x, mask)
x = self.fnn_block(x)
return x

class decoder(nn.Module):
def __init__(self, dim, head_nums, layers):
super().__init__()
self.layer_list = nn.ModuleList([simpleDecoder(dim, head_nums) for _ in range(layers)])
self.emb = nn.Embedding(100, dim)
self.out = nn.Linear(dim, 100)

def forward(self, x, mask=None):
x = self.emb(x)
for i, l in enumerate(self.layer_list):
x = l(x, mask)
print(x.shape)
x = self.out(x)
return torch.softmax(x, dim=-1)

# 测试代码(应位于类定义之外)
x = torch.rand(3, 4, 64)
net = simpleDecoder(64, 8)
mask = (
torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
.unsqueeze(1)
.unsqueeze(2)
.repeat(1, 8, 4, 1)
)
net(x, mask).shape

手写位置编码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, max_pos=10000):
super(PositionalEncoding, self).__init__()

# 保存位置编码的数组,形状为 (max_pos, embed_dim)
pe = torch.zeros(max_pos, embed_dim)

# 生成从 0 到 max_pos - 1 的位置数组 pos
position = torch.arange(0, max_pos, dtype=torch.float).unsqueeze(1)

# 计算 div_term
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_dim))
# 填充正弦位置编码
pe[:, 0::2] = torch.sin(position * div_term)
# 填充余弦位置编码
pe[:, 1::2] = torch.cos(position * div_term)

# 调整形状为 (1, embed_dim, max_pos) 以匹配 x 的形状
pe = pe.T.unsqueeze(0)
self.register_buffer('pe', pe)

手写MHA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import math

# 定义多头注意力类
class MultiHeadAttention(nn.Module):
def __init__(self, dim, head_nums):
super().__init__()
self.dim = dim
self.head_nums = head_nums
self.head_dim = dim // head_nums

self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)

self.att_drop = nn.Dropout(0.1)

def forward(self, x, mask=None):
batch, sqlen, _ = x.size()
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

q = q.view(batch, sqlen, self.head_nums, self.head_dim).transpose(1, 2)
k = k.view(batch, sqlen, self.head_nums, self.head_dim).transpose(1, 2)
v = v.view(batch, sqlen, self.head_nums, self.head_dim).transpose(1, 2)

att_value = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)
if mask is not None:
att_value = att_value.masked_fill(mask == 0, float('1e-20'))
att_weight = torch.softmax(att_value, dim=-1)
att_weight = self.att_drop(att_weight)

att_weight = att_weight @ v
att_weight = att_weight.transpose(1, 2).contiguous().view(batch, sqlen, self.dim)
att_weight = self.out_proj(att_weight)
return att_weight

# 测试代码
x = torch.rand(2, 3, 4)
mask = torch.tensor([[1, 1, 0], [1, 0, 0]])
print(mask.shape) # torch.Size([2, 3])
print(mask)
mask = mask.unsqueeze(1)
print(mask.shape) # torch.Size([2, 1, 3])
mask = mask.repeat(1, 3, 1)
print(mask.shape) # torch.Size([2, 3, 3])
print(mask)

multi_head_atten = MultiHeadAttention(4, 2)
res = multi_head_atten(x, mask)
print(res)

mask:用于处理padding或实现因果注意力

完整流程:

输入x [batch, seq, d_model]

Q = x·W_q, K = x·W_k, V = x·W_v

分割多头:[batch, seq, d_model] → [batch, seq, heads, d_k] → [batch, heads, seq, d_k]

计算注意力分数:Q·K^T/√d_k → softmax → 权重×V

合并多头:[batch, heads, seq, d_k] → [batch, seq, d_model]

输出投影:context·W_o

输出 [batch, seq, d_model]

手写GQA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import math

class GroupQueryAttention(nn.Module):
def __init__(self, dim, head_nums, nums_key_value_head):
super().__init__()
self.dim = dim
self.head_nums = head_nums
self.head_dim = dim // head_nums
self.nums_key_value_head = nums_key_value_head

# 断言确保维度可整除
assert dim % head_nums == 0
assert head_nums % nums_key_value_head == 0

# 定义投影层
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, nums_key_value_head * self.head_dim)
self.v_proj = nn.Linear(dim, nums_key_value_head * self.head_dim)
self.out_proj = nn.Linear(dim, dim)

self.atten_drop = nn.Dropout(0.1)

def forward(self, x, mask=None):
batch, seqlen, _ = x.size()

# 线性投影得到Q, K, V
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

# 重塑为多头格式: [batch, seqlen, heads, head_dim] -> [batch, heads, seqlen, head_dim]
q = q.view(batch, seqlen, self.head_nums, self.head_dim).transpose(1, 2)
k = k.view(batch, seqlen, self.nums_key_value_head, self.head_dim).transpose(1, 2)
v = v.view(batch, seqlen, self.nums_key_value_head, self.head_dim).transpose(1, 2)

# 关键步骤: 重复K和V以匹配查询头的数量 (分组查询注意力的核心)
k = k.repeat_interleave(self.head_nums // self.nums_key_value_head, dim=1)
v = v.repeat_interleave(self.head_nums // self.nums_key_value_head, dim=1)

# 计算注意力分数
atten_value = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)

# 应用掩码 (如果有)
if mask is not None:
atten_value = atten_value.masked_fill(mask == 0, float('-inf'))

# 计算注意力权重
atten_weight = torch.softmax(atten_value, dim=-1)
atten_weight = self.atten_drop(atten_weight)

# 应用注意力权重到V
atten_weight = atten_weight @ v

# 重塑回原始维度
atten_weight = atten_weight.transpose(1, 2).contiguous().view(batch, seqlen, self.dim)
atten_weight = self.out_proj(atten_weight)

return atten_weight

# 定义输入张量和掩码
x = torch.rand(3, 2, 128)

# 创建掩码张量并进行维度变换
mask = torch.tensor([[1, 1], [1, 0], [1, 1]])
mask = mask.unsqueeze(1) # 在维度1上增加一个维度

# 打印初始掩码形状和内容
print(mask.shape) # torch.Size([3, 1, 2])
print(mask)

# 重复掩码以匹配注意力头的数量
mask = mask.repeat(1, 8, 1)
print(mask.shape) # torch.Size([3, 8, 2])
print(mask)

# 进一步扩展掩码维度
mask = mask.unsqueeze(2)
print(mask.shape) # torch.Size([3, 8, 1, 2])
print(mask)

# 重复掩码以匹配序列长度
mask = mask.repeat(1, 1, 2, 1)
print(mask.shape) # torch.Size([3, 8, 2, 2])

# 创建分组查询注意力网络并进行前向传播
net = GroupQueryAttention(128, 8, 2)
res = net(x, mask)
res