MMoE的一些浅薄经验及思考

近年来,多任务学习(Multi-task Learning, MTL)在推荐系统、广告点击率预测(CTR/CVR)等场景中得到了广泛应用。其中,**MMoE(Multi-gate Mixture-of-Experts)**作为一种经典的参数共享结构,有效缓解了任务冲突问题,提升了模型整体表现。在实践过程中,也积累了一些关于MMoE结构的理解与调优经验,本文将结合理论与工程视角展开总结。

MMoE介绍

MMoE结构的核心思想是:

“将共享网络表示为多个专家网络,通过多个任务特有的门控机制(gate)选择性地融合专家输出”。

相比传统的硬共享或软共享结构,MMoE的优势在于:

  • 专家共享但任务区分:每个任务有自己的gate,可以按需组合不同专家;
  • 解决任务冲突:提升多个任务之间的学习灵活性,避免梯度互相干扰。

架构组成:

  • Experts:多个共享的子网络(MLP/FFN),用于提取通用特征;
  • Gates:每个任务一个门控网络,输出softmax权重用于加权组合专家输出;
  • Towers:任务特定的输出层,通常是简单的MLP结构。

给出一个简单的推荐系统上MMoE的demo代码(不保证可以跑起来,让Claude阉割了我的原始代码生成的):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

class SimpleMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, out_dim, dropout_rate=0.1):
super(SimpleMLP, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, out_dim)
)

def forward(self, x):
return self.net(x)

class WeightMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, out_dim, dropout_rate=0.1, temperature=1.0):
super(WeightMLP, self).__init__()
self.temperature = temperature
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, out_dim),
)

def forward(self, x):
logits = self.net(x)
return F.softmax(logits / self.temperature, dim=-1)

class OutputMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, out_dim, dropout_rate=0.1):
super(OutputMLP, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, out_dim),
)

def forward(self, x):
return self.net(x)

class MultiTargetMMoE(nn.Module):
def __init__(self, cat_fea_dims=None, num_fea_dims=None, list_fea_dims=None,
expert_out_dim=64, expert_num=4, target_num=2, list_pooling="mean"):
super(MultiTargetMMoE, self).__init__()

self.cat_feas = list(cat_fea_dims.keys()) if cat_fea_dims else []
self.num_feas = list(num_fea_dims.keys()) if num_fea_dims else []
self.list_feas = list(list_fea_dims.keys()) if list_fea_dims else []
self.list_pooling = list_pooling
self.expert_num = expert_num
self.target_num = target_num

self.embedding_dict = {}
self.cat_embedding_layers = nn.ModuleList()
self.cat_embedding_dims = []
for fea in self.cat_feas:
num_embeddings, embedding_dim = cat_fea_dims[fea]
embedding_layer = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
self.cat_embedding_layers.append(embedding_layer)
self.cat_embedding_dims.append(embedding_dim)
self.embedding_dict[fea] = embedding_layer

self.list_embedding_layers = nn.ModuleList()
self.list_embedding_dims = []
for fea in self.list_feas:
embedding_dim = list_fea_dims[fea][1]
self.list_embedding_layers.append(self.embedding_dict["itemId"])
self.list_embedding_dims.append(embedding_dim)

self.num_fea_dims = [dim[1] for dim in num_fea_dims.values()] if num_fea_dims else []
input_dim = sum(self.cat_embedding_dims) + sum(self.num_fea_dims) + sum(self.list_embedding_dims)

self.experts = nn.ModuleList([
SimpleMLP(input_dim, 128, expert_out_dim)
for _ in range(expert_num)
])

self.weight_nets = nn.ModuleList([
WeightMLP(input_dim, 64, expert_num)
for _ in range(target_num)
])

self.output_nets = nn.ModuleList([
OutputMLP(expert_out_dim, 64, 1)
for _ in range(target_num)
])

def forward(self, cat_inputs: torch.Tensor, num_inputs: torch.Tensor, list_input: List[torch.Tensor]):
embedding_v = []
if cat_inputs is not None and self.cat_feas:
for i, embedding_layer in enumerate(self.cat_embedding_layers):
col = cat_inputs[:, i].long()
embedding_v.append(embedding_layer(col))

for i, embedding_layer in enumerate(self.list_embedding_layers):
list_tensor = list_input[i].long()
lengths = (list_tensor > 0).sum(dim=1).float().clamp(min=1.0)
embedded = embedding_layer(list_tensor)

if self.list_pooling == "mean":
mask = (list_tensor > 0).unsqueeze(-1).float()
pooled = (embedded * mask).sum(dim=1) / lengths.unsqueeze(-1)
else:
mask = (list_tensor > 0).unsqueeze(-1).float()
pooled = (embedded * mask).sum(dim=1) / lengths.unsqueeze(-1)

embedding_v.append(pooled)

if num_inputs is not None and self.num_feas:
concate_num_input = num_inputs.float()
x0 = torch.cat(embedding_v + [concate_num_input], dim=1)
else:
x0 = torch.cat(embedding_v, dim=1)

expert_outputs = [expert(x0) for expert in self.experts]
expert_outputs = torch.stack(expert_outputs, dim=1)

