去年我做了一个 MoE (Mixture of Experts) 相关的工作。MoE 需要在任务损失之外,再引入负载均衡损失。我简单查了下,面对“返回多个损失”需求,PyTorch 里常见的做法是把损失放在 forward 的 dict 返回值中,再一层层向外传递。这非常麻烦!
Python
def block(x, losses):
losses["aux_loss"] = x.abs().mean()
return x, losses
def model(x):
losses = {}
x, losses = block1(x, losses)
x, losses = block2(x, losses)
return x, losses
x, losses = model(x)
total_loss = losses["task_loss"] + 0.1 * losses["aux_loss"]如果继续沿用通过 dict 逐层传递方案,那么我需要沿 Bottom-up 路径修改所有的Modules!
所以我根据我的需求编写了TorchExtraContext。它的管线大致是这样的:
- 创建一个Context对象。
- 选择一个根Module,向每个named Module注册Context的引用。我设计了
- 在Module的forward里,调用API向Context注册loss/metric/log。
- 在根Module(training loop)里,从根Module的Context里获取所有的loss/metric/log。
该项目可以通过GitHub Repo获取。
另外由于目前我还没有Torch Script的使用需求,虽然我还没测试过,不过这个工具大概率不支持Torch Script。
