1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
| import torch import torch.nn as nn import torch.nn.functional as F from typing import List
class SimpleMLP(nn.Module): def __init__(self, input_dim, hidden_dim, out_dim, dropout_rate=0.1): super(SimpleMLP, self).__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(hidden_dim, out_dim) )
def forward(self, x): return self.net(x)
class WeightMLP(nn.Module): def __init__(self, input_dim, hidden_dim, out_dim, dropout_rate=0.1, temperature=1.0): super(WeightMLP, self).__init__() self.temperature = temperature self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(hidden_dim, out_dim), )
def forward(self, x): logits = self.net(x) return F.softmax(logits / self.temperature, dim=-1)
class OutputMLP(nn.Module): def __init__(self, input_dim, hidden_dim, out_dim, dropout_rate=0.1): super(OutputMLP, self).__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(hidden_dim, out_dim), )
def forward(self, x): return self.net(x)
class MultiTargetMMoE(nn.Module): def __init__(self, cat_fea_dims=None, num_fea_dims=None, list_fea_dims=None, expert_out_dim=64, expert_num=4, target_num=2, list_pooling="mean"): super(MultiTargetMMoE, self).__init__()
self.cat_feas = list(cat_fea_dims.keys()) if cat_fea_dims else [] self.num_feas = list(num_fea_dims.keys()) if num_fea_dims else [] self.list_feas = list(list_fea_dims.keys()) if list_fea_dims else [] self.list_pooling = list_pooling self.expert_num = expert_num self.target_num = target_num self.embedding_dict = {} self.cat_embedding_layers = nn.ModuleList() self.cat_embedding_dims = [] for fea in self.cat_feas: num_embeddings, embedding_dim = cat_fea_dims[fea] embedding_layer = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0) self.cat_embedding_layers.append(embedding_layer) self.cat_embedding_dims.append(embedding_dim) self.embedding_dict[fea] = embedding_layer self.list_embedding_layers = nn.ModuleList() self.list_embedding_dims = [] for fea in self.list_feas: embedding_dim = list_fea_dims[fea][1] self.list_embedding_layers.append(self.embedding_dict["itemId"]) self.list_embedding_dims.append(embedding_dim) self.num_fea_dims = [dim[1] for dim in num_fea_dims.values()] if num_fea_dims else [] input_dim = sum(self.cat_embedding_dims) + sum(self.num_fea_dims) + sum(self.list_embedding_dims)
self.experts = nn.ModuleList([ SimpleMLP(input_dim, 128, expert_out_dim) for _ in range(expert_num) ])
self.weight_nets = nn.ModuleList([ WeightMLP(input_dim, 64, expert_num) for _ in range(target_num) ])
self.output_nets = nn.ModuleList([ OutputMLP(expert_out_dim, 64, 1) for _ in range(target_num) ])
def forward(self, cat_inputs: torch.Tensor, num_inputs: torch.Tensor, list_input: List[torch.Tensor]): embedding_v = [] if cat_inputs is not None and self.cat_feas: for i, embedding_layer in enumerate(self.cat_embedding_layers): col = cat_inputs[:, i].long() embedding_v.append(embedding_layer(col)) for i, embedding_layer in enumerate(self.list_embedding_layers): list_tensor = list_input[i].long() lengths = (list_tensor > 0).sum(dim=1).float().clamp(min=1.0) embedded = embedding_layer(list_tensor) if self.list_pooling == "mean": mask = (list_tensor > 0).unsqueeze(-1).float() pooled = (embedded * mask).sum(dim=1) / lengths.unsqueeze(-1) else: mask = (list_tensor > 0).unsqueeze(-1).float() pooled = (embedded * mask).sum(dim=1) / lengths.unsqueeze(-1) embedding_v.append(pooled) if num_inputs is not None and self.num_feas: concate_num_input = num_inputs.float() x0 = torch.cat(embedding_v + [concate_num_input], dim=1) else: x0 = torch.cat(embedding_v, dim=1)
expert_outputs = [expert(x0) for expert in self.experts] expert_outputs = torch.stack(expert_outputs, dim=1)
outputs = [] for weight_net, output_net in zip(self.weight_nets, self.output_nets): gate = weight_net(x0) gate_expanded = gate.unsqueeze(-1) weighted_expert_output = (expert_outputs * gate_expanded).sum(dim=1) output = output_net(weighted_expert_output) outputs.append(output) return torch.cat(outputs, dim=1) def main(): torch.manual_seed(42) cat_fea_dims = { 'userId': (1000, 32), 'itemId': (5000, 32), 'category': (100, 16) } num_fea_dims = { 'age': (1, 1), 'price': (1, 1) } list_fea_dims = { 'history_items': (5000, 32) } model = MultiTargetMMoE( cat_fea_dims=cat_fea_dims, num_fea_dims=num_fea_dims, list_fea_dims=list_fea_dims, expert_out_dim=64, expert_num=4, target_num=2, list_pooling="mean" ) batch_size = 32 cat_inputs = torch.randint(1, 100, (batch_size, 3)) num_inputs = torch.randn(batch_size, 2) list_inputs = [torch.randint(1, 1000, (batch_size, 10))] outputs = model(cat_inputs, num_inputs, list_inputs) print(f"模型输出形状: {outputs.shape}") print(f"任务1预测值: {outputs[:5, 0]}") print(f"任务2预测值: {outputs[:5, 1]}") criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) labels = torch.randn(batch_size, 2) optimizer.zero_grad() loss = criterion(outputs, labels) loss.backward() optimizer.step() print(f"损失值: {loss.item():.4f}") print("训练完成!")
if __name__ == "__main__": main()
|