从理论到实现,手把手实现Attention网络

作者 | 梁唐
出品 | 公众号:Coder梁(ID:Coder_LT)
大家好 , 我是老梁 。
我们之前介绍了Transformer的核心——attention网络 , 我们之前只是介绍了它的原理 , 并且没有详细解释它的实现方法 。光聊理论难免显得有些空洞 , 所以我们来谈谈它的实现 。
为了帮助大家更好地理解 ,  这里我选了电商场景中的DIN模型来做切入点 。
一方面可以帮助大家理解现在电商系统中的推荐和广告系统中的商品排序都是怎么做的 , 另外我个人感觉DIN要比直接去硬啃transformer容易理解一些 。
我们可以先从attention网络的数据入手 , 它的输入数据有两个:一个是用户的历史行为序列 , 一个是待打分的item(以下称为target item) 。用户的历史行为序列本质上其实就是一个用户历史上有过交互的item的数组 。这里为了简化 , 我们假设已经完成了从item到embedding的转换 。
首先是target item , 它的shape应该是[B, E] 。这里的B指的是batch_size , 即训练时候一个批量的大小 。这里的E指的是embedding的长度 。也有一些文章里使用别的字母表示 , 这也没有一个硬性的标准 , 能看懂就行 。
我们再来看用户行为序列 , 除了batch_size和embedding长度之外 , 还需要一个额外的参数来表示行为序列的长度 , 通常我们用字母T 。对于所有的样本 , 我们都需要保证它的行为序列长度是T , 如果不足T的 , 则使用默认值补足 。如果超过T的 , 则进行截断 。如此 , 它的shape应该是[B, T, E] 。
根据attention网络的原理 , 我们需要根据行为序列中的每个item与target item的相似度 , 再根据相似度计算权重 。最后对这T个item的embedding进行加权求和 。求和之后 , 这T个item根据计算得到的权重合并得到一个embedding 。论文中说这个集成T个行为序列的embedding就是用户兴趣的表达 , 我们只需要将它和目标item拼接在一起发送到神经网络即可 , 就可以帮助模型更好地决策了 。这里用户兴趣的shape应该和item是一样的 , 也是[B, E] 。
简单总结一下 , 我们现在需要一个模块 , 它接收两个输入 , 一个是item的embedding , 一个是用户行为序列的embedding 。它的输出应该是[B, T] , 对应行为序列中T个item的权重 。剩下的问题就是怎么生成这个结果 。
原理讲完了 , 接下来讲讲实现 , 我们可以结合一下下面这两张论文中的结构图帮助理解 。

从理论到实现,手把手实现Attention网络

文章插图

从理论到实现,手把手实现Attention网络

文章插图
图片
首先 , 我们来统一一下输入的维度 , 手动将item的embedding这个二维的向量变成三维 , 即shape变成[B, 1, E] 。
这里一种做法是 , 手动循环T次 , 每次从行为序列中拿出一个item embedding , 和目标item的embedding拼接在一起丢进一个神经网络中得到一个分数 。
这种做法非常不推荐 , 一般在神经网络当中 , 我们不到万不得已 , 不手动循环 , 因为循环是线性计算 , 没办法利用GPU的并行计算来加速 。
对于当前问题来说 , 我们完全可以使用矩阵运算来代替 。通过使用expand/tile函数 , 将[B, 1, E]的item embedding复制T份 , 形状也变成[B, T, E] 。这样一来 , 两个输入的shape都变成了[B, T, E] , 我们就可以把它们拼接到一起变成[B, T, 2E] 。
然后经过一个输入是2E , 输出是1的神经网络 , 最终得到[B, T, 1]的结果 , 我们把它调换一下维度 , 变成[B, 1, T] , 这个就是我们想要的权重了 。
这里我找来一份Pytorch的代码 , 大家代入一下上面的逻辑去看一下 , 应该不难看懂 。
class LocalActivationUnit(nn.Module):
def __init__(self, hidden_units=(64, 32), embedding_dim=4, activation='sigmoid', dropout_rate=0, dice_dim=3,l2_reg=0, use_bn=False):super(LocalActivationUnit, self).__init__()self.dnn = DNN(inputs_dim=4 * embedding_dim,hidden_units=hidden_units,activation=activation,l2_reg=l2_reg,dropout_rate=dropout_rate,dice_dim=dice_dim,use_bn=use_bn)self.dense = nn.Linear(hidden_units[-1], 1)


推荐阅读