推荐系统 教程

推荐系统 召回算法

推荐系统 排序层

推荐系统 笔记

original icon
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://www.knowledgedict.com/tutorial/rec-mmoe.html

详解谷歌的 MMoE(Multi-gate Mixture-of-Experts )模型(附 tensorflow 代码实现)


本章主要介绍 Google 发表在 KDD 2018 上的经典的多任务学习模型 MMoE(Multi-gate Mixture-of-Experts),它主要的使用工业场景是不相关任务的多任务学习,这里不相关任务以常见的示例来讲,如视频流推荐中的 CTR、点赞、时长、完播、分享等相关性不强的多个任务。

背景

阿里巴巴的 ESMM 模型几乎是为广告 CTR 和 CVR 专门量身打造的,在 ESMM 模型结构中,两个塔具有明确的依赖关系。在 MTL 模型中,基本都是 N 个塔共享底座(shared bottom)embedding,然后不同的任务分不同的塔,这种模式需要这些塔之间具有比较强的相关性,不然性能就很差,甚至会发生跷跷板现象,即一个 task 性能的提升是通过损害另一个 task 性能作为代价换来的。因此,如果两个 task 都有足够的数据量,这种共享底座 embedding 的多塔设计的性能并没有分开单独建模效果来得好,原因是几乎必然出现负迁移(negative transfer)和跷跷板现象。因此在实际应用中,并不要盲目的为了 MTL 而 MTL,这样只会弄巧成拙。

但如果又有多个目标,多个 tower 之间的相关性并不是很强,比如,CTR、点赞、时长、完播、分享等,并且有的目标的数据量并不是很足够,甚至无法单独训练一个DNN,这时候 MMoE 就可以派上用场了。

这种多任务的迭代进程是从比较简单的 stacking 多模型融合硬性(hard)的 shared bottom layer软性(soft)的 shared bottom layer(专家层)

模型

先附上论文中的 MMoE 模型图:

mmoe 模型

图1. MMoE整体网络结构

上图结构说明:

  • (a)展示了传统的 MTL 模型结构,即多个 task 共享底座 shared bottom(一般都是 embedding 向量)。
  • (b)则是论文中提到的一个 Gate 的 Mixture-of-Experts 模型结构。
  • (c)则是论文中的 MMoE 模型结构。

我们来进一步解析 MMoE 结构,也就是图1 中的 (c),这里每一个 Expert 和 Gate 都是一个全连接网络(MLP),层数由在实际的场景下自己决定。

MMoE 为每一个模型目标设置一个 Gate,所有的目标共享多个 Expert,每个 Expert 通常是数层规模比较小的全连接层。Gate 用来选择每个 Expert 的信号占比(或者叫权重)。每个 Expert 都有其擅长的预测方向,最后共同作用于上面的多个目标。

下图是内部详细的 MMoE 模型结构图,读者可以直接对着下图,进行代码实现即可。

mmoe 模型

图2. MMoE 模型内部细节版

注:GateB 参考 GateA 即可。

上图中,我们需要注意如下几点细节:

  1. Gate 的数量取决于 task 数量,即有几个 Task 就有几个 Gate。Gate 网络最后一层全连接层的隐藏单元(即输出)size 必须等于 Expert 的个数。此外,Gate 网络最后的输出会经过 softmax 进行归一化
  2. Gate 网络最后一层全连接层经过 softmax 归一化后的输出,对应作用到每一个 Expert 上(上图中 GateA 输出的红、紫、绿三条线分别作用与 Expert0,Expert1,Expert2),注意是通过广播机制作用到 Expert 中的每一个隐藏单元,比如红线作用于 Expert0 的2个隐藏单元。这里 Gate 网络的作用非常类似于 attention 机制,提供了权重。
  3. 假设 GateA 的输出为 [ GA1, GA2, GA3 ],Expert0 的输出为 [ E01 , E02 ],Expert1 的输出为 [ E11 , E12 ],Expert2 的输出为 [ E21 , E22 ]。GateA 分别与 Expert0、Expert1、Expert2 作用,得到 [ GA1 ∗ E01 , GA1 ∗ E02 ] , [ GA2 ∗ E11 , GA2 ∗ E12 ] , [ GA3 ∗ E21 , GA3 ∗ E22 ],然后对应位置求和得到 TowerA 的输入,即 TowerA 的输入 size 等于 Expert 输出隐藏单元个数(在这个例子中,Expert 最后一层全连接层隐藏单元个数为2,因此 TowerA 的输入维度也为2),所以 TowerA 的输入为 [ GA1 ∗ E01 + GA2 ∗ E11 + GA3 ∗ E21 , GA1 ∗ E02 + GA2 ∗ E12 + GA3 ∗ E22 ]
  4. Expert 每个网络的输入特征都是一样的,其网络结构也是一致的
  5. 每个 Gate 网络的输入也是一样的,Gate 网络结构也是一样的

进一步通俗地解释每个模块:

  • 每个专家(expert)网络一般是层数规模比较小的 MLP 层,每个 expert 网络都有其擅长的学习方向,最后共同作用于上面的多个目标上。
  • Share-Bottom 形式的模型一定程度上限制了不同目标的特异性,预测的目标之间的相关性比较高,而多个 expert 网络(即底部引入 MoE 层)结构可以理解为一种 ensemble 的方法,再对每个任务学习一个Gate(门控网络),每个门控网络对不同专家分配不同的权重系数。
  • MMoE 为每一个模型学习目标设置一个 gate 网络,gate 网络用来学习选择每个 expert 网络输出的信号权重。

MMoE Tensorflow 2 实现

 

Redis CLUSTER SET-CONFIG-EPOCH 命令为一个全新的节点设置指定的 config epoch 配置,并且仅在 2 ...
Redis CLUSTER COUNT-FAILURE-REPORTS 命令返回指定节点的故障报告个数,故障报告是 Redis Cluste ...
Java中有多种方式可以将日期格式化为"yyyy-mm-dd"的形式。###方法二:使用第三方库:ApacheCommo ...
下面我将介绍两种常用的方法,分别是使用Java内置的日期库和使用第三方库,以及它们的具体步骤和代码示例。###方法二:使用第三方日期库使用` ...
在使用 spring-data-redis 库提供的 redis client 对象 RedisTemplate 进行 set EX NX ...