✨Attention Is All You Need
https://arxiv.org/abs/1706.03762
为了加深对于transformer架构的理解
Talk is cheap. Show me the code.
以下代码皆由Gemini 2.5 Flash生成(已经过验证)
✨Attention(注意力机制)
Scaled Dot-Product Attention (缩放点积注意力)
Scaled Dot-Product Attention 是注意力机制的基础形式,它通过计算查询(Query, Q)和键(Key, K)的点积来衡量它们之间的相似度,然后除以一个缩放因子 $ \sqrt{d_k} $(d_k 是键向量的维度),以防止点积过大导致 softmax 函数进入梯度饱和区。最后,将注意力权重与值(Value, V)进行加权求和,得到最终的输出。
数学表达式为:
$$
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$
import torch
import torch.nn as nn
import mathdef scaled_dot_product_attention(query, key, value, mask=None):"""实现缩放点积注意力。Args:query (torch.Tensor): 查询张量,形状通常为 (batch_size, num_heads, seq_len_q, d_k)。key (torch.Tensor): 键张量,形状通常为 (batch_size, num_heads, seq_len_k, d_k)。value (torch.Tensor): 值张量,形状通常为 (batch_size, num_heads, seq_len_v, d_v)。注意:seq_len_k 必须等于 seq_len_v。mask (torch.Tensor, optional): 注意力掩码,形状通常为 (batch_size, 1, seq_len_q, seq_len_k)。用于掩盖某些位置的注意力得分。Returns:torch.Tensor: 注意力机制的输出,形状为 (batch_size, num_heads, seq_len_q, d_v)。torch.Tensor: 注意力权重,形状为 (batch_size, num_heads, seq_len_q, seq_len_k)。"""d_k = query.size(-1) # 获取键向量的维度 d_k# 1. 计算查询和键的点积,然后进行缩放# (..., seq_len_q, d_k) @ (..., d_k, seq_len_k) -> (..., seq_len_q, seq_len_k)scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)# 2. 应用掩码(如果提供)if mask is not None:# 将掩码中为0(通常表示要忽略的位置)的得分设置为负无穷大# 这样在 softmax 后,这些位置的注意力权重会接近于0scores = scores.masked_fill(mask == 0, float('-inf'))# 3. 对得分进行 softmax 归一化,得到注意力权重attention_weights = torch.softmax(scores, dim=-1)# 4. 注意力权重与值张量相乘,得到最终的注意力输出# (..., seq_len_q, seq_len_k) @ (..., seq_len_k, d_v) -> (..., seq_len_q, d_v)output = torch.matmul(attention_weights, value)return output, attention_weights# --- Demo: Scaled Dot-Product Attention ---
print("--- Scaled Dot-Product Attention Demo ---")
batch_size = 2
seq_len_q = 3 # Query序列长度 (e.g., target sequence)
seq_len_k = 5 # Key/Value序列长度 (e.g., source sequence)
d_k = 64 # 键向量的维度# 随机生成 Query, Key, Value 张量
# 注意:Multi-Head Attention 会在 MultiHeadAttention 类中处理 head 维度
# 这里为简化,直接假定为单头(或者说是单个注意力头的输入)
query = torch.randn(batch_size, seq_len_q, d_k)
key = torch.randn(batch_size, seq_len_k, d_k)
value = torch.randn(batch_size, seq_len_k, d_k)# 创建一个简单的填充掩码示例:
# 假设 batch_size=2,第一个序列没有填充,第二个序列的后两个 token 是填充
# 真实的 mask 会基于 padding token 的位置生成
mask = torch.ones(batch_size, seq_len_q, seq_len_k)
# 假设第二个 batch 的 key 序列,最后两个位置是填充,那么它们不应该被 Query 关注
mask[1, :, -2:] = 0 # 将第二个 batch 的所有 Query 对 Key 序列的最后两个位置的关注设为0# 将 mask 扩展维度以匹配 scores 的形状 (batch_size, 1, seq_len_q, seq_len_k)
# 第一个 1 是为了兼容 multi-head attention 中的 num_heads 维度
mask_for_sdpa = mask.unsqueeze(1) output_sdpa, weights_sdpa = scaled_dot_product_attention(query.unsqueeze(1), key.unsqueeze(1), value.unsqueeze(1), mask=mask_for_sdpa)print(f"Query shape: {query.shape}")
print(f"Key shape: {key.shape}")
print(f"Value shape: {value.shape}")
print(f"Mask shape: {mask_for_sdpa.shape}")
print(f"Scaled Dot-Product Attention Output shape: {output_sdpa.squeeze(1).shape}") # Squeeze out the head dimension
print(f"Scaled Dot-Product Attention Weights shape: {weights_sdpa.squeeze(1).shape}") # Squeeze out the head dimension# 验证掩码效果:检查第二个批次的注意力权重,其最后两列是否接近于0
print("\nExample of masked attention weights (batch 1, first head):")
print(weights_sdpa[1, 0, :, :])
Multi-Head Attention(多头注意力)
Multi-Head Attention 是 Transformer 架构中的一个关键创新。它通过将 Q,K,V 线性投影到 h 个不同的子空间中,然后对每个子空间并行地执行 Scaled Dot-Product Attention。最后,将所有 h 个注意力头的输出拼接(concatenate)起来,再通过一个最终的线性变换,得到最终的输出。
这样做的好处是模型可以从不同的“表示子空间”(representation subspaces)学习不同的注意力模式,从而捕捉序列中更丰富、更全面的关系信息。例如,一个头可能关注语法关系,另一个头可能关注语义关系。
数学表达式为:
$$
\mathrm{MultiHead}(Q, K, V) = \mathrm{Concat}(head_1,...,head_h)W^O\
\mathrm{where}\ head_i = \mathrm{Attention}(QW_iQ,KW_iK,VW_i^v )
$$
import torch
import torch.nn as nn
import math# 再次定义 scaled_dot_product_attention,确保其在 MultiHeadAttention 中可用
def scaled_dot_product_attention(query, key, value, mask=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention_weights = torch.softmax(scores, dim=-1)return torch.matmul(attention_weights, value), attention_weightsclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, dropout_rate=0.1):"""实现多头注意力机制。Args:d_model (int): 输入和输出特征的维度。num_heads (int): 注意力头的数量。dropout_rate (float): Dropout 比率。"""super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"self.d_k = d_model // num_heads # 每个注意力头的维度self.num_heads = num_headsself.d_model = d_model # 保存 d_model,用于最终的线性变换# 线性层,用于将 Q, K, V 投影到 d_model 维度self.query_linear = nn.Linear(d_model, d_model)self.key_linear = nn.Linear(d_model, d_model)self.value_linear = nn.Linear(d_model, d_model)self.output_linear = nn.Linear(d_model, d_model) # 最终的线性输出层self.dropout = nn.Dropout(dropout_rate)def forward(self, query, key, value, mask=None):"""Args:query (torch.Tensor): 查询张量,形状 (batch_size, seq_len_q, d_model)。key (torch.Tensor): 键张量,形状 (batch_size, seq_len_k, d_model)。value (torch.Tensor): 值张量,形状 (batch_size, seq_len_v, d_model)。mask (torch.Tensor, optional): 注意力掩码。形状通常为 (batch_size, 1, seq_len_q, seq_len_k)或 (1, 1, seq_len_q, seq_len_k) 用于广播。Returns:torch.Tensor: 多头注意力的输出,形状 (batch_size, seq_len_q, d_model)。torch.Tensor: 所有头的注意力权重(可选返回),形状 (batch_size, num_heads, seq_len_q, seq_len_k)。"""batch_size = query.size(0)# 1. 线性投影并分割成多头# (batch_size, seq_len, d_model) -> (batch_size, seq_len, num_heads, d_k) -> (batch_size, num_heads, seq_len, d_k)query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)# 2. 对每个头执行缩放点积注意力# output_heads: (batch_size, num_heads, seq_len_q, d_k)# attention_weights: (batch_size, num_heads, seq_len_q, seq_len_k)output_heads, attention_weights = scaled_dot_product_attention(query, key, value, mask)# 3. 拼接所有头的输出# (batch_size, num_heads, seq_len_q, d_k) -> (batch_size, seq_len_q, num_heads, d_k)# -> (batch_size, seq_len_q, d_model) (通过 contiguous().view 扁平化 num_heads 和 d_k)concat_output = output_heads.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)# 4. 最终的线性变换final_output = self.output_linear(concat_output)final_output = self.dropout(final_output) # 应用 dropoutreturn final_output, attention_weights# --- Demo: Multi-Head Attention ---
print("\n--- Multi-Head Attention Demo ---")
d_model = 512 # 输入/输出特征维度
num_heads = 8 # 注意力头数量
seq_len_q = 10 # 查询序列长度
seq_len_k = 12 # 键/值序列长度 (源序列长度)
batch_size = 4mha = MultiHeadAttention(d_model, num_heads)# 随机生成输入张量
input_query = torch.randn(batch_size, seq_len_q, d_model)
input_key = torch.randn(batch_size, seq_len_k, d_model)
input_value = torch.randn(batch_size, seq_len_k, d_model)# 创建一个简单的填充掩码示例
# 假设源序列的某些位置是填充(例如,所有 Query 对 Key 序列的最后两个位置不应有注意力)
# mask 的形状为 (batch_size, 1, 1, seq_len_k) 或 (batch_size, 1, seq_len_q, seq_len_k)
# 这里我们创建一个简单的填充掩码,表示 Key 序列的最后一个 token 是填充
# for batch 0, 1, 2, 3: last token is padding
src_mask = torch.ones(batch_size, 1, 1, seq_len_k) # (batch_size, num_heads广播维, query_len广播维, key_len)
src_mask[:, :, :, -1] = 0 # 将 key 序列的最后一个位置设置为0 (表示填充)output_mha, weights_mha = mha(input_query, input_key, input_value, mask=src_mask)print(f"Input Query shape: {input_query.shape}")
print(f"Input Key shape: {input_key.shape}")
print(f"Input Value shape: {input_value.shape}")
print(f"Source Mask shape: {src_mask.shape}")
print(f"Multi-Head Attention Output shape: {output_mha.shape}")
print(f"Multi-Head Attention Weights shape: {weights_mha.shape}")# 验证掩码效果:检查第一个批次的第一个头的注意力权重,其最后一列是否接近于0
print("\nExample of masked attention weights (batch 0, head 0):")
print(weights_mha[0, 0, :, :])
Applications of Attention in our Model
编码器-解码器注意力 (Encoder-Decoder Attention):
位置:发生在解码器的每一层中。
作用:让解码器中的每个位置(词)能够关注到编码器输出的所有位置(词)。
Q, K, V 来源:
Query (Q) 来自于前一个解码器层的输出。
Key (K) 和 Value (V) 来自于编码器的输出。
目的:这模仿了传统的序列到序列(Seq2Seq)模型中的注意力机制,允许解码器在生成目标序列时,根据输入序列的内容进行信息检索。
编码器自注意力 (Encoder Self-Attention):
位置:发生在编码器的每一层中。
作用:让编码器中的每个位置(词)能够关注到同一编码器层中所有位置(词)。
Q, K, V 来源:所有 (Query, Key, Value) 都来自于前一个编码器层的输出。
目的:捕获输入序列内部词语之间的关系,例如,在句子“The animal didn't cross the street because it was too tired”中,"it"指向"animal"的信息。
解码器自注意力 (Decoder Self-Attention):
位置:发生在解码器的每一层中。
作用:让解码器中的每个位置(词)能够关注到当前解码器层中,当前位置及其之前所有位置(词)。
Q, K, V 来源:所有 (Query, Key, Value) 都来自于前一个解码器层的输出。
目的:为了保持自回归(auto-regressive)属性,即在生成当前词时,只能依赖已经生成的词,而不能“看到”未来的词。
实现方式:通过在缩放点积注意力内部应用掩码(masking)实现。具体做法是将对应非法连接(即未来位置)的注意力得分设置为负无穷(在 softmax 之前),这样它们在 softmax 后会变成接近于零的权重。这被称为“因果掩码”(Casual Mask)或“Look-Ahead Mask”。
import torch
import torch.nn as nn# --- 核心辅助函数和模块 ---def scaled_dot_product_attention(query, key, value, mask=None):"""实现缩放点积注意力。Args:query, key, value: 张量,形状通常为 (batch_size, num_heads, seq_len_q/k/v, d_k)。mask: 注意力掩码,形状通常为 (batch_size, 1, seq_len_q, seq_len_k)。Returns:output (torch.Tensor): 注意力机制的输出。attention_weights (torch.Tensor): 注意力权重。"""d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf')) # 将掩码为0的地方设为负无穷attention_weights = torch.softmax(scores, dim=-1)output = torch.matmul(attention_weights, value)return output, attention_weightsclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, dropout_rate=0.1):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"self.d_k = d_model // num_headsself.num_heads = num_headsself.d_model = d_modelself.query_linear = nn.Linear(d_model, d_model)self.key_linear = nn.Linear(d_model, d_model)self.value_linear = nn.Linear(d_model, d_model)self.output_linear = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout_rate)def forward(self, query, key, value, mask=None):batch_size = query.size(0)# 线性投影并分割成多头query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)# 对每个头执行缩放点积注意力output_heads, attention_weights = scaled_dot_product_attention(query, key, value, mask)# 拼接所有头的输出concat_output = output_heads.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)# 最终的线性变换final_output = self.output_linear(concat_output)final_output = self.dropout(final_output)return final_output, attention_weightsclass PositionwiseFeedForward(nn.Module):def __init__(self, d_model, d_ff, dropout=0.1):super(PositionwiseFeedForward, self).__init__()self.fc1 = nn.Linear(d_model, d_ff)self.fc2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)self.relu = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x)return xclass SublayerConnection(nn.Module):"""一个残差连接,接着层归一化和dropout"""def __init__(self, size, dropout):super(SublayerConnection, self).__init__()self.norm = nn.LayerNorm(size)self.dropout = nn.Dropout(dropout)def forward(self, x, sublayer):"Apply residual connection to any sublayer with the same size."# pre-norm 方式return x + self.dropout(sublayer(self.norm(x)))# --- Mask 生成函数 ---def create_padding_mask(seq, pad_idx):"""生成填充掩码。Args:seq (torch.Tensor): 输入序列,形状 (batch_size, seq_len)。pad_idx (int): 填充 token 的索引。Returns:torch.Tensor: 填充掩码,形状 (batch_size, 1, 1, seq_len)。值为1表示非填充,值为0表示填充。"""# (batch_size, 1, 1, seq_len)mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2).float()return maskdef create_look_ahead_mask(seq_len):"""生成前瞻掩码 (casual mask)。Args:seq_len (int): 序列长度。Returns:torch.Tensor: 前瞻掩码,形状 (1, 1, seq_len, seq_len)。上三角部分为0,下三角和对角线为1。"""# (seq_len, seq_len) 的上三角矩阵,对角线以上为1,其他为0# 然后取反,得到下三角矩阵,上三角和对角线为0,其他为1# 最后转换为 float 类型look_ahead_mask = (torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 0).float()# 增加维度以兼容 (batch_size, num_heads, seq_len_q, seq_len_k)return look_ahead_mask.unsqueeze(0).unsqueeze(0)# --- 三种注意力机制的模块化实现 ---class EncoderSelfAttention(nn.Module):def __init__(self, d_model, num_heads, dropout_rate):super(EncoderSelfAttention, self).__init__()self.attention = MultiHeadAttention(d_model, num_heads, dropout_rate)self.sublayer_conn = SublayerConnection(d_model, dropout_rate)def forward(self, x, padding_mask):"""Args:x (torch.Tensor): 编码器层的输入,形状 (batch_size, seq_len_src, d_model)。padding_mask (torch.Tensor): 源序列的填充掩码,形状 (batch_size, 1, 1, seq_len_src)。"""# Q, K, V 都来自 x# padding_mask 用于掩盖源序列中的填充部分output, _ = self.sublayer_conn(x, lambda x_norm: self.attention(x_norm, x_norm, x_norm, padding_mask))return outputclass DecoderSelfAttention(nn.Module):def __init__(self, d_model, num_heads, dropout_rate):super(DecoderSelfAttention, self).__init__()self.attention = MultiHeadAttention(d_model, num_heads, dropout_rate)self.sublayer_conn = SublayerConnection(d_model, dropout_rate)def forward(self, x, look_ahead_mask):"""Args:x (torch.Tensor): 解码器层的输入,形状 (batch_size, seq_len_tgt, d_model)。look_ahead_mask (torch.Tensor): 前瞻掩码,形状 (1, 1, seq_len_tgt, seq_len_tgt)。"""# Q, K, V 都来自 x# look_ahead_mask 用于防止信息从未来位置泄露output, _ = self.sublayer_conn(x, lambda x_norm: self.attention(x_norm, x_norm, x_norm, look_ahead_mask))return outputclass EncoderDecoderAttention(nn.Module):def __init__(self, d_model, num_heads, dropout_rate):super(EncoderDecoderAttention, self).__init__()self.attention = MultiHeadAttention(d_model, num_heads, dropout_rate)self.sublayer_conn = SublayerConnection(d_model, dropout_rate)def forward(self, decoder_output, encoder_output, src_padding_mask):"""Args:decoder_output (torch.Tensor): 前一个解码器层或解码器自注意力的输出,形状 (batch_size, seq_len_tgt, d_model)。encoder_output (torch.Tensor): 编码器的最终输出,形状 (batch_size, seq_len_src, d_model)。src_padding_mask (torch.Tensor): 源序列的填充掩码,形状 (batch_size, 1, 1, seq_len_src)。"""# Query 来自解码器输出 (decoder_output)# Key 和 Value 来自编码器输出 (encoder_output)# src_padding_mask 用于掩盖编码器输出中的填充部分output, _ = self.sublayer_conn(decoder_output, lambda x_norm: self.attention(query=x_norm, key=encoder_output, value=encoder_output, mask=src_padding_mask))return output# --- Demo 参数 ---
d_model = 512
num_heads = 8
dropout_rate = 0.1
batch_size = 2
seq_len_src = 50 # 编码器输入序列长度
seq_len_tgt = 60 # 解码器输入序列长度
pad_idx = 0 # 假设填充token的索引是0# --- 模拟输入数据 ---
# 编码器输入(假设已经过词嵌入和位置编码)
encoder_input_tensor = torch.randn(batch_size, seq_len_src, d_model)
# 解码器输入(假设已经过词嵌入和位置编码)
decoder_input_tensor = torch.randn(batch_size, seq_len_tgt, d_model)# 模拟源序列(用于生成填充掩码)
# 假设第二个 batch 的源序列有填充,其真实长度为 45
dummy_src_ids = torch.randint(1, 1000, (batch_size, seq_len_src))
dummy_src_ids[1, 45:] = pad_idx # 第二个 batch 的最后5个位置是填充
src_padding_mask = create_padding_mask(dummy_src_ids, pad_idx)# 解码器自注意力的前瞻掩码
look_ahead_mask = create_look_ahead_mask(seq_len_tgt)print("--- 演示三种注意力机制 ---")# 1. 编码器自注意力 (Encoder Self-Attention)
print("\n--- Encoder Self-Attention Demo ---")
encoder_self_attn_layer = EncoderSelfAttention(d_model, num_heads, dropout_rate)
# x = encoder_input_tensor
encoder_self_attn_output = encoder_self_attn_layer(encoder_input_tensor, src_padding_mask)
print(f"Encoder Self-Attention output shape: {encoder_self_attn_output.shape}")
# 注意:在真实的Transformer编码器中,FFN会在自注意力之后。这里仅演示注意力层。# 2. 解码器自注意力 (Decoder Self-Attention)
print("\n--- Decoder Self-Attention Demo ---")
decoder_self_attn_layer = DecoderSelfAttention(d_model, num_heads, dropout_rate)
# x = decoder_input_tensor
decoder_self_attn_output = decoder_self_attn_layer(decoder_input_tensor, look_ahead_mask)
print(f"Decoder Self-Attention output shape: {decoder_self_attn_output.shape}")
# 验证掩码效果:解码器自注意力不会关注未来的token
# 我们可以尝试提取权重,但MultiHeadAttention只返回了最终输出,您可以修改其返回权重来观察。# 3. 编码器-解码器注意力 (Encoder-Decoder Attention)
print("\n--- Encoder-Decoder Attention Demo ---")
encoder_decoder_attn_layer = EncoderDecoderAttention(d_model, num_heads, dropout_rate)
# Query = decoder_self_attn_output (或 decoder_input_tensor,取决于层的位置)
# Key/Value = encoder_output (假设这是编码器的最终输出)
encoder_decoder_attn_output = encoder_decoder_attn_layer(decoder_self_attn_output, # 通常是解码器前一个子层的输出encoder_input_tensor, # 编码器的输出src_padding_mask # 源序列的填充掩码
)
print(f"Encoder-Decoder Attention output shape: {encoder_decoder_attn_output.shape}")
✨词嵌入 (Word Embedding) 和 位置编码 (Positional Encoding)
词嵌入使用 nn.Embedding
,位置编码则需要手动实现。
import torch
import torch.nn as nnclass WordEmbeddings(nn.Module):def __init__(self, vocab_size, d_model):super(WordEmbeddings, self).__init__()self.embedding = nn.Embedding(vocab_size, d_model)def forward(self, x):return self.embedding(x)class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0) # Add batch dimensionself.register_buffer('pe', pe) # Register as a buffer, not a parameterdef forward(self, x):# x: (batch_size, seq_len, d_model)# self.pe: (1, max_len, d_model)return x + self.pe[:, :x.size(1)]# --- Demo ---
vocab_size = 10000
d_model = 512
seq_len = 50
batch_size = 2word_embedder = WordEmbeddings(vocab_size, d_model)
pos_encoder = PositionalEncoding(d_model)input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) # Dummy input IDs
word_embedded = word_embedder(input_ids)
output = pos_encoder(word_embedded)print("Word embedded shape:", word_embedded.shape)
print("Output with positional encoding shape:", output.shape)
✨多头自注意力机制 (Multi-Head Self-Attention)
PyTorch提供了 nn.MultiheadAttention
,这里我们封装一下以适应Transformer的结构。
import torch
import torch.nn as nn
import mathclass MultiHeadAttention(nn.Module):def __init__(self, d_model, n_heads, dropout=0.1):super(MultiHeadAttention, self).__init__()assert d_model % n_heads == 0, "d_model must be divisible by n_heads"self.d_model = d_model # <-- 添加这一行,将 d_model 保存为成员变量self.d_k = d_model // n_headsself.n_heads = n_headsself.query_proj = nn.Linear(d_model, d_model)self.key_proj = nn.Linear(d_model, d_model)self.value_proj = nn.Linear(d_model, d_model)self.fc_out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):# query, key, value: (batch_size, seq_len, d_model)batch_size = query.shape[0]# 1. 线性变换并分割成多头# (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads, d_k) -> (batch_size, n_heads, seq_len, d_k)query = self.query_proj(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)key = self.key_proj(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)value = self.value_proj(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)# 2. 计算注意力分数# (batch_size, n_heads, seq_len, d_k) @ (batch_size, n_heads, d_k, seq_len) -> (batch_size, n_heads, seq_len, seq_len)scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)# 3. 应用掩码(如果存在)if mask is not None:# 确保 mask 能够正确广播,通常 mask 的维度为 (batch_size, 1, 1, seq_len) 或 (batch_size, 1, seq_len, seq_len)scores = scores.masked_fill(mask == 0, float('-inf')) # 将掩码为0的地方设为负无穷# 4. softmax 归一化attention = torch.softmax(scores, dim=-1)attention = self.dropout(attention) # Dropout for attention weights# 5. 加权求和# (batch_size, n_heads, seq_len, seq_len) @ (batch_size, n_heads, seq_len, d_k) -> (batch_size, n_heads, seq_len, d_k)x = torch.matmul(attention, value)# 6. 拼接多头并线性变换# (batch_size, n_heads, seq_len, d_k) -> (batch_size, seq_len, n_heads, d_k) -> (batch_size, seq_len, d_model)# 在 .view() 中使用 self.d_modelx = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) output = self.fc_out(x)return output# --- Demo ---
d_model = 512
n_heads = 8
seq_len = 50
batch_size = 2mha = MultiHeadAttention(d_model, n_heads)# Dummy input
x = torch.randn(batch_size, seq_len, d_model)
# Mask for padded tokens (e.g., if some sequences are shorter)
# Example: batch_size=2, seq_len=50, first sequence len=40, second len=30
# Realistically, mask would be generated based on padding.
src_mask = torch.ones(batch_size, 1, seq_len, seq_len) # No mask for this demo
# For decoder masked attention, you would use a look-ahead mask
# look_ahead_mask = (torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 0).unsqueeze(0).unsqueeze(0)
# mha(x, x, x, mask=look_ahead_mask)output = mha(x, x, x, mask=src_mask)print("Multi-Head Attention output shape:", output.shape)
✨前馈神经网络 (Feed-Forward Network, FFN)
import torch.nn as nnclass PositionwiseFeedForward(nn.Module):def __init__(self, d_model, d_ff, dropout=0.1):super(PositionwiseFeedForward, self).__init__()self.fc1 = nn.Linear(d_model, d_ff)self.fc2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)self.relu = nn.ReLU()def forward(self, x):# x: (batch_size, seq_len, d_model)x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x)return x# --- Demo ---
d_model = 512
d_ff = 2048 # Usually d_ff is 4 * d_model
seq_len = 50
batch_size = 2ffn = PositionwiseFeedForward(d_model, d_ff)# Dummy input
x = torch.randn(batch_size, seq_len, d_model)
output = ffn(x)print("Feed-Forward Network output shape:", output.shape)
✨残差连接 (Residual Connections) 和 层归一化 (Layer Normalization)
这通常作为一个通用模块使用,或者直接集成到编码器/解码器层中。
import torch.nn as nnclass SublayerConnection(nn.Module):"""一个残差连接,接着层归一化和dropout"""def __init__(self, size, dropout):super(SublayerConnection, self).__init__()self.norm = nn.LayerNorm(size)self.dropout = nn.Dropout(dropout)def forward(self, x, sublayer):"Apply residual connection to any sublayer with the same size."# sublayer(self.norm(x)) 是先归一化,再通过子层# 这种方式是“pre-norm”,即在残差连接和子层之前进行归一化,通常效果更好return x + self.dropout(sublayer(self.norm(x)))# --- Demo ---
d_model = 512
dropout_rate = 0.1
seq_len = 50
batch_size = 2# Example usage with a dummy sublayer (e.g., FFN)
dummy_sublayer = nn.Linear(d_model, d_model) # Simplified sublayer
sublayer_conn = SublayerConnection(d_model, dropout_rate)# Dummy input
x = torch.randn(batch_size, seq_len, d_model)
output = sublayer_conn(x, dummy_sublayer)print("SublayerConnection output shape:", output.shape)
✨编码器层 (Encoder Layer)
一个编码器层包含多头自注意力、残差连接、层归一化和前馈网络。
import torch.nn as nnclass EncoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, dropout):super(EncoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)self.sublayer_conn1 = SublayerConnection(d_model, dropout)self.sublayer_conn2 = SublayerConnection(d_model, dropout)self.dropout = nn.Dropout(dropout)def forward(self, x, mask):# Multi-Head Self-Attention with residual connection and layer normx = self.sublayer_conn1(x, lambda x: self.self_attn(x, x, x, mask))# Feed-Forward Network with residual connection and layer normx = self.sublayer_conn2(x, self.feed_forward)return x# --- Demo ---
d_model = 512
n_heads = 8
d_ff = 2048
dropout_rate = 0.1
seq_len = 50
batch_size = 2encoder_layer = EncoderLayer(d_model, n_heads, d_ff, dropout_rate)# Dummy input and mask
x = torch.randn(batch_size, seq_len, d_model)
src_mask = torch.ones(batch_size, 1, seq_len, seq_len) # No mask for this demooutput = encoder_layer(x, src_mask)
print("EncoderLayer output shape:", output.shape)
✨解码器层 (Decoder Layer)
一个解码器层包含带掩码的多头自注意力、编码器-解码器注意力、残差连接、层归一化和前馈网络。
import torch
import torch.nn as nn
import math# 确保 MultiHeadAttention 和 SublayerConnection 已经定义,并包含之前修复的 d_model 属性class MultiHeadAttention(nn.Module):def __init__(self, d_model, n_heads, dropout=0.1):super(MultiHeadAttention, self).__init__()assert d_model % n_heads == 0, "d_model must be divisible by n_heads"self.d_model = d_model # 确保 d_model 已保存self.d_k = d_model // n_headsself.n_heads = n_headsself.query_proj = nn.Linear(d_model, d_model)self.key_proj = nn.Linear(d_model, d_model)self.value_proj = nn.Linear(d_model, d_model)self.fc_out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):batch_size = query.shape[0]query = self.query_proj(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)key = self.key_proj(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)value = self.value_proj(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:# 这里的 mask 应该能够广播到 scores 的形状 (batch_size, n_heads, seq_len_q, seq_len_k)# 例如,对于 src_mask,它通常是 (batch_size, 1, 1, seq_len_src)# 对于 tgt_mask,它通常是 (1, 1, seq_len_tgt, seq_len_tgt)scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)attention = self.dropout(attention)x = torch.matmul(attention, value)x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)output = self.fc_out(x)return outputclass SublayerConnection(nn.Module):def __init__(self, size, dropout):super(SublayerConnection, self).__init__()self.norm = nn.LayerNorm(size)self.dropout = nn.Dropout(dropout)def forward(self, x, sublayer):return x + self.dropout(sublayer(self.norm(x)))class PositionwiseFeedForward(nn.Module):def __init__(self, d_model, d_ff, dropout=0.1):super(PositionwiseFeedForward, self).__init__()self.fc1 = nn.Linear(d_model, d_ff)self.fc2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)self.relu = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x)return x# --- 解码器层 (Decoder Layer) ---
class DecoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, dropout):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, n_heads, dropout) # Masked self-attentionself.src_attn = MultiHeadAttention(d_model, n_heads, dropout) # Encoder-decoder attentionself.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)self.sublayer_conn1 = SublayerConnection(d_model, dropout)self.sublayer_conn2 = SublayerConnection(d_model, dropout)self.sublayer_conn3 = SublayerConnection(d_model, dropout)self.dropout = nn.Dropout(dropout)def forward(self, x, memory, src_mask, tgt_mask):# x: decoder input (target sequence)# memory: encoder output# src_mask: mask for encoder output (source sequence)# tgt_mask: mask for decoder self-attention (look-ahead mask)# 1. Masked Multi-Head Self-Attention (Query, Key, Value 都是 x)# tgt_mask 的形状应为 (1, 1, seq_len_tgt, seq_len_tgt) 以适配 (batch_size, n_heads, seq_len_tgt, seq_len_tgt)x = self.sublayer_conn1(x, lambda x_norm: self.self_attn(x_norm, x_norm, x_norm, tgt_mask))# 2. Encoder-Decoder Attention (Query 是 x, Key/Value 是 memory)# src_mask 的形状应为 (batch_size, 1, 1, seq_len_src) 或 (batch_size, 1, seq_len_tgt, seq_len_src)# 以适配 scores 的形状 (batch_size, n_heads, seq_len_tgt, seq_len_src)x = self.sublayer_conn2(x, lambda x_norm: self.src_attn(x_norm, memory, memory, src_mask))# 3. Feed-Forward Networkx = self.sublayer_conn3(x, self.feed_forward)return x# --- Demo ---
d_model = 512
n_heads = 8
d_ff = 2048
num_layers = 6
dropout_rate = 0.1
seq_len_src = 50 # Source sequence length
seq_len_tgt = 60 # Target sequence length
batch_size = 2decoder_layer = DecoderLayer(d_model, n_heads, d_ff, dropout_rate)# Dummy inputs
decoder_input = torch.randn(batch_size, seq_len_tgt, d_model)
encoder_output = torch.randn(batch_size, seq_len_src, d_model)# --- 关键修改部分 ---
# src_mask 应该与 encoder_output 的序列长度匹配 (seq_len_src)
# 它的形状应为 (batch_size, 1, 1, seq_len_src)
# 这样在与 scores (batch_size, n_heads, seq_len_tgt, seq_len_src) 广播时,
# 可以正确地应用到 Key 的维度 (seq_len_src)。
src_mask = torch.ones(batch_size, 1, 1, seq_len_src)# Look-ahead mask for target sequence
# tgt_mask 的形状应为 (1, 1, seq_len_tgt, seq_len_tgt) 以适配 Decoder self-attention
tgt_mask = (torch.triu(torch.ones(seq_len_tgt, seq_len_tgt), diagonal=1) == 0).unsqueeze(0).unsqueeze(0).float()output = decoder_layer(decoder_input, encoder_output, src_mask, tgt_mask)
print("DecoderLayer output shape:", output.shape)
✨编码器 (Encoder)
由多个编码器层堆叠而成。
import torch.nn as nnclass Encoder(nn.Module):def __init__(self, d_model, n_heads, d_ff, num_layers, dropout):super(Encoder, self).__init__()self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])self.norm = nn.LayerNorm(d_model) # Final layer normalizationdef forward(self, x, mask):# x: input_embeddings + positional_encodingsfor layer in self.layers:x = layer(x, mask)return self.norm(x) # Apply final layer norm# --- Demo ---
d_model = 512
n_heads = 8
d_ff = 2048
num_layers = 6 # Standard Transformer uses 6 layers
dropout_rate = 0.1
seq_len = 50
batch_size = 2encoder = Encoder(d_model, n_heads, d_ff, num_layers, dropout_rate)# Dummy input (already passed through word embedding and positional encoding)
input_tensor = torch.randn(batch_size, seq_len, d_model)
src_mask = torch.ones(batch_size, 1, seq_len, seq_len) # Example src maskoutput = encoder(input_tensor, src_mask)
print("Encoder output shape:", output.shape)
✨解码器 (Decoder)
由多个解码器层堆叠而成。
import torch
import torch.nn as nn
import math# 确保 MultiHeadAttention, SublayerConnection, PositionwiseFeedForward,
# EncoderLayer, DecoderLayer 等之前所有的模块都已经正确定义和修复。
# 为了保持完整性,我将相关依赖的类也包含进来,但只显示Decoder及其Demo的修改。class MultiHeadAttention(nn.Module):def __init__(self, d_model, n_heads, dropout=0.1):super(MultiHeadAttention, self).__init__()assert d_model % n_heads == 0, "d_model must be divisible by n_heads"self.d_model = d_modelself.d_k = d_model // n_headsself.n_heads = n_headsself.query_proj = nn.Linear(d_model, d_model)self.key_proj = nn.Linear(d_model, d_model)self.value_proj = nn.Linear(d_model, d_model)self.fc_out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):batch_size = query.shape[0]query = self.query_proj(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)key = self.key_proj(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)value = self.value_proj(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf')) attention = torch.softmax(scores, dim=-1)attention = self.dropout(attention) x = torch.matmul(attention, value)x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) output = self.fc_out(x)return outputclass SublayerConnection(nn.Module):def __init__(self, size, dropout):super(SublayerConnection, self).__init__()self.norm = nn.LayerNorm(size)self.dropout = nn.Dropout(dropout)def forward(self, x, sublayer):return x + self.dropout(sublayer(self.norm(x)))class PositionwiseFeedForward(nn.Module):def __init__(self, d_model, d_ff, dropout=0.1):super(PositionwiseFeedForward, self).__init__()self.fc1 = nn.Linear(d_model, d_ff)self.fc2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)self.relu = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x)return xclass DecoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, dropout):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, n_heads, dropout) self.src_attn = MultiHeadAttention(d_model, n_heads, dropout) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)self.sublayer_conn1 = SublayerConnection(d_model, dropout)self.sublayer_conn2 = SublayerConnection(d_model, dropout)self.sublayer_conn3 = SublayerConnection(d_model, dropout)self.dropout = nn.Dropout(dropout)def forward(self, x, memory, src_mask, tgt_mask):x = self.sublayer_conn1(x, lambda x_norm: self.self_attn(x_norm, x_norm, x_norm, tgt_mask))x = self.sublayer_conn2(x, lambda x_norm: self.src_attn(x_norm, memory, memory, src_mask))x = self.sublayer_conn3(x, self.feed_forward)return x# --- 解码器 (Decoder) ---
class Decoder(nn.Module):def __init__(self, d_model, n_heads, d_ff, num_layers, dropout):super(Decoder, self).__init__()self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])self.norm = nn.LayerNorm(d_model) # Final layer normalizationdef forward(self, x, memory, src_mask, tgt_mask):# x: decoder_input_embeddings + positional_encodings# memory: encoder outputfor layer in self.layers:x = layer(x, memory, src_mask, tgt_mask)return self.norm(x) # Apply final layer norm# --- Demo ---
d_model = 512
n_heads = 8
d_ff = 2048
num_layers = 6
dropout_rate = 0.1
seq_len_src = 50
seq_len_tgt = 60
batch_size = 2decoder = Decoder(d_model, n_heads, d_ff, num_layers, dropout_rate)# Dummy inputs
decoder_input = torch.randn(batch_size, seq_len_tgt, d_model)
encoder_output = torch.randn(batch_size, seq_len_src, d_model)# --- 关键修改部分 ---
# src_mask 应该与 encoder_output 的序列长度匹配 (seq_len_src)
# 它的形状应为 (batch_size, 1, 1, seq_len_src)
# 这样在与 scores (batch_size, n_heads, seq_len_tgt, seq_len_src) 广播时,
# 可以正确地应用到 Key 的维度 (seq_len_src)。
src_mask = torch.ones(batch_size, 1, 1, seq_len_src) # <-- 这里是修改点# Look-ahead mask for target sequence
# tgt_mask 的形状应为 (1, 1, seq_len_tgt, seq_len_tgt) 以适配 Decoder self-attention
tgt_mask = (torch.triu(torch.ones(seq_len_tgt, seq_len_tgt), diagonal=1) == 0).unsqueeze(0).unsqueeze(0).float()output = decoder(decoder_input, encoder_output, src_mask, tgt_mask)
print("Decoder output shape:", output.shape)
✨最终线性层 (Linear Layer for Output)
解码器输出的向量通常会通过一个线性层,然后是softmax激活函数,将其映射到词汇表的大小,从而得到每个词的概率分布。
import torch.nn as nnclass OutputLinear(nn.Module):def __init__(self, d_model, vocab_size):super(OutputLinear, self).__init__()self.linear = nn.Linear(d_model, vocab_size)def forward(self, x):# x: (batch_size, seq_len_tgt, d_model) from decoder outputreturn self.linear(x) # -> (batch_size, seq_len_tgt, vocab_size)# --- Demo ---
d_model = 512
vocab_size = 10000
seq_len_tgt = 60
batch_size = 2output_layer = OutputLinear(d_model, vocab_size)# Dummy input (e.g., from decoder output)
decoder_output = torch.randn(batch_size, seq_len_tgt, d_model)
output_logits = output_layer(decoder_output)print("Output Linear layer shape:", output_logits.shape)
⭐ 转载请注明出处
本文作者:双份浓缩馥芮白
原文链接:https://www.cnblogs.com/Flat-White/p/19009218
版权所有,如需转载请注明出处。