outputs = []
for weight_net, output_net in zip(self.weight_nets, self.output_nets):
gate = weight_net(x0) # [batch_size, expert_num]
gate_expanded = gate.unsqueeze(-1) # [batch_size, expert_num, 1]
weighted_expert_output = (expert_outputs * gate_expanded).sum(dim=1)
output = output_net(weighted_expert_output)
outputs.append(output)

return torch.cat(outputs, dim=1)

def main():
torch.manual_seed(42)
cat_fea_dims = {
'userId': (1000, 32),
'itemId': (5000, 32),
'category': (100, 16)
}
num_fea_dims = {
'age': (1, 1),
'price': (1, 1)
}
list_fea_dims = {
'history_items': (5000, 32)
}

# 创建模型
model = MultiTargetMMoE(
cat_fea_dims=cat_fea_dims,
num_fea_dims=num_fea_dims,
list_fea_dims=list_fea_dims,
expert_out_dim=64,
expert_num=4,
target_num=2, # 例如:点击率和转化率两个任务
list_pooling="mean"
)

batch_size = 32
cat_inputs = torch.randint(1, 100, (batch_size, 3))
num_inputs = torch.randn(batch_size, 2) # age, price
list_inputs = [torch.randint(1, 1000, (batch_size, 10))] # tags
outputs = model(cat_inputs, num_inputs, list_inputs)

print(f"模型输出形状: {outputs.shape}") # [batch_size, target_num]
print(f"任务1预测值: {outputs[:5, 0]}") # 前5个样本的任务1预测
print(f"任务2预测值: {outputs[:5, 1]}") # 前5个样本的任务2预测

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
labels = torch.randn(batch_size, 2)

optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

print(f"损失值: {loss.item():.4f}")
print("训练完成!")


if __name__ == "__main__":
main()

遇到的问题和解决方法

在实践中,MMoE也存在着一些问题,比如极化现象、任务相关性等问题,下面是针对本人实践MMoE的时候遇到的问题和解决方案:

任务之间依赖未建模

任务描述:

虽然MMoE通过多专家-多门控机制缓解了任务冲突,但在实际业务中,多个任务往往并非彼此独立,而是具有自然的因果或行为依赖。这种“任务依赖性”在推荐系统中非常常见,尤其是在用户行为递进链路中,典型的如:

  • 曝光 → 点击(click)→ 点赞(favour)→ 评论(comment)
  • 曝光 → 收藏(collect)→ 加购(add to cart)→ 购买(purchase)

以“点击(click)”与“点赞(favour)”为例:

  • 用户只有点击内容后才可能进行点赞;
  • 在训练数据中,点赞数据天然依赖于点击行为;
  • 因此,点赞任务的训练样本是点击样本的一个子集,而不是全部曝光样本。

如果我们直接基于曝光样本进行点赞(CVR)建模,会出现训练数据偏差(selection bias),本质上是一个标签条件分布发生偏移的问题

P(yfavourx)P(yfavourx,yclick=1)P(y_{favour}∣x)≠P(y_{favour}∣x,y_{click}=1)

这将导致:

  • 模型错误地学习了在未点击的情况下点赞的“伪模式”;
  • 训练数据中点赞标签为0的大多数其实是“未点击”,而不是“点击后未点赞”,使得负样本分布不真实;
  • CVR估计偏高或偏低,影响推荐排序的有效性。

解决方案:

  1. 抛弃掉这部分数据:在CVR任务上,拿曝光且点击的数据进行训练,不加入曝光未点击的数据。
  2. ESMM(之前已经提及过,这个本质上就是之前说的Selective Bias)

极化现象

这是MMoE最突出的问题之一,也叫做Gate输出塌缩。极化指的是在训练过程中,不同的Expert逐渐"专化"到特定任务上,导致某些Expert只服务于特定任务,而其他Expert被忽略。最终可能出现每个任务只使用一个Expert的情况,违背了Mixture-of-Experts的初衷。

解决方案:

  • 门控正则化: 在损失函数中加入门控权重的正则化项,鼓励权重分布更均匀
  • Expert平衡损失: 添加辅助损失确保每个Expert都被充分利用
  • 温度缩放: 在门控网络的softmax中引入温度参数,控制分布的尖锐程度
  • 梯度平衡: 使用梯度平衡技术,如GradNorm,平衡不同任务的梯度贡献

实际中,加入dropout层就可以很好地解决这个问题:


上面两个图就是加入dropout前/后的专家网络利用率。(加入一些噪声也可以)

结构设计问题

专家网络

  1. 专家网络结构太浅表达力不足,太深则梯度消失或训练不稳定。可视化结果如下:

  • 解决方案

    • 小型MLP,每层加上BatchNorm + Dropout;

    • 若任务差异较大,可尝试**“异构专家”(不同结构),提高路由灵活性。

  1. 专家(Experts)数量设置太少会导致表达能力不足,太多则易冗余、过拟合。

    解决方案

    • 通常4~8个专家已足够,过多时应结合门控权重分布热力图评估使用率;

    • 可引入 专家多样性正则项(如KL散度、熵)避免所有任务集中激活同一个专家;

    • 使用 L1/L2 正则限制专家参数范数,提升泛化性。