Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts
Modeling the tradeoffs between task-specific objectives and inter-task relationships
Today, let's delve back into one of the most renowned multi-task models: the Multi-gate Mixture-of-Experts (MMoE)1 developed by Google. In this paper, Google introduced a module featuring multiple gates to assemble experts with varying weights, enabling different tasks to utilize these experts in unique ways. This approach achieved an excellent balance between task differentiation and relevance.
MMoE
The model architecture is simple and easy to understand. Let's examine the evolution from left to right.
Shared Bottom
The shared bottom is a popular baseline for multi-task learning. From bottom to top, it typically consists of two main components:
Shared bottom MLP layers to extract shared information across tasks.
Several task-specific MLP layers to model task-specific information separately.
However, the performance of this model heavily relies on the inherent relatedness between tasks in the data. If the relatedness between tasks is low, the multi-task model can perform even worse than training the tasks separately. As demonstrated in the experiments on synthetic data below, performance worsens as the relatedness decreases
OMoE and MMoE
Inspired by model ensembling and MoE layers, they introduce a new MoE model designed to capture task differences without substantially increasing the model parameters compared to the shared-bottom multi-task model. The key ideas include:
Replacing the shared bottom network with multiple experts (MoE), allowing more flexible information sharing where each expert can learn different aspects from different tasks. (This idea reminds us the multi-head machanism in attention)
Incorporating a separate gating network for each task
More formally, the output for task k is:
Here h_k is the task specific tower. f_k is the output from the gating network.
W_gk is a trainable matrix with dimension n*d. n is the number of experts and d is the feature dimension. f_i is the i-th expert network and x is the input features.
The gating networks are simply linear transformations of the input with a softmax layer. And the f_k is the weighted sum of all the output from expert networks.
Compared to the shared bottom model, the experts in OMoE and MMoEmodels are shared softly across different tasks. The extent of sharing is regulated by the gating networks. Each expert is assigned specific tasks, minimizing interference from other loosely related tasks. Real-world experiments conducted by Google demonstrate that the MMoE model discerns differences between these tasks, automatically striking a balance between shared and task-specific parameters.
Experiments
In the experiments, one key highlight is how to validate the trainability of the MMoE model. The remaining experiments might seem mundane as they primarily showcase the excellence of MMoE.
Generate several sets of training data with different label correlations. Because for real-word dataset, it’s hard to control the correlation across different tasks
Multiple rounds of experiments were conducted on various synthetic datasets, and the average results were considered as the final metrics
The performance variances of Shared-Bottom model are much larger than those of the MoE based model. This means that Shared-Bottom models in general have much more poor quality local minima than the MoE based models do
The robustness of the OMoE has an obvious drop when the task correlation decreases to 0.5. This validates the usefulness of the multi-gate structure in resolving bad local minima caused by the conflict from task difference
The Code
I shared my implementation of MMoE layer here. We can see it’s quite straight forward. In the initialization phase, we create several experts consists of several dense layers and we also create the gates using dense layers with softmax activation function.
Then during calling, we apply the experts and gates on the input features and then aggregate the result using multiplication between experts and gates.
The MMOE is a flexible structure, we can also combine it with the ESMM model I shared before. Here is an example. We use MMOE to replace the shared bottom network in ESMM, then keep the same output structure.
That’s all for MMOE model itself.
Regarding multi-task, there are another two common questions.
Task Balancing
Training multiple tasks effectively requires a careful balance to prevent any single task from dominating the network weights. This is a complex and not yet fully resolved problem.
I plan to create one or a few posts to share the details. Some common approaches to address this issue include unitary scalarization, uncertainty weighting, gradient normalization, dynamic task prioritization, and MTL as a multi-objective optimization problem.
Prediction Score Fusion
How to fuse the scores from different tasks while inference? Since the score distribution from different tasks could differ a lot. We cannot easily use linear weights to add them together.
A common approach for this problem is to take the inverse of the rank as the task specific score. The steps are:
For each task, generate the prediction score for all the candidates
Rank the candidates based on the scores in descending order
Use the inverse of the rank position as the final score for this task
After this process, all the scores from different tasks are comparable now.
Then here comes the other problem, how to merge the scores together. The most common approach in industry is still manually setting weight for each task and added them together. Then doing several rounds of abtesting to find the best weight setting. 😅
That's all for today. I didn't expect this paper would be tedious to write. I hope we can still glean valuable insights from this post. Thank you for your attention.
https://dl.acm.org/doi/pdf/10.1145/3219819.3220007
thank you for sharing! For the approaches you mentioned in task balancing section, which one is the best? Never heard of these approaches before. How did you find these papers?