Transformer拆解
这篇博客主要记录Transformer架构的代码实现。以下是参考资料 - Attention is All you need - Attention? Attention! - lilian wen的tensorflow版本的实现 - illustrated-transformer 强烈建议看illustrated transformer这篇博客,是跟paper介绍的transformer架构完全对齐的 - pytorch transformer实现
这是pytorch的官方实现
斯坦福出的transformer架构的实现tutorial
我自己想实现的一遍的原因在于:
- transformer的文章读了很多遍,但是很多细节还是没有去深究。
- 斯坦福的实现完全遵照的是paper的架构,但是我觉得还是实现的过于复杂了,我想遵循lilian的tensorflow实现把原生的tranformer架构实现一下
- 我对pytorch的掌握没有tensorflow好,感觉现在pytorch基本上成为深度学习网络的主流,特别是大模型出来之后,hugginggface的transformer库也是支持pytorch更好一点,更大的社区。(此时有点后悔当时系统学习的是tensorflow而不是pytorch)
transformer的整体架构: encoder-decoder两大模块,encoder模块内有重复的6个子模块,decoder模块内也有重复的6个子模块。
我们采用自上而下的方式来看这两个模块
Transformer整体架构
1 | class Transformer(nn.Module): |
Tranformer Encoder
1 | def clones(module, N): |
这里实现了一个clones
帮助函数,我想过在这里用for循环,lilian在这里就是用的for循环:
1 | out = inp # now, (batch, seq_len, embed_size) |
注意这里的每一个encoder_layer的参数都是独立的,也就是有6份encoder_layer的参数需要训练,tensorflow为什么可行?是因为它这里使用了variable_scope的概念,上面的tensorflow实现每一次out和input_mask进来都是和不同的数值进行的运算。如果在pytorch中想实现这种方式,要先把encoder_layer复制六遍,每一次输入进来都拿不同的layer做运算。
encoder layer
接下来我们实现encoder layer中的细节部分,它包含两个sub-layer: 1) self-attention + Add&layer_Norm 2) position-wise feed forward + Add&layer_Norm
1 | class TransformerEncoderLayer(nn.Module): |
self attention
我一开始查阅的资料是illustrated-transformer, 这个博客内没有具体的实现。后来我参考的是lilian wen的tensorflow实现。在lilian的实现里对于multihead attention是这样写的:
1 | def multihead_attention(self, query, memory=None, mask=None, scope='attn'): |
以上的实现其实和博客内的内容有点相左,博客写的是:
As we’ll see next, with multi-headed attention we have not only one, but multiple sets of Query/Key/Value weight matrices (the Transformer uses eight attention heads, so we end up with eight sets for each encoder/decoder). Each of these sets is randomly initialized. Then, after training, each set is used to project the input embeddings (or vectors from lower encoders/decoders) into a different representation subspace.
结合作者给出的图片:
我一开始的理解是每一个head都有一份单独的W sets(WQ,WK,WV)。每一个head经过了scaled attention的计算
得到的Z的shape都是(batch, seq_len, embeded_size)
,所以才会有WO这个线性变化(blog里说的):
但我看完代码之后发现并不是我想的那样。我觉得这篇博客写的有点问题。后来又找到了一篇博客,能够解答我的疑问。它最重要的话是:
However, the important thing to understand is that this is a logical split only. The Query, Key, and Value are not physically split into separate matrices, one for each Attention head. A single data matrix is used for the Query, Key, and Value, respectively, with logically separate sections of the matrix for each Attention head. Similarly, there are not separate Linear layers, one for each Attention head. All the Attention heads share the same Linear layer but simply operate on their ‘own’ logical section of the data matrix.