Jekyll2021-12-09T09:09:43+00:00https://lilianweng.github.io/lil-log/feed.xmlLil’LogDocument my learning notes.Lilian WengLearning with not Enough Data Part 1: Semi-Supervised Learning2021-12-05T02:00:00+00:002021-12-05T02:00:00+00:00https://lilianweng.github.io/lil-log/2021/12/05/semi-supervised-learning<blockquote>
<p>The performance of supervised learning tasks improves with more high-quality labels available. However, it is expensive to collect a large number of labeled samples. There are several paradigms in machine learning to deal with the scenario when the labels are scarce. Semi-supervised learning is one candidate, utilizing a large amount of unlabeled data conjunction with a small amount of labeled data.</p>
</blockquote>
<!--more-->
<p>When facing a limited amount of labeled data for supervised learning tasks, four approaches are commonly discussed.</p>
<ol>
<li><em>Pre-training + fine-tuning</em>: Pre-train a powerful task-agnostic model on a large unsupervised data corpus, e.g. <a href="/lil-log/2019/01/31/generalized-language-models.html">pre-training LMs</a> on free text, or pre-training vision models on unlabelled images via <a href="/lil-log/2019/11/10/self-supervised-learning.html">self-supervised learning</a>, and then fine-tune it on the downstream task with a small set of labeled samples.</li>
<li><em>Semi-supervised learning</em>: Learn from the labelled and unlabeled samples together. A lot of research has happened on vision tasks within this approach.</li>
<li><em>Active learning</em>: Labeling is expensive, but we still want to collect more given a cost budget. Active learning learns to select most valuable unlabeled samples to be collected next and helps us act smartly with a limited budget.</li>
<li><em>Pre-training + dataset auto-generation</em>: Given a capable pre-trained model, we can utilize it to auto-generate a lot more labeled samples. This has been especially popular within the language domain driven by the success of few-shot learning.</li>
</ol>
<p>I plan to write a series of posts on the topic of “Learning with not enough data”. Part 1 is on <em>Semi-Supervised Learning</em>.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#what-is-semi-supervised-learning" id="markdown-toc-what-is-semi-supervised-learning">What is semi-supervised learning?</a></li>
<li><a href="#notations" id="markdown-toc-notations">Notations</a></li>
<li><a href="#hypotheses" id="markdown-toc-hypotheses">Hypotheses</a></li>
<li><a href="#consistency-regularization" id="markdown-toc-consistency-regularization">Consistency Regularization</a> <ul>
<li><a href="#π-model" id="markdown-toc-π-model">Π-model</a></li>
<li><a href="#temporal-ensembling" id="markdown-toc-temporal-ensembling">Temporal ensembling</a></li>
<li><a href="#mean-teachers" id="markdown-toc-mean-teachers">Mean teachers</a></li>
<li><a href="#noisy-samples-as-learning-targets" id="markdown-toc-noisy-samples-as-learning-targets">Noisy samples as learning targets</a></li>
</ul>
</li>
<li><a href="#pseudo-labeling" id="markdown-toc-pseudo-labeling">Pseudo Labeling</a> <ul>
<li><a href="#label-propagation" id="markdown-toc-label-propagation">Label propagation</a></li>
<li><a href="#self-training" id="markdown-toc-self-training">Self-Training</a></li>
<li><a href="#reducing-confirmation-bias" id="markdown-toc-reducing-confirmation-bias">Reducing confirmation bias</a></li>
</ul>
</li>
<li><a href="#pseudo-labeling-with-consistency-regularization" id="markdown-toc-pseudo-labeling-with-consistency-regularization">Pseudo Labeling with Consistency Regularization</a> <ul>
<li><a href="#mixmatch" id="markdown-toc-mixmatch">MixMatch</a></li>
<li><a href="#dividemix" id="markdown-toc-dividemix">DivideMix</a></li>
<li><a href="#fixmatch" id="markdown-toc-fixmatch">FixMatch</a></li>
</ul>
</li>
<li><a href="#combined-with-powerful-pre-training" id="markdown-toc-combined-with-powerful-pre-training">Combined with Powerful Pre-Training</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="what-is-semi-supervised-learning">What is semi-supervised learning?</h2>
<p>Semi-supervised learning uses both labeled and unlabeled data to train a model.</p>
<p>Interestingly most existing literature on semi-supervised learning focuses on vision tasks. And instead pre-training + fine-tuning is a more common paradigm for language tasks.</p>
<p>All the methods introduced in this post have a loss combining two parts: \(\mathcal{L} = \mathcal{L}_s + \mu(t) \mathcal{L}_u\). The supervised loss \(\mathcal{L}_s\) is easy to get given all the labeled examples. We will focus on how the unsupervised loss \(\mathcal{L}_u\) is designed. A common choice of the weighting term \(\mu(t)\) is a ramp function increasing the importance of \(\mathcal{L}_u\) in time, where \(t\) is the training step.</p>
<blockquote>
<p><em>Disclaimer</em>: The post is not gonna cover semi-supervised methods with focus on model architecture modification. Check <a href="https://arxiv.org/abs/2006.05278">this survey</a> for how to use generative models and graph-based methods in semi-supervised learning.</p>
</blockquote>
<h2 id="notations">Notations</h2>
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(L\)</td>
<td>Number of unique labels.</td>
</tr>
<tr>
<td>\((\mathbf{x}^l, y) \sim \mathcal{X}, y \in \{0, 1\}^L\)</td>
<td>Labeled dataset. \(y\) is a one-hot representation of the true label.</td>
</tr>
<tr>
<td>\(\mathbf{u} \sim \mathcal{U}\)</td>
<td>Unlabeled dataset.</td>
</tr>
<tr>
<td>\(\mathcal{D} = \mathcal{X} \cup \mathcal{U}\)</td>
<td>The entire dataset, including both labeled and unlabeled examples.</td>
</tr>
<tr>
<td>\(\mathbf{x}\)</td>
<td>Any sample which can be either labeled or unlabeled.</td>
</tr>
<tr>
<td>\(\bar{\mathbf{x}}\)</td>
<td>\(\mathbf{x}\) with augmentation applied.</td>
</tr>
<tr>
<td>\(\mathbf{x}_i\)</td>
<td>The \(i\)-th sample.</td>
</tr>
<tr>
<td>\(\mathcal{L}\), \(\mathcal{L}_s\), \(\mathcal{L}_u\)</td>
<td>Loss, supervised loss, and unsupervised loss.</td>
</tr>
<tr>
<td>\(\mu(t)\)</td>
<td>The unsupervised loss weight, increasing in time.</td>
</tr>
<tr>
<td>\(p(y \vert \mathbf{x}), p_\theta(y \vert \mathbf{x})\)</td>
<td>The conditional probability over the label set given the input.</td>
</tr>
<tr>
<td>\(f_\theta(.)\)</td>
<td>The implemented neural network with weights \(\theta\), the model that we want to train.</td>
</tr>
<tr>
<td>\(\mathbf{z} = f_\theta(\mathbf{x})\)</td>
<td>A vector of logits output by \(f\).</td>
</tr>
<tr>
<td>\(\hat{y} = \text{softmax}(\mathbf{z})\)</td>
<td>The predicted label distribution.</td>
</tr>
<tr>
<td>\(D[.,.]\)</td>
<td>A distance function between two distributions, such as MSE, cross entropy, KL divergence, etc.</td>
</tr>
<tr>
<td>\(\beta\)</td>
<td>EMA weighting hyperparameter for <a href="#mean-teachers">teacher</a> model weights.</td>
</tr>
<tr>
<td>\(\alpha, \lambda\)</td>
<td>Parameters for MixUp, \(\lambda \sim \text{Beta}(\alpha, \alpha)\).</td>
</tr>
<tr>
<td>\(T\)</td>
<td>Temperature for sharpening the predicted distribution.</td>
</tr>
<tr>
<td>\(\tau\)</td>
<td>A confidence threshold for selecting the qualified prediction.</td>
</tr>
</tbody>
</table>
<h2 id="hypotheses">Hypotheses</h2>
<p>Several hypotheses have been discussed in literature to support certain design decisions in semi-supervised learning methods.</p>
<ul>
<li>
<p>H1: <strong>Smoothness Assumptions</strong>: If two data samples are close in a high-density region of the feature space, their labels should be the same or very similar.</p>
</li>
<li>
<p>H2: <strong>Cluster Assumptions</strong>: The feature space has both dense regions and sparse regions. Densely grouped data points naturally form a cluster. Samples in the same cluster are expected to have the same label. This is a small extension of H1.</p>
</li>
<li>
<p>H3: <strong>Low-density Separation Assumptions</strong>: The decision boundary between classes tends to be located in the sparse, low density regions, because otherwise the decision boundary would cut a high-density cluster into two classes, corresponding to two clusters, which invalidates H1 and H2.</p>
</li>
<li>
<p>H4: <strong>Manifold Assumptions</strong>: The high-dimensional data tends to locate on a low-dimensional manifold. Even though real-world data might be observed in very high dimensions (e.g. such as images of real-world objects/scenes), they actually can be captured by a lower dimensional manifold where certain attributes are captured and similar points are grouped closely (e.g. images of real-world objects/scenes are not drawn from a uniform distribution over all pixel combinations). This enables us to learn a more efficient representation for us to discover and measure similarity between unlabeled data points. This is also the foundation for representation learning. [see <a href="https://stats.stackexchange.com/questions/66939/what-is-the-manifold-assumption-in-semi-supervised-learning">a helpful link</a>].</p>
</li>
</ul>
<h2 id="consistency-regularization">Consistency Regularization</h2>
<p><strong>Consistency Regularization</strong>, also known as <strong>Consistency Training</strong>, assumes that randomness within the neural network (e.g. with Dropout) or data augmentation transformations should not modify model predictions given the same input. Every method in this section has a consistency regularization loss as \(\mathcal{L}_u\).</p>
<p>This idea has been adopted in several <a href="/lil-log/2019/11/10/self-supervised-learning.html">self-supervised</a> <a href="/lil-log/2021/05/31/contrastive-representation-learning.html">learning</a> methods, such as SimCLR, BYOL, SimCSE, etc. Different augmented versions of the same sample should result in the same representation. <a href="/lil-log/2019/01/31/generalized-language-models.html#cross-view-training">Cross-view training</a> in language modeling and multi-view learning in self-supervised learning all share the same motivation.</p>
<h3 id="π-model">Π-model</h3>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/PI-model.png" alt="Pi Model" /></p>
<p class="image-caption"><em>Fig. 1. Overview of the Π-model. Two versions of the same input with different stochastic augmentation and dropout masks pass through the network and the outputs are expected to be consistent. (Image source: <a href="https://arxiv.org/abs/1610.02242">Laine & Aila (2017)</a>)</em></p>
<p><a href="https://arxiv.org/abs/1606.04586">Sajjadi et al. (2016)</a> proposed an unsupervised learning loss to minimize the difference between two passes through the network with stochastic transformations (e.g. dropout, random max-pooling) for the same data point. The label is not explicitly used, so the loss can be applied to unlabeled dataset. <a href="https://arxiv.org/abs/1610.02242">Laine & Aila (2017)</a> later coined the name, <strong>Π-Model</strong>, for such a setup.</p>
\[\mathcal{L}_u^\Pi = \sum_{\mathbf{x} \in \mathcal{D}} \text{MSE}(f_\theta(\mathbf{x}), f'_\theta(\mathbf{x}))\]
<p>where \(f'\) is the same neural network with different stochastic augmentation or dropout masks applied. This loss utilizes the entire dataset.</p>
<h3 id="temporal-ensembling">Temporal ensembling</h3>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/temperal-ensembling.png" alt="Temporal Ensembling" /></p>
<p class="image-caption"><em>Fig. 2. Overview of Temporal Ensembling. The per-sample EMA label prediction is the learning target. (Image source: <a href="https://arxiv.org/abs/1610.02242">Laine & Aila (2017)</a>)</em></p>
<p>Π-model requests the network to run two passes per sample, doubling the computation cost. To reduce the cost, <strong>Temporal Ensembling</strong> (<a href="https://arxiv.org/abs/1610.02242">Laine & Aila 2017</a>) maintains an exponential moving average (EMA) of the model prediction in time per training sample \(\tilde{\mathbf{z}}_i\) as the learning target, which is only evaluated and updated once per epoch. Because the ensemble output \(\tilde{\mathbf{z}}_i\) is initialized to \(\mathbf{0}\), it is normalized by \((1-\alpha^t)\) to correct this startup bias. Adam optimizer has such <a href="https://stats.stackexchange.com/questions/232741/why-is-it-important-to-include-a-bias-correction-term-for-the-adam-optimizer-for">bias correction</a> terms for the same reason.</p>
\[\tilde{\mathbf{z}}^{(t)}_i = \frac{\alpha \tilde{\mathbf{z}}^{(t-1)}_i + (1-\alpha) \mathbf{z}_i}{1-\alpha^t}\]
<p>where \(\tilde{\mathbf{z}}^{(t)}\) is the ensemble prediction at epoch \(t\) and \(\mathbf{z}_i\) is the model prediction in the current round. Note that since \(\tilde{\mathbf{z}}^{(0)} = \mathbf{0}\), with correction, \(\tilde{\mathbf{z}}^{(1)}\) is simply equivalent to \(\mathbf{z}_i\) at epoch 1.</p>
<h3 id="mean-teachers">Mean teachers</h3>
<p style="width: 68%;" class="center"><img src="/lil-log/assets/images/mean-teacher.png" alt="Mean Teacher" /></p>
<p class="image-caption"><em>Fig. 3. Overview of the Mean Teacher framework. (Image source: <a href="https://arxiv.org/abs/1703.01780">Tarvaninen & Valpola, 2017</a>)</em></p>
<p>Temporal Ensembling keeps track of an EMA of label predictions for each training sample as a learning target. However, this label prediction only changes <em>every epoch</em>, making the approach clumsy when the training dataset is large. <strong>Mean Teacher</strong> (<a href="https://arxiv.org/abs/1703.01780">Tarvaninen & Valpola, 2017</a>) is proposed to overcome the slowness of target update by tracking the moving average of model weights instead of model outputs. Let’s call the original model with weights \(\theta\) as the <em>student</em> model and the model with moving averaged weights \(\theta’\) across consecutive student models as the <em>mean teacher</em>: \(\theta’ \gets \beta \theta’ + (1-\beta)\theta\)</p>
<p>The consistency regularization loss is the distance between predictions by the student and teacher and the student-teacher gap should be minimized. The mean teacher is expected to provide more accurate predictions than the student. It got confirmed in the empirical experiments, as shown in Fig. 4.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/mean-teacher-results.png" alt="Mean teacher experiments" /></p>
<p class="image-caption"><em>Fig. 4. Classification error on SVHN of Mean Teacher and the Π Model. The mean teacher (in orange) has better performance than the student model (in blue). (Image source: <a href="https://arxiv.org/abs/1703.01780">Tarvaninen & Valpola, 2017</a>)</em></p>
<p>According to their ablation studies,</p>
<ul>
<li>Input augmentation (e.g. random flips of input images, Gaussian noise) or student model dropout is necessary for good performance. Dropout is not needed on the teacher model.</li>
<li>The performance is sensitive to the EMA decay hyperparameter \(\beta\). A good strategy is to use a small \(\beta=0.99\) during the ramp up stage and a larger \(\beta=0.999\) in the later stage when the student model improvement slows down.</li>
<li>They found that MSE as the consistency cost function performs better than other cost functions like KL divergence.</li>
</ul>
<h3 id="noisy-samples-as-learning-targets">Noisy samples as learning targets</h3>
<p>Several recent consistency training methods learn to minimize prediction difference between the original unlabeled sample and its corresponding augmented version. It is quite similar to the Π-model but the consistency regularization loss is <em>only</em> applied to the unlabeled data.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/consistency-training-with-noisy-samples.png" alt="Consistency training with noisy samples" /></p>
<p class="image-caption"><em>Fig. 5. Consistency training with noisy samples.</em></p>
<p>Adversarial Training (<a href="https://arxiv.org/abs/1412.6572">Goodfellow et al. 2014</a>) applies adversarial noise onto the input and trains the model to be robust to such adversarial attack. The setup works in supervised learning,</p>
\[\begin{aligned}
\mathcal{L}_\text{adv}(\mathbf{x}^l, \theta) &= D[q(y\mid \mathbf{x}^l), p_\theta(y\mid \mathbf{x}^l + r_\text{adv})] \\
r_\text{adv} &= {\arg\max}_{r; \|r\| \leq \epsilon} D[q(y\mid \mathbf{x}^l), p_\theta(y\mid \mathbf{x}^l + r_\text{adv})] \\
r_\text{adv} &\approx \epsilon \frac{g}{\|g\|_2} \approx \epsilon\text{sign}(g)\quad\text{where }g = \nabla_{r} D[y, p_\theta(y\mid \mathbf{x}^l + r)]
\end{aligned}\]
<p>where \(q(y \mid \mathbf{x}^l)\) is the true distribution, approximated by one-hot encoding of the ground truth label, \(y\). \(p_\theta(y \mid \mathbf{x}^l)\) is the model prediction. \(D[.,.]\) is a distance function measuring the divergence between two distributions.</p>
<p><strong>Virtual Adversarial Training</strong> (<strong>VAT</strong>; <a href="https://arxiv.org/abs/1704.03976">Miyato et al. 2018</a>) extends the idea to work in semi-supervised learning. Because \(q(y \mid \mathbf{x}^l)\) is unknown, VAT replaces it with the current model prediction for the original input with the current weights \(\hat{\theta}\). Note that \(\hat{\theta}\) is a fixed copy of model weights, so there is no gradient update on \(\hat{\theta}\).</p>
\[\begin{aligned}
\mathcal{L}_u^\text{VAT}(\mathbf{x}, \theta) &= D[p_{\hat{\theta}}(y\mid \mathbf{x}), p_\theta(y\mid \mathbf{x} + r_\text{vadv})] \\
r_\text{vadv} &= {\arg\max}_{r; \|r\| \leq \epsilon} D[p_{\hat{\theta}}(y\mid \mathbf{x}), p_\theta(y\mid \mathbf{x} + r)]
\end{aligned}\]
<p>The VAT loss applies to both labeled and unlabeled samples. It is a negative smoothness measure of the current model’s prediction manifold at each data point. The optimization of such loss motivates the manifold to be smoother.</p>
<p><strong>Interpolation Consistency Training</strong> (<strong>ICT</strong>; <a href="https://arxiv.org/abs/1903.03825">Verma et al. 2019</a>) enhances the dataset by adding more interpolations of data points and expects the model prediction to be consistent with interpolations of the corresponding labels. MixUp (<a href="https://arxiv.org/abs/1710.09412">Zheng et al. 2018</a>) operation mixes two images via a simple weighted sum and combines it with label smoothing. Following the idea of MixUp, ICT expects the prediction model to produce a label on a mixup sample to match the interpolation of predictions of corresponding inputs:</p>
\[\begin{aligned}
\text{mixup}_\lambda (\mathbf{x}_i, \mathbf{x}_j) &= \lambda \mathbf{x}_i + (1-\lambda)\mathbf{x}_j \\
p(\text{mixup}_\lambda (y \mid \mathbf{x}_i, \mathbf{x}_j)) &\approx \lambda p(y \mid \mathbf{x}_i) + (1-\lambda) p(y \mid \mathbf{x}_j)
\end{aligned}\]
<p>where \(\theta'\) is a moving average of \(\theta\), which is a <a href="#mean-teachers">mean teacher</a>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ICT.png" alt="ICT" /></p>
<p class="image-caption"><em>Fig. 6. Overview of Interpolation Consistency Training. MixUp is applied to produce more interpolated samples with interpolated labels as learning targets. (Image source: <a href="https://arxiv.org/abs/1903.03825">Verma et al. 2019</a>)</em></p>
<p>Because the probability of two randomly selected unlabeled samples belonging to different classes is high (e.g. There are 1000 object classes in ImageNet), the interpolation by applying a mixup between two random unlabeled samples is likely to happen around the decision boundary. According to the low-density separation <a href="#hypotheses">assumptions</a>, the decision boundary tends to locate in the low density regions.</p>
\[\mathcal{L}^\text{ICT}_{u} = \mathbb{E}_{\mathbf{u}_i, \mathbf{u}_j \sim \mathcal{U}} \mathbb{E}_{\lambda \sim \text{Beta}(\alpha, \alpha)} D[p_\theta(y \mid \text{mixup}_\lambda (\mathbf{u}_i, \mathbf{u}_j)), \text{mixup}_\lambda(p_{\theta’}(y \mid \mathbf{u}_i), p_{\theta'}(y \mid \mathbf{u}_j)]\]
<p>where \(\theta'\) is a moving average of \(\theta\).</p>
<p>Similar to VAT, <strong>Unsupervised Data Augmentation</strong> (<strong>UDA</strong>; <a href="https://arxiv.org/abs/1904.12848">Xie et al. 2020</a>) learns to predict the same output for an unlabeled example and the augmented one. UDA especially focuses on studying how the <em>“quality”</em> of noise can impact the semi-supervised learning performance with consistency training. It is crucial to use advanced data augmentation methods for producing meaningful and effective noisy samples. Good data augmentation should produce valid (i.e. does not change the label) and diverse noise, and carry targeted inductive biases.</p>
<p>For images, UDA adopts RandAugment (<a href="https://arxiv.org/abs/1909.13719">Cubuk et al. 2019</a>) which uniformly samples augmentation operations available in <a href="https://pillow.readthedocs.io/en/stable/">PIL</a>, no learning or optimization, so it is much cheaper than AutoAugment.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/UDA-image-results.png" alt="UDA vision" /></p>
<p class="image-caption"><em>Fig. 7. Comparison of various semi-supervised learning methods on CIFAR-10 classification. Fully supervised Wide-ResNet-28-2 and PyramidNet+ShakeDrop have an error rate of <strong>5.4</strong> and <strong>2.7</strong> respectively when trained on 50,000 examples without RandAugment. (Image source: <a href="https://arxiv.org/abs/1904.12848">Xie et al. 2020</a>)</em></p>
<p>For language, UDA combines back-translation and TF-IDF based word replacement. Back-translation preserves the high-level meaning but may not retain certain words, while TF-IDF based word replacement drops uninformative words with low TF-IDF scores. In the experiments on language tasks, they found UDA to be complementary to transfer learning and representation learning; For example, BERT fine-tuned (i.e. \(\text{BERT}_\text{FINETUNE}\) in Fig. 8.) on in-domain unlabeled data can further improve the performance.</p>
<p style="width: 83%;" class="center"><img src="/lil-log/assets/images/UDA-language-results.png" alt="UDA language" /></p>
<p class="image-caption"><em>Fig. 8. Comparison of UDA with different initialization configurations on various text classification tasks. (Image source: <a href="https://arxiv.org/abs/1904.12848">Xie et al. 2020</a>)</em></p>
<p>When calculating \(\mathcal{L}_u\), UDA found two training techniques to help improve the results.</p>
<ul>
<li><em>Low confidence masking</em>: Mask out examples with low prediction confidence if lower than a threshold \(\tau\).</li>
<li><em>Sharpening prediction distribution</em>: Use a low temperature \(T\) in softmax to sharpen the predicted probability distribution.</li>
<li><em>In-domain data filtration</em>: In order to extract more in-domain data from a large out-of-domain dataset, they trained a classifier to predict in-domain labels and then retain samples with high confidence predictions as in-domain candidates.</li>
</ul>
\[\begin{aligned}
&\mathcal{L}_u^\text{UDA} = \mathbb{1}[\max_{y'} p_{\hat{\theta}}(y'\mid \mathbf{x}) > \tau ] \cdot D[p^\text{(sharp)}_{\hat{\theta}}(y \mid \mathbf{x}; T), p_\theta(y \mid \bar{\mathbf{x}})] \\
&\text{where } p_{\hat{\theta}}^\text{(sharp)}(y \mid \mathbf{x}; T) = \frac{\exp(z^{(y)} / T)}{ \sum_{y'} \exp(z^{(y')} / T) }
\end{aligned}\]
<p>where \(\hat{\theta}\) is a fixed copy of model weights, same as in VAT, so no gradient update, and \(\bar{\mathbf{x}}\) is the augmented data point. \(\tau\) is the prediction confidence threshold and \(T\) is the distribution sharpening temperature.</p>
<h2 id="pseudo-labeling">Pseudo Labeling</h2>
<p><strong>Pseudo Labeling</strong> (<a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf">Lee 2013</a>) assigns fake labels to unlabeled samples based on the maximum softmax probabilities predicted by the current model and then trains the model on both labeled and unlabeled samples simultaneously in a pure supervised setup.</p>
<p>Why could pseudo labels work? Pseudo label is in effect equivalent to <em>Entropy Regularization</em> (<a href="https://papers.nips.cc/paper/2004/hash/96f2b50b5d3613adf9c27049b2a888c7-Abstract.html">Grandvalet & Bengio 2004</a>), which minimizes the conditional entropy of class probabilities for unlabeled data to favor low density separation between classes. In other words, the predicted class probabilities is in fact a measure of class overlap, minimizing the entropy is equivalent to reduced class overlap and thus low density separation.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/pseudo-label-segregation.png" alt="Pseudo labeling segregation" /></p>
<p class="image-caption"><em>Fig. 9. t-SNE visualization of outputs on MNIST test set by models training (a) without and (b) with pseudo labeling on 60000 unlabeled samples, in addition to 600 labeled data. Pseudo labeling leads to better segregation in the learned embedding space. (Image source: <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf">Lee 2013</a>)</em></p>
<p>Training with pseudo labeling naturally comes as an iterative process. We refer to the model that produces pseudo labels as teacher and the model that learns with pseudo labels as student.</p>
<h3 id="label-propagation">Label propagation</h3>
<p><strong>Label Propagation</strong> (<a href="https://arxiv.org/abs/1904.04717">Iscen et al. 2019</a>) is an idea to construct a similarity graph among samples based on feature embedding. Then the pseudo labels are “diffused” from known samples to unlabeled ones where the propagation weights are proportional to pairwise similarity scores in the graph. Conceptually it is similar to a k-NN classifier and both suffer from the problem of not scaling up well with a large dataset.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/label-propagation.png" alt="Label propagation" /></p>
<p class="image-caption"><em>Fig. 10. Illustration of how Label Propagation works. (Image source: <a href="https://arxiv.org/abs/1904.04717">Iscen et al. 2019</a>)</em></p>
<h3 id="self-training">Self-Training</h3>
<p><strong>Self-Training</strong> is not a new concept (<a href="https://ieeexplore.ieee.org/document/1053799">Scudder 1965</a>, <a href="http://www.kamalnigam.com/papers/cotrain-CIKM00.pdf">Nigram & Ghani CIKM 2000</a>). It is an iterative algorithm, alternating between the following two steps until every unlabeled sample has a label assigned:</p>
<ul>
<li>Initially it builds a classifier on labeled data.</li>
<li>Then it uses this classifier to predict labels for the unlabeled data and converts the most confident ones into labeled samples.</li>
</ul>
<p><a href="https://arxiv.org/abs/1911.04252">Xie et al. (2020)</a> applied self-training in deep learning and achieved great results. On the ImageNet classification task, they first trained an EfficientNet (<a href="https://arxiv.org/abs/1905.11946">Tan & Le 2019</a>) model as teacher to generate pseudo labels for 300M unlabeled images and then trained a larger EfficientNet as student to learn with both true labeled and pseudo labeled images. One critical element in their setup is to have <em>noise</em> during student model training but have no noise for the teacher to produce pseudo labels. Thus their method is called <strong>Noisy Student</strong>. They applied stochastic depth (<a href="https://arxiv.org/abs/1603.09382">Huang et al. 2016</a>), dropout and RandAugment to noise the student. Noise is important for the student to perform better than the teacher. The added noise has a compound effect to encourage the model’s decision making frontier to be smooth, on both labeled and unlabeled data.</p>
<p>A few other important technical configs in noisy student self-training are:</p>
<ul>
<li>The student model should be sufficiently large (i.e. larger than the teacher) to fit more data.</li>
<li>Noisy student should be paired with data balancing, especially important to balance the number of pseudo labeled images in each class.</li>
<li>Soft pseudo labels work better than hard ones.</li>
</ul>
<p>Noisy student also improves adversarial robustness against an FGSM (Fast Gradient Sign Attack = The attack uses the gradient of the loss w.r.t the input data and adjusts the input data to maximize the loss) attack though the model is not optimized for adversarial robustness.</p>
<p>SentAugment, proposed by <a href="https://arxiv.org/abs/2010.02194">Du et al. (2020)</a>, aims to solve the problem when there is not enough in-domain unlabeled data for self-training in the language domain. It relies on sentence embedding to find unlabeled in-domain samples from a large corpus and uses the retrieved sentences for self-training.</p>
<h3 id="reducing-confirmation-bias">Reducing confirmation bias</h3>
<p>Confirmation bias is a problem with incorrect pseudo labels provided by an imperfect teacher model. Overfitting to wrong labels may not give us a better student model.</p>
<p>To reduce confirmation bias, <a href="https://arxiv.org/abs/1908.02983">Arazo et al. (2019)</a> proposed two techniques. One is to adopt MixUp with soft labels. Given two samples, \((\mathbf{x}_i, \mathbf{x}_j)\) and their corresponding true or pseudo labels \((y_i, y_j)\), the interpolated label equation can be translated to a cross entropy loss with softmax outputs:</p>
\[\begin{aligned}
&\bar{\mathbf{x}} = \lambda \mathbf{x}_i + (1-\lambda) \mathbf{x}_j \\
&\bar{y} = \lambda y_i + (1-\lambda) y_j \Leftrightarrow
\mathcal{L} = \lambda [y_i^\top \log f_\theta(\bar{\mathbf{x}})] + (1-\lambda) [y_j^\top \log f_\theta(\bar{\mathbf{x}})]
\end{aligned}\]
<p>Mixup is insufficient if there are too few labeled samples. They further set a minimum number of labeled samples in every mini batch by oversampling the labeled samples. This works better than upweighting labeled samples, because it leads to more frequent updates rather than few updates of larger magnitude which could be less stable. Like consistency regularization, data augmentation and dropout are also important for pseudo labeling to work well.</p>
<p><strong>Meta Pseudo Labels</strong> (<a href="https://arxiv.org/abs/2003.10580">Pham et al. 2021</a>) adapts the teacher model constantly with the feedback of how well the student performs on the labeled dataset. The teacher and the student are trained in parallel, where the teacher learns to generate better pseudo labels and the student learns from the pseudo labels.</p>
<p>Let the teacher and student model weights be \(\theta_T\) and \(\theta_S\), respectively. The student model’s loss on the labeled samples is defined as a function \(\theta^\text{PL}_S(.)\) of \(\theta_T\) and we would like to minimize this loss by optimizing the teacher model accordingly.</p>
\[\begin{aligned}
\min_{\theta_T} &\mathcal{L}_s(\theta^\text{PL}_S(\theta_T)) = \min_{\theta_T} \mathbb{E}_{(\mathbf{x}^l, y) \in \mathcal{X}} \text{CE}[y, f_{\theta_S}(\mathbf{x}^l)] \\
\text{where } &\theta^\text{PL}_S(\theta_T)
= \arg\min_{\theta_S} \mathcal{L}_u (\theta_T, \theta_S)
= \arg\min_{\theta_S} \mathbb{E}_{\mathbf{u} \sim \mathcal{U}} \text{CE}[(f_{\theta_T}(\mathbf{u}), f_{\theta_S}(\mathbf{u}))]
\end{aligned}\]
<p>However, it is not trivial to optimize the above equation. Borrowing the idea of <a href="https://arxiv.org/abs/1703.03400">MAML</a>, it approximates the multi-step \(\arg\min_{\theta_S}\) with the one-step gradient update of \(\theta_S\),</p>
\[\begin{aligned}
\theta^\text{PL}_S(\theta_T) &\approx \theta_S - \eta_S \cdot \nabla_{\theta_S} \mathcal{L}_u(\theta_T, \theta_S) \\
\min_{\theta_T} \mathcal{L}_s (\theta^\text{PL}_S(\theta_T)) &\approx \min_{\theta_T} \mathcal{L}_s \big( \theta_S - \eta_S \cdot \nabla_{\theta_S} \mathcal{L}_u(\theta_T, \theta_S) \big)
\end{aligned}\]
<p>With soft pseudo labels, the above objective is differentiable. But if using hard pseudo labels, it is not differentiable and thus we need to use RL, e.g. REINFORCE.</p>
<p>The optimization procedure is alternative between training two models:</p>
<ul>
<li><em>Student model update</em>: Given a batch of unlabeled samples \(\{ \mathbf{u} \}\), we generate pseudo labels by \(f_{\theta_T}(\mathbf{u})\) and optimize \(\theta_S\) with one step SGD: \(\theta’_S = \color{green}{\theta_S - \eta_S \cdot \nabla_{\theta_S} \mathcal{L}_u(\theta_T, \theta_S)}\).</li>
<li><em>Teacher model update</em>: Given a batch of labeled samples \(\{(\mathbf{x}^l, y)\}\), we reuse the student’s update to optimize \(\theta_T\): \(\theta’_T = \theta_T - \eta_T \cdot \nabla_{\theta_T} \mathcal{L}_s ( \color{green}{\theta_S - \eta_S \cdot \nabla_{\theta_S} \mathcal{L}_u(\theta_T, \theta_S)} )\). In addition, the UDA objective is applied to the teacher model to incorporate consistency regularization.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/MPL-results.png" alt="MPL experiment results" /></p>
<p class="image-caption"><em>Fig. 11. Comparison of Meta Pseudo Labels with other semi- or self-supervised learning methods on image classification tasks. (Image source: <a href="https://arxiv.org/abs/2003.10580">Pham et al. 2021</a>)</em></p>
<h2 id="pseudo-labeling-with-consistency-regularization">Pseudo Labeling with Consistency Regularization</h2>
<p>It is possible to combine the above two approaches together, running semi-supervised learning with both pseudo labeling and consistency training.</p>
<h3 id="mixmatch">MixMatch</h3>
<p><strong>MixMatch</strong> (<a href="https://arxiv.org/abs/1905.02249">Berthelot et al. 2019</a>), as a holistic approach to semi-supervised learning, utilizes unlabeled data by merging the following techniques:</p>
<ol>
<li><em>Consistency regularization</em>: Encourage the model to output the same predictions on perturbed unlabeled samples.</li>
<li><em>Entropy minimization</em>: Encourage the model to output confident predictions on unlabeled data.</li>
<li><em>MixUp</em> augmentation: Encourage the model to have linear behaviour between samples.</li>
</ol>
<p>Given a batch of labeled data \(\mathcal{X}\) and unlabeled data \(\mathcal{U}\), we create augmented versions of them via \(\text{MixMatch}(.)\), \(\bar{\mathcal{X}}\) and \(\bar{\mathcal{U}}\), containing augmented samples and guessed labels for unlabeled examples.</p>
\[\begin{aligned}
\bar{\mathcal{X}}, \bar{\mathcal{U}} &= \text{MixMatch}(\mathcal{X}, \mathcal{U}, T, K, \alpha) \\
\mathcal{L}^\text{MM}_s &= \frac{1}{\vert \bar{\mathcal{X}} \vert} \sum_{(\bar{\mathbf{x}}^l, y)\in \bar{\mathcal{X}}} D[y, p_\theta(y \mid \bar{\mathbf{x}}^l)] \\
\mathcal{L}^\text{MM}_u &= \frac{1}{L\vert \bar{\mathcal{U}} \vert} \sum_{(\bar{\mathbf{u}}, \hat{y})\in \bar{\mathcal{U}}} \| \hat{y} - p_\theta(y \mid \bar{\mathbf{u}}) \|^2_2 \\
\end{aligned}\]
<p>where \(T\) is the sharpening temperature to reduce the guessed label overlap; \(K\) is the number of augmentations generated per unlabeled example; \(\alpha\) is the parameter in MixUp.</p>
<p>For each \(\mathbf{u}\), MixMatch generates \(K\) augmentations, \(\bar{\mathbf{u}}^{(k)} = \text{Augment}(\mathbf{u})\) for \(k=1, \dots, K\) and the pseudo label is guessed based on the average: \(\hat{y} = \frac{1}{K} \sum_{k=1}^K p_\theta(y \mid \bar{\mathbf{u}}^{(k)})\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/MixMatch.png" alt="MixMatch" /></p>
<p class="image-caption"><em>Fig. 12. The process of “label guessing” in MixMatch: averaging \(K\) augmentations, correcting the predicted marginal distribution and finally sharpening the distribution. (Image source: <a href="https://arxiv.org/abs/1905.02249">Berthelot et al. 2019</a>)</em></p>
<p>According to their ablation studies, it is critical to have MixUp especially on the unlabeled data. Removing temperature sharpening on the pseudo label distribution hurts the performance quite a lot. Average over multiple augmentations for label guessing is also necessary.</p>
<p><strong>ReMixMatch</strong> (<a href="https://arxiv.org/abs/1911.09785">Berthelot et al. 2020</a>) improves MixMatch by introducing two new mechanisms:</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ReMixMatch.png" alt="ReMixMatch" /></p>
<p class="image-caption"><em>Fig. 13. Illustration of two improvements introduced in ReMixMatch over MixMatch. (Image source: <a href="https://arxiv.org/abs/1911.09785">Berthelot et al. 2020</a>)</em></p>
<ul>
<li><em>Distribution alignment.</em> It encourages the marginal distribution \(p(y)\) to be close to the marginal distribution of the ground truth labels. Let \(p(y)\) be the class distribution in the true labels and \(\tilde{p}(\hat{y})\) be a running average of the predicted class distribution among the unlabeled data. The model prediction on an unlabeled sample \(p_\theta(y \vert \mathbf{u})\) is normalized to be \(\text{Normalize}\big( \frac{p_\theta(y \vert \mathbf{u}) p(y)}{\tilde{p}(\hat{y})} \big)\) to match the true marginal distribution.
<ul>
<li>Note that entropy minimization is not a useful objective if the marginal distribution is not uniform.</li>
<li>I do feel the assumption that the class distributions on the labeled and unlabeled data should match is too strong and not necessarily to be true in the real-world setting.</li>
</ul>
</li>
<li><em>Augmentation anchoring</em>. Given an unlabeled sample, it first generates an “anchor” version with weak augmentation and then averages \(K\) strongly augmented versions using CTAugment (Control Theory Augment). CTAugment only samples augmentations that keep the model predictions within the network tolerance.</li>
</ul>
<p>The ReMixMatch loss is a combination of several terms,</p>
<ul>
<li>a supervised loss with data augmentation and MixUp applied;</li>
<li>an unsupervised loss with data augmentation and MixUp applied, using pseudo labels as targets;</li>
<li>a CE loss on a single heavily-augmented unlabeled image without MixUp;</li>
<li>a <a href="/lil-log/2019/11/10/self-supervised-learning.html#distortion">rotation</a> loss as in self-supervised learning.</li>
</ul>
<h3 id="dividemix">DivideMix</h3>
<p><strong>DivideMix</strong> (<a href="https://arxiv.org/abs/2002.07394">Junnan Li et al. 2020</a>) combines semi-supervised learning with Learning with noisy labels (LNL). It models the per-sample loss distribution via a <a href="https://scikit-learn.org/stable/modules/mixture.html">GMM</a> to dynamically divide the training data into a labeled set with clean examples and an unlabeled set with noisy ones. Following the idea in <a href="https://arxiv.org/abs/1904.11238">Arazo et al. 2019</a>, they fit a two-component GMM on the per-sample cross entropy loss \(\ell_i = y_i^\top \log f_\theta(\mathbf{x}_i)\). Clean samples are expected to get lower loss faster than noisy samples. The component with smaller mean is the cluster corresponding to clean labels and let’s denote it as \(c\). If the GMM posterior probability \(w_i = p_\text{GMM}(c \mid \ell_i)\) (i.e. the probability of the sampling belonging to the clean sample set) is larger than the threshold \(\tau\), this sample is considered as a clean sample and otherwise a noisy one.</p>
<p>The data clustering step is named <em>co-divide</em>. To avoid confirmation bias, DivideMix simultaneously trains two diverged networks where each network uses the dataset division from the other network; e.g. thinking about how Double Q Learning works.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DivideMix.png" alt="DivideMix" /></p>
<p class="image-caption"><em>Fig. 14. DivideMix trains two networks independently to reduce confirmation bias. They run co-divide, co-refinement, and co-guessing together. (Image source: <a href="https://arxiv.org/abs/2002.07394">Junnan Li et al. 2020</a>)</em></p>
<p>Compared to MixMatch, DivideMix has an additional <em>co-divide</em> stage for handling noisy samples, as well as the following improvements during training:</p>
<ul>
<li><em>Label co-refinement</em>: It linearly combines the ground-truth label \(y_i\) with the network’s prediction \(\hat{y}_i\), which is averaged across multiple augmentations of \(\mathbf{x}_i\), guided by the clean set probability \(w_i\) produced by the other network.</li>
<li><em>Label co-guessing</em>: It averages the predictions from two models for unlabelled data samples.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DivideMix-algo.png" alt="Algorithm of DivideMix" /></p>
<p class="image-caption"><em>Fig. 15. The algorithm of DivideMix. (Image source: <a href="https://arxiv.org/abs/2002.07394">Junnan Li et al. 2020</a>)</em></p>
<h3 id="fixmatch">FixMatch</h3>
<p><strong>FixMatch</strong> (<a href="https://arxiv.org/abs/2001.07685">Sohn et al. 2020</a>) generates pseudo labels on unlabeled samples with weak augmentation and only keeps predictions with high confidence. Here both weak augmentation and high confidence filtering help produce high-quality trustworthy pseudo label targets. Then FixMatch learns to predict these pseudo labels given a heavily-augmented sample.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/FixMatch.png" alt="FixMatch" /></p>
<p class="image-caption"><em>Fig. 16. Illustration of how FixMatch works. (Image source: <a href="https://arxiv.org/abs/2001.07685">Sohn et al. 2020</a>)</em></p>
\[\begin{aligned}
\mathcal{L}_s &= \frac{1}{B} \sum^B_{b=1} \text{CE}[y_b, p_\theta(y \mid \mathcal{A}_\text{weak}(\mathbf{x}_b))] \\
\mathcal{L}_u &= \frac{1}{\mu B} \sum_{b=1}^{\mu B} \mathbb{1}[\max(\hat{y}_b) \geq \tau]\;\text{CE}(\hat{y}_b, p_\theta(y \mid \mathcal{A}_\text{strong}(\mathbf{u}_b)))
\end{aligned}\]
<p>where \(\hat{y}_b\) is the pseudo label for an unlabeled example; \(\mu\) is a hyperparameter that determines the relative sizes of \(\mathcal{X}\) and \(\mathcal{U}\).</p>
<ul>
<li>Weak augmentation \(\mathcal{A}_\text{weak}(.)\): A standard flip-and-shift augmentation</li>
<li>Strong augmentation \(\mathcal{A}_\text{strong}(.)\) : AutoAugment, Cutout, RandAugment, CTAugment</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/FixMatch-results.png" alt="FixMatch results" /></p>
<p class="image-caption"><em>Fig. 17. Performance of FixMatch and several other semi-supervised learning methods on image classification tasks. (Image source: <a href="https://arxiv.org/abs/2001.07685">Sohn et al. 2020</a>)</em></p>
<p>According to the ablation studies of FixMatch,</p>
<ul>
<li>Sharpening the predicted distribution with a temperature parameter \(T\) does not have a significant impact when the threshold \(\tau\) is used.</li>
<li>Cutout and CTAugment as part of strong augmentations are necessary for good performance.</li>
<li>When the weak augmentation for label guessing is replaced with strong augmentation, the model diverges early in training. If discarding weak augmentation completely, the model overfit the guessed labels.</li>
<li>Using weak instead of strong augmentation for pseudo label prediction leads to unstable performance. Strong data augmentation is critical.</li>
</ul>
<h2 id="combined-with-powerful-pre-training">Combined with Powerful Pre-Training</h2>
<p>It is a common paradigm, especially in language tasks, to first pre-train a task-agnostic model on a large unsupervised data corpus via self-supervised learning and then fine-tune it on the downstream task with a small labeled dataset. Research has shown that we can obtain extra gain if combining semi-supervised learning with pretraining.</p>
<p><a href="https://arxiv.org/abs/2006.06882">Zoph et al. (2020)</a> studied to what degree <a href="#self-training">self-training</a> can work better than pre-training. Their experiment setup was to use ImageNet for pre-training or self-training to improve COCO. Note that when using ImageNet for self-training, it discards labels and only uses ImageNet samples as unlabeled data points. <a href="https://arxiv.org/abs/1811.08883">He et al. (2018)</a> has demonstrated that ImageNet classification pre-training does not work well if the downstream task is very different, such as object detection.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/self-training-pre-training.png" alt="self-training-pre-training" /></p>
<p class="image-caption"><em>Fig. 18. The effect of (a) data augment (from weak to strong) and (b) the labeled dataset size on the object detection performance. In the legend: <code class="language-plaintext highlighter-rouge">Rand Init</code> refers to a model initialized w/ random weights; <code class="language-plaintext highlighter-rouge">ImageNet</code> is initialized with a pre-trained checkpoint at 84.5% top-1 ImageNet accuracy; <code class="language-plaintext highlighter-rouge">ImageNet++</code> is initialized with a checkpoint with a higher accuracy 86.9%. (Image source: <a href="https://arxiv.org/abs/2006.06882">Zoph et al. 2020</a>)</em></p>
<p>Their experiments demonstrated a series of interesting findings:</p>
<ul>
<li>The effectiveness of pre-training diminishes with more labeled samples available for the downstream task. Pre-training is helpful in the low-data regimes (20%) but neutral or harmful in the high-data regime.</li>
<li>Self-training helps in high data/strong augmentation regimes, even when pre-training hurts.</li>
<li>Self-training can bring in additive improvement on top of pre-training, even using the same data source.</li>
<li>Self-supervised pre-training (e.g. via SimCLR) hurts the performance in a high data regime, similar to how supervised pre-training does.</li>
<li>Joint-training supervised and self-supervised objectives help resolve the mismatch between the pre-training and downstream tasks. Pre-training, joint-training and self-training are all additive.</li>
<li>Noisy labels or un-targeted labeling (i.e. pre-training labels are not aligned with downstream task labels) is worse than targeted pseudo labeling.</li>
<li>Self-training is computationally more expensive than fine-tuning on a pre-trained model.</li>
</ul>
<p><a href="https://arxiv.org/abs/2006.10029">Chen et al. (2020)</a> proposed a three-step procedure to merge the benefits of self-supervised pretraining, supervised fine-tuning and self-training together:</p>
<ol>
<li>Unsupervised or self-supervised pretrain a big model.</li>
<li>Supervised fine-tune it on a few labeled examples. It is important to use a big (deep and wide) neural network. <em>Bigger models yield better performance with fewer labeled samples.</em></li>
<li>Distillation with unlabeled examples by adopting pseudo labels in self-training.
<ul>
<li>It is possible to distill the knowledge from a large model into a small one because the task-specific use does not require extra capacity of the learned representation.</li>
<li>
<p>The distillation loss is formatted as the following, where the teacher network is fixed with weights \(\hat{\theta}_T\).</p>
\[\mathcal{L}_\text{distill} = - (1-\alpha) \underbrace{\sum_{(\mathbf{x}^l_i, y_i) \in \mathcal{X}} \big[ \log p_{\theta_S}(y_i \mid \mathbf{x}^l_i) \big]}_\text{Supervised loss} - \alpha \underbrace{\sum_{\mathbf{u}_i \in \mathcal{U}} \Big[ \sum_{i=1}^L p_{\hat{\theta}_T}(y^{(i)} \mid \mathbf{u}_i; T) \log p_{\theta_S}(y^{(i)} \mid \mathbf{u}_i; T) \Big]}_\text{Distillation loss using unlabeled data}\]
</li>
</ul>
</li>
</ol>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/big-self-supervised-model.png" alt="big-self-supervised-model" /></p>
<p class="image-caption"><em>Fig. 19. A semi-supervised learning framework leverages unlabeled data corpus by (Left) task-agnostic unsupervised pretraining and (Right) task-specific self-training and distillation. (Image source: <a href="https://arxiv.org/abs/2006.10029">Chen et al. 2020</a>)</em></p>
<p>They experimented on the ImageNet classification task. The self-supervised pre-training uses SimCLRv2, a directly improved version of <a href="/lil-log/2021/05/31/contrastive-representation-learning.html#simclr">SimCLR</a>. Observations in their empirical studies confirmed several learnings, aligned with <a href="https://arxiv.org/abs/2006.06882">Zoph et al. 2020</a>:</p>
<ul>
<li>Bigger models are more label-efficient;</li>
<li>Bigger/deeper project heads in SimCLR improve representation learning;</li>
<li>Distillation using unlabeled data improves semi-supervised learning.</li>
</ul>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/big-self-supervised-model-results.png" alt="big-self-supervised-model-results" /></p>
<p class="image-caption"><em>Fig. 20. Comparison of performance by SimCLRv2 + semi-supervised distillation on ImageNet classification. (Image source: <a href="https://arxiv.org/abs/2006.10029">Chen et al. 2020</a>)</em></p>
<hr />
<p>💡 Quick summary of common themes among recent semi-supervised learning methods, many aiming to reduce confirmation bias:</p>
<ul>
<li>Apply valid and diverse noise to samples by advanced data augmentation methods.</li>
<li>When dealing with images, MixUp is an effective augmentation. Mixup could work on language too, resulting in a small incremental improvement (<a href="https://arxiv.org/abs/1905.08941">Guo et al. 2019</a>).</li>
<li>Set a threshold and discard pseudo labels with low confidence.</li>
<li>Set a minimum number of labeled samples per mini-batch.</li>
<li>Sharpen the pseudo label distribution to reduce the class overlap.</li>
</ul>
<h2 id="references">References</h2>
<p>[1] Ouali, Hudelot & Tami. <a href="https://arxiv.org/abs/2006.05278">“An Overview of Deep Semi-Supervised Learning”</a> arXiv preprint arXiv:2006.05278 (2020).</p>
<p>[2] Sajjadi, Javanmardi & Tasdizen <a href="https://arxiv.org/abs/1606.04586">“Regularization With Stochastic Transformations and Perturbations for Deep Semi-Supervised Learning.”</a> arXiv preprint arXiv:1606.04586 (2016).</p>
<p>[3] Pham et al. <a href="https://arxiv.org/abs/2003.10580">“Meta Pseudo Labels.”</a> CVPR 2021.</p>
<p>[4] Laine & Aila. <a href="https://arxiv.org/abs/1610.02242">“Temporal Ensembling for Semi-Supervised Learning”</a> ICLR 2017.</p>
<p>[5] Tarvaninen & Valpola. <a href="https://arxiv.org/abs/1703.01780">“Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results.”</a> NeuriPS 2017</p>
<p>[6] Xie et al. <a href="https://arxiv.org/abs/1904.12848">“Unsupervised Data Augmentation for Consistency Training.”</a> NeuriPS 2020.</p>
<p>[7] Miyato et al. <a href="https://arxiv.org/abs/1704.03976">“Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning.”</a> IEEE transactions on pattern analysis and machine intelligence 41.8 (2018).</p>
<p>[8] Verma et al. <a href="https://arxiv.org/abs/1903.03825">“Interpolation consistency training for semi-supervised learning.”</a> IJCAI 2019</p>
<p>[9] Lee. <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf">“Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks.”</a> ICML 2013 Workshop: Challenges in Representation Learning.</p>
<p>[10] Iscen et al. <a href="https://arxiv.org/abs/1904.04717">“Label propagation for deep semi-supervised learning.”</a> CVPR 2019.</p>
<p>[11] Xie et al. <a href="https://arxiv.org/abs/1911.04252">“Self-training with Noisy Student improves ImageNet classification”</a> CVPR 2020.</p>
<p>[12] Jingfei Du et al. <a href="https://arxiv.org/abs/2010.02194">“Self-training Improves Pre-training for Natural Language Understanding.”</a> 2020</p>
<p>[13] Iscen et al. <a href="https://arxiv.org/abs/1904.04717">“Label propagation for deep semi-supervised learning.”</a> CVPR 2019</p>
<p>[14] Arazo et al. <a href="https://arxiv.org/abs/1908.02983">“Pseudo-labeling and confirmation bias in deep semi-supervised learning.”</a> IJCNN 2020.</p>
<p>[15] Berthelot et al. <a href="https://arxiv.org/abs/1905.02249">“MixMatch: A holistic approach to semi-supervised learning.”</a> NeuriPS 2019</p>
<p>[16] Berthelot et al. <a href="https://arxiv.org/abs/1911.09785">“ReMixMatch: Semi-supervised learning with distribution alignment and augmentation anchoring.”</a> ICLR 2020</p>
<p>[17] Sohn et al. <a href="https://arxiv.org/abs/2001.07685">“FixMatch: Simplifying semi-supervised learning with consistency and confidence.”</a> CVPR 2020</p>
<p>[18] Junnan Li et al. <a href="https://arxiv.org/abs/2002.07394">“DivideMix: Learning with Noisy Labels as Semi-supervised Learning.”</a> 2020 [<a href="https://github.com/LiJunnan1992/DivideMix">code</a>]</p>
<p>[19] Zoph et al. <a href="https://arxiv.org/abs/2006.06882">“Rethinking pre-training and self-training.”</a> 2020.</p>
<p>[20] Chen et al. <a href="https://arxiv.org/abs/2006.10029">“Big Self-Supervised Models are Strong Semi-Supervised Learners”</a> 2020</p>Lilian WengThe performance of supervised learning tasks improves with more high-quality labels available. However, it is expensive to collect a large number of labeled samples. There are several paradigms in machine learning to deal with the scenario when the labels are scarce. Semi-supervised learning is one candidate, utilizing a large amount of unlabeled data conjunction with a small amount of labeled data.How to Train Really Large Models on Many GPUs?2021-09-24T12:00:00+00:002021-09-24T12:00:00+00:00https://lilianweng.github.io/lil-log/2021/09/24/train-large-neural-networks<blockquote>
<p>How to train large and deep neural networks is challenging, as it demands a large amount of GPU memory and a long horizon of training time. This post reviews several popular training parallelism paradigms, as well as a variety of model architecture and memory saving designs to make it possible to train very large neural networks across a large number of GPUs.</p>
</blockquote>
<!--more-->
<p>In recent years, we are seeing better results on many NLP benchmark tasks with larger pre-trained <a href="/lil-log/2019/01/31/generalized-language-models.html">language models</a>. How to train large and deep neural networks is challenging, as it demands a large amount of GPU memory and a long horizon of training time.</p>
<p>However an individual GPU worker has limited memory and the sizes of many large models have grown beyond a single GPU. There are several parallelism paradigms to enable model training across multiple GPUs, as well as a variety of model architecture and memory saving designs to help make it possible to train <em>very large</em> neural networks.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#training-parallelism" id="markdown-toc-training-parallelism">Training Parallelism</a> <ul>
<li><a href="#data-parallelism" id="markdown-toc-data-parallelism">Data Parallelism</a></li>
<li><a href="#model-parallelism" id="markdown-toc-model-parallelism">Model Parallelism</a></li>
<li><a href="#pipeline-parallelism" id="markdown-toc-pipeline-parallelism">Pipeline Parallelism</a></li>
<li><a href="#tensor-parallelism" id="markdown-toc-tensor-parallelism">Tensor Parallelism</a></li>
</ul>
</li>
<li><a href="#mixture-of-experts-moe" id="markdown-toc-mixture-of-experts-moe">Mixture-of-Experts (MoE)</a></li>
<li><a href="#other-memory-saving-designs" id="markdown-toc-other-memory-saving-designs">Other Memory Saving Designs</a> <ul>
<li><a href="#cpu-offloading" id="markdown-toc-cpu-offloading">CPU Offloading</a></li>
<li><a href="#activation-recomputation" id="markdown-toc-activation-recomputation">Activation Recomputation</a></li>
<li><a href="#mixed-precision-training" id="markdown-toc-mixed-precision-training">Mixed Precision Training</a></li>
<li><a href="#compression" id="markdown-toc-compression">Compression</a></li>
<li><a href="#memory-efficient-optimizer" id="markdown-toc-memory-efficient-optimizer">Memory Efficient Optimizer</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="training-parallelism">Training Parallelism</h2>
<p>The main bottleneck for training very large neural network models is the intense demand for a large amount of GPU memory, way above what can be hosted on an individual GPU machine. Besides the model weights (e.g. tens of billions of floating point numbers), it is usually even more expensive to store intermediate computation outputs such as gradients and optimizer states (e.g. momentums & variations in Adam). Additionally training a large model often pairs with a large training corpus and thus a single process may just take forever.</p>
<p>As a result, parallelism is necessary. Parallelism can happen at different dimensions, including data, model architecture, and tensor operation.</p>
<h3 id="data-parallelism">Data Parallelism</h3>
<p>The most naive way for <strong>Data parallelism (DP)</strong> is to copy the same model weights into multiple workers and assign a fraction of data to each worker to be processed at the same time.</p>
<p>Naive DP cannot work well if the model size is larger than a single GPU node’s memory. Methods like <em>GeePS</em> (<a href="https://www.pdl.cmu.edu/PDL-FTP/CloudComputing/GeePS-cui-eurosys16.pdf">Cui et al. 2016</a>) offload temporarily unused parameters back to CPU to work with limited GPU memory when the model is too big to fit into one machine. The data swapping transfer should happen at the backend and not interfere with training computation.</p>
<p>At the end of each minibatch, workers need to synchronize gradients or weights to avoid staleness. There are two main synchronization approaches and both have clear pros & cons.</p>
<ol>
<li><em>Bulk synchronous parallels (BSP)</em>: Workers sync data at the end of every minibatch. It prevents model weights staleness and good learning efficiency but each machine has to halt and wait for others to send gradients.</li>
<li><em>Asynchronous parallel (ASP)</em>: Every GPU worker processes the data asynchronously, no waiting or stalling. However, it can easily lead to stale weights being used and thus lower the statistical learning efficiency. Even though it increases the computation time, it may not speed up training time to convergence.</li>
</ol>
<p>Somewhere in the middle is to synchronize gradients globally once every \(x\) iterations (\(x > 1\)). This feature is called “gradient accumulation” in Distribution Data Parallel (<a href="https://pytorch.org/tutorials/intermediate/ddp_tutorial.html">DDP</a>) since Pytorch v1.5 (<a href="https://arxiv.org/abs/2006.15704">Li et al. 2021</a>). Bucketing gradients avoid immediate <code class="language-plaintext highlighter-rouge">AllReduce</code> operations but instead buckets multiple gradients into one <code class="language-plaintext highlighter-rouge">AllReduce</code> to improve throughput. Computation and communication scheduling optimization can be made based on the computation graph.</p>
<p style="width: 68%;" class="center"><img src="/lil-log/assets/images/pytorch-ddp.png" alt="Pytorch DDP" /></p>
<p class="image-caption"><em>Fig. 1. Pseudo code for Pytorch DDP. (Image source: <a href="https://arxiv.org/abs/2006.15704">Li et al. 2021</a>)</em></p>
<h3 id="model-parallelism">Model Parallelism</h3>
<p><strong>Model parallelism (MP)</strong> aims to solve the case when the model weights cannot fit into a single node. The computation and model parameters are partitioned across multiple machines. Different from data parallelism where each worker hosts a full copy of the entire model, MP only allocates a fraction of model parameters on one worker and thus both the memory usage and the computation are reduced.</p>
<p>Since deep neural networks usually contain a stack of vertical layers, it feels straightforward to split a large model by layer, where a small consecutive set of layers are grouped into one partition on one worker. However, a naive implementation for running every data batch through multiple such workers with sequential dependency leads to big bubbles of waiting time and severe under-utilization of computation resources.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/naive-data-parallelism.png" alt="Naive DP" /></p>
<p class="image-caption"><em>Fig. 2. A naive model parallelism setup where the model is vertically split into 4 partitions. Data is processed by one worker at a time due to sequential dependency, leading to large “bubbles” of idle time. (Image source: <a href="https://arxiv.org/abs/1811.06965">Huang et al. 2019</a>)</em></p>
<h3 id="pipeline-parallelism">Pipeline Parallelism</h3>
<p><strong>Pipeline parallelism (PP)</strong> combines model parallelism with data parallelism to reduce inefficient time “bubbles’’. The main idea is to split one minibatch into multiple microbatches and enable each stage worker to process one microbatch simultaneously. Note that every microbatch needs two passes, one forward and one backward. Inter-worker communication only transfers activations (forward) and gradients (backward). How these passes are scheduled and how the gradients are aggregated vary in different approaches. The number of partitions (workers) is also known as <em>pipeline depth</em>.</p>
<p>In <em>GPipe</em> (<a href="https://arxiv.org/abs/1811.06965">Huang et al. 2019</a>) gradients from multiple microbatches are aggregated and applied synchronously at the end. The synchronous gradient descent guarantees learning consistency and efficiency irrespective of the number of workers. As shown in Fig. 3, bubbles still exist but are much smaller than what’s in Fig. 2. Given \(m\) evenly split microbatches and \(d\) partitions, assuming both forward and backward per microbatch take one unit of time, the fraction of bubble is:</p>
\[1 - \frac{2md}{(2m + 2(d-1))d} = \frac{d-1}{m+d-1}\]
<p>The GPipe paper observed that the bubble overhead is almost negligible if the number of microbatches is more than 4x the number of partitions \(m > 4d\) (when <a href="#activation-recomputation">activation recomputation</a> is applied).</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/gpipe.png" alt="GPipe" /></p>
<p class="image-caption"><em>Fig. 3. Illustration of pipeline parallelism in GPipe with 4 microbatches and 4 partitions. GPipe aggregates and updates gradients across devices synchronously at the end of every batch. (Image source: <a href="https://arxiv.org/abs/1811.06965">Huang et al. 2019</a>)</em></p>
<p>GPipe achieves almost linear speedup in throughput with the number of devices, although it is not always guaranteed if the model parameters are not evenly distributed across workers.</p>
<p><em>PipeDream</em> (<a href="https://cs.stanford.edu/~matei/papers/2019/sosp_pipedream.pdf">Narayanan et al. 2019</a>) schedules each worker to alternatively process the forward and backward passes (<code class="language-plaintext highlighter-rouge">1F1B</code>).
PipeDream names each model partition “stage” and each stage worker can have multiple replicas to run data parallelism. In this process, PipeDream uses a deterministic round-robin load balancing strategy to assign work among multiple replicas of stages to ensure that the forward and backward passes for the same minibatch happen on the same replica.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/pipedream.png" alt="PipeDream" /></p>
<p class="image-caption"><em>Fig. 4. Illustration of <code class="language-plaintext highlighter-rouge">1F1B</code> microbatch scheduling in PipeDream. (Image source: <a href="https://arxiv.org/abs/1806.03377">Harlap et al. 2018</a>)</em></p>
<p>Since PipeDream does not have an end-of-batch global gradient sync across all the workers, an native implementation of 1F1B can easily lead to the forward and backward passes of one microbatch using different versions of model weights, thus lowering the learning efficiency. PipeDream proposed a few designs to tackle this issue:</p>
<ul>
<li><em>Weight stashing</em>: Each worker keeps track of several model versions and makes sure that the same version of weights are used in the forward and backward passes given one data batch.</li>
<li><em>Vertical sync</em> (Optional): The version of model weights flows between stage workers together with activations and gradients. Then the computation adopts the corresponding stashed version propagated from the previous worker. This process keeps version consistency across workers. Note that it is asynchronous, different from GPipe.</li>
</ul>
<p>At the beginning of a training run, PipeDream first profiles the computation memory cost and time of each layer in the model and then optimizes a solution for partitioning layers into stages, which is a dynamic programming problem.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/pipedream-results.png" alt="PipeDream experiments" /></p>
<p class="image-caption"><em>Fig. 5. Results for VGG16 on ILSVRC12. (Top) Accuracy vs time. The integer marks the number of stage workers. ASP = Asynchronous parallel & BSP = Bulk synchronous parallels. (Bottom) Training time speedup for different parallelism configurations. Straight pipeline refers to pipeline parallelism without data parallelism. (Image source: <a href="https://arxiv.org/abs/1806.03377">Harlap et al. 2018</a>)</em></p>
<p>Two variations of PipeDream were later proposed to reduce the memory footprint by stashed model versions (<a href="https://arxiv.org/abs/2006.09503">Narayanan et al. 2021</a>).</p>
<p><em>PipeDream-flush</em> adds a globally synchronized pipeline flush periodically, just like GPipe. In this way, it greatly reduces the memory footprint (i.e. only maintain a single version of model weights) by sacrificing a little throughput.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/pipedream-flush.png" alt="PipeDream-flush" /></p>
<p class="image-caption"><em>Fig. 6. Illustration of pipeline scheduling in PipeDream-flush. (Image source: (<a href="https://arxiv.org/abs/2006.09503">Narayanan et al. 2021</a>)</em></p>
<p><em>PipeDream-2BW</em> maintains only two versions of model weights, where “2BW” is short for “double-buffered weights”. It generates a new model version every \(k\) microbatches and \(k\) should be larger than the pipeline depth \(d\), \(k > d\). A newly updated model version cannot fully replace the old version immediately since some leftover backward passes still depend on the old version. In total only two versions need to be saved so the memory cost is much reduced.</p>
<p style="width: 95%;" class="center"><img src="/lil-log/assets/images/pipedream-2bw.png" alt="PipeDream-2BW" /></p>
<p class="image-caption"><em>Fig. 7. Illustration of pipeline scheduling in PipeDream-2BW. (Image source: (<a href="https://arxiv.org/abs/2006.09503">Narayanan et al. 2021</a>)</em></p>
<h3 id="tensor-parallelism">Tensor Parallelism</h3>
<p>Both model and pipeline parallelisms split a model vertically. OTOH we can horizontally partition the computation for one tensor operation across multiple devices, named <strong>Tensor parallelism (TP)</strong>.</p>
<p>Let’s take the transformer as an example given its popularity. The transformer model mainly consists of layers of MLP and self-attention blocks. <em>Megatron-LM</em> (<a href="https://arxiv.org/abs/1909.08053">Shoeybi et al. 2020</a>) adopts a simple way to parallelize intra-layer computation for MLP and self-attention.</p>
<p>A MLP layer in a transformer contains a GEMM (General matrix multiply) followed by an non-linear GeLU transfer. Let’s split weight matrix \(A\) by column:</p>
\[\begin{aligned}
\text{Split }A &= [A_1, A_2] \\
Y &=\text{GeLU}(XA) \\
[Y_1, Y_2] &= [\text{GeLU}(XA_1), \text{GeLU}(XA_2)]
\end{aligned}\]
<p>The attention block runs GEMM with query (\(Q\)), key (\(K\)), and value weights (\(V\)) according to the above partitioning in parallel and then combines them with another GEMM to produce the attention head results.</p>
\[\text{Attention}(X, Q, K, V) = \text{softmax}(\frac{(XQ) (XK)^\top}{\sqrt{d_k}}) XV\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/Megatron-LM.png" alt="Megatron LM" /></p>
<p class="image-caption"><em>Fig. 8. Illustration of tensor parallelism for key transformer components proposed in Megatron-LM. (Image source: <a href="https://arxiv.org/abs/1909.08053">Shoeybi et al. 2020</a>)</em></p>
<p><a href="https://arxiv.org/abs/2104.04473">Narayanan et al. (2021)</a> combined pipeline, tensor and data parallelism with a new pipeline scheduling strategy and named their approach <em>PTD-P</em>. Instead of only positioning a continuous set of layers (“model chunk”) on a device, each worker can be assigned with multiple chunks of smaller continuous subsets of layers (e.g. device 1 has layers 1, 2, 9, 10; device 2 has layers 3, 4, 11, 12; each has two model chunks). The number of microbatches in one batch should be exactly divided by the number of workers (\(m % d = 0\)). If there are \(v\) model chunks per worker, the pipeline bubble time can be reduced by a multiplier of \(v\) compared to a GPipe scheduling.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/PTD-P-interleaved.png" alt="PTD-P" /></p>
<p class="image-caption"><em>Fig. 9. (Top) Default <code class="language-plaintext highlighter-rouge">1F1B</code> pipeline schedule as in PipeDream-flush. (Bottom) Interleaved 1F1B pipeline schedule. First model chunks are in dark colors and second chunks are in light colors. (Image source: <a href="https://arxiv.org/abs/2104.04473">Narayanan et al. 202)</a>)</em></p>
<h2 id="mixture-of-experts-moe">Mixture-of-Experts (MoE)</h2>
<p>The <strong>Mixture-of-Experts (MoE)</strong> approach attracts a lot of attention recently as researchers (mainly from Google) try to push the limit of model size. The core of the idea is <a href="https://en.wikipedia.org/wiki/Ensemble_learning">ensembling learning</a>: <em>Combination of multiple weak learners gives you a strong learner!</em></p>
<p>Within one deep neural network, ensembling can be implemented with a gating mechanism connecting multiple experts (<a href="https://arxiv.org/abs/1701.06538">Shazeer et al., 2017</a>). The gating mechanism controls which subset of the network (e.g. which experts) should be activated to produce outputs. The paper named it “sparsely gated mixture-of-experts” (MoE) layer.</p>
<p>Precisely one MoE layer contains</p>
<ul>
<li>\(n\) feed-forward networks as experts \(\{E_i\}^n_{i=1}\)</li>
<li>A trainable gating network \(G\) to learn a probability distribution over \(n\) experts so as to route the traffic to a few selected experts.</li>
</ul>
<p>Depending on the gating outputs, not every expert has to be evaluated. When the number of experts is too large, we can consider using a two-level hierarchical MoE.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/moe.png" alt="MoE" /></p>
<p class="image-caption"><em>Fig. 10. Illustration of a mixture-of-experts (MoE) layer. Only 2 out of \(n\) experts are selected and activated by the gating network. (Image source: <a href="https://arxiv.org/abs/1701.06538">Shazeer et al., 2017</a>)</em></p>
<p>A simple choice of \(G\) is to multiply the input with a trainable weight matrix \(G_g\) and then do softmax: \(G_\sigma (x) = \text{softmax}(x W_g)\). However, this produces a dense control vector for gating and does not help save computation resources because we don’t need to evaluate an expert only when \(G^{(i)}(x)=0\). Thus the MoE layer only keeps the top \(k\) values. It also adds tunable Gaussian noise into \(G\) to improve load balancing. This mechanism is called <em>noisy top-k gating</em>.</p>
\[\begin{aligned}
G(x) &= \text{softmax}( \text{topk}(H(x), k)) \\
H^{(i)}(x) &= (xW_g)^{(i)} + \epsilon \cdot \text{softplus}((xW_\text{noise})^{(i)} ); \quad \epsilon \sim \mathcal{N}(0, \mathbf{1}) \\
\text{topk}^{(i)}(v, k) &= \begin{cases} v^{(i)} & \text{if }v^{(i)}\text{ is in the top }k\text{ elements of }v \\ -\infty & \text{otherwise}
\end{cases}
\end{aligned}\]
<p>where the superscript \(v^{(i)}\) denotes the i-th dimension of the vector \(v\). The function \(\text{topk}(., k)\) selected the top \(k\) dimensions with highest values by setting other dimensions to \(-\infty\).</p>
<p>To avoid the self-reinforcing effect that the gating network may favor a few strong experts all the time, <a href="https://arxiv.org/abs/1701.06538">Shazeer et al. (2017)</a> proposed a soft constraint via an additional importance loss to encourage all the experts to have the same weights. It is equivalent to the square of the <a href="https://en.wikipedia.org/wiki/Coefficient_of_variation">coefficient of variation</a> of batchwise average value per expert.</p>
\[L_\text{aux} = w_\text{aux} \cdot \text{CV}(\sum_{x \in X} G(x))^2\]
<p>where \(\text{CV}\) is the coefficient of variation and the loss weight \(w_\text{aux}\) is a hyperparameter to tune.</p>
<p>Because every expert network only gets a fraction of training samples (“The shrinking batch problem”), we should try to use a batch size as large as possible in MoE. However, it is restricted by GPU memory. Data parallelism and model parallelism can be applied to improve the throughput.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/moe-experiments.png" alt="MoE experiments" /></p>
<p class="image-caption"><em>Fig. 11. Test perplexity on 1-Billion-Word language modeling benchmark. (Left) The model capacity increases from left to right, containing 4, 32, 256, 256, 1024 and 4096 experts. (Right) Performance of the 4 billion parameters MoE model, the largest one in the left figure, under different computation budgets. (Image source: <a href="https://arxiv.org/abs/1701.06538">Shazeer et al., 2017</a>)</em></p>
<p><strong>GShard</strong> (<a href="https://arxiv.org/abs/2006.16668">Lepikhin et al., 2020</a>) scales the MoE transformer model up to 600 billion parameters with sharding. The MoE transformer replaces every other feed forward layer with a MoE layer. The <em>sharded MoE transformer</em> only has the MoE layers sharded across multiple machines, while other layers are simply duplicated.</p>
<p>There are several improved designs for the gating function \(G\) in GShard:</p>
<ul>
<li><em>Expert capacity</em>: The amount of tokens going through one expert should not go above a threshold, named “expert capacity”. If a token is routed to experts that have reached their capacity, the token would be marked “overflowed” and the gating output is changed to a zero vector.</li>
<li><em>Local group dispatching</em>: Tokens are evenly partitioned into multiple local groups and the expert capacity is enforced on the group level.</li>
<li><em>Auxiliary loss</em>: The motivation is similar to the original MoE aux loss. They add an auxiliary loss to minimize the mean square of the fraction of data routed to each expert.</li>
<li><em>Random routing</em>: The 2nd-best expert is selected with a probability proportional to its weight; otherwise, GShard follows a random routing, so as to add some randomness.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/gshard-algo.png" alt="GShard algorithm" /></p>
<p class="image-caption"><em>Fig. 12. Pseudo code of the group-level top-2 gating mechanism with auxiliary loss in GShard. (Image source: <a href="https://arxiv.org/abs/2006.16668">Lepikhin et al., 2020</a>)</em></p>
<p><strong>Switch Transformer</strong> (<a href="https://arxiv.org/abs/2101.03961">Fedus et al. 2021</a>) scales the model size up to trillions of parameters (!!) by replacing the dense feed forward layer with a <em>sparse switch FFN layer</em> in which each input is only routed to <em>one</em> expert network. The auxiliary loss for load balancing is \(\text{loss}_\text{aux} = w_\text{aux} \sum_{i=1}^n f_i p_i\) given \(n\) experts, where \(f_i\) is the fraction of tokens routed to the \(i\)-th expert and \(p_i\) is the routing probability for expert \(i\) predicted by the gating network.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/switch-transformer.png" alt="Switch transformer" /></p>
<p class="image-caption"><em>Fig. 13. Switch transformer. The sparse switch FFN layer is in the blue boxes. (Image source: <a href="https://arxiv.org/abs/2101.03961">Fedus et al. 2021</a>)</em></p>
<p>To improve training stability, switch transformer incorporates the following designs:</p>
<ul>
<li><em>Selective precision</em>. They showed that selectively casting only a local part of the model to FP32 precision improves stability, while avoiding the expensive communication cost of FP32 tensors. The FP32 precision is only used within the body of the router function and the results are recast to FP16.</li>
<li><em>Smaller initialization</em>. The initialization of weight matrices is sampled from a truncated normal distribution with mean \(\mu=0\) and stdev \(\sigma = \sqrt{s/n}\). They also recommended reducing the transformer initialization scale parameter \(s=1\) to \(s=0.1\).</li>
<li><em>Use higher expert dropout</em>. Fine-tuning often works with a small dataset. To avoid overfitting, the dropout rate within each expert is increased by a significant amount. Interestingly they found that increasing dropout across all layers lead to poor performance. In the paper, they used a dropout rate 0.1 at non-expert layers but 0.4 within expert FF layers.</li>
</ul>
<p>The switch transformer paper summarized different data and model parallelism strategies for training large models with a nice illustration:</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/switch-transformer-parallelism.png" alt="Parallelism strategies" /></p>
<p class="image-caption"><em>Fig. 14. An illustration of various parallelism strategies on how (Top) model weights and (Bottom) data are split over multiple GPU cores. In the top row, each color denotes a unique weight matrix. In the bottom row, different colors indicate different sets of tokens. (Image source: <a href="https://arxiv.org/abs/2101.03961">Fedus et al. 2021</a>)</em></p>
<h2 id="other-memory-saving-designs">Other Memory Saving Designs</h2>
<h3 id="cpu-offloading">CPU Offloading</h3>
<p>When the GPU memory is full, one option is to offload temporarily unused data to CPU and read them back when needed later (<a href="https://arxiv.org/abs/1602.08124">Rhu et al. 2016</a>). The idea of <strong>CPU offloading</strong> is straightforward but is less popular in recent years due to the slowdown it brings into the training time.</p>
<h3 id="activation-recomputation">Activation Recomputation</h3>
<p><strong>Activation recomputation</strong> (also known as “activation checkpointing” or “gradient checkpointing”; [Chen et al. 2016 (https://arvix.org/abs/1604.06174)) is a smart yet simple idea to reduce memory footprint at the cost of computation time. It reduces the memory cost of training a \(\ell\) layer deep neural net to \(O(\sqrt{\ell})\), which only additionally consumes an extra forward pass computation per batch.</p>
<p>Let’s say, we evenly divide an \(\ell\)-layer network into \(d\) partitions. Only activations at partition boundaries are saved and communicated between workers. Intermediate activations at intra-partition layers are still needed for computing gradients so they are recomputed during backward passes. With activation recomputation, the memory cost for training \(M(\ell)\) is:</p>
\[M(\ell)
=\max_{i=1,\dots,k} \underbrace{\text{cost-of-one-partition}(i)}_\text{cost of back-propagation on the i-th partition} + \underbrace{O(d)}_\text{store intermediate outputs}
= O(\frac{\ell}{d}) + O(d)\]
<p>The minimum cost is \(O(\sqrt{\ell})\) at \(d=\sqrt{\ell}\).</p>
<p>Activation recompuation trick can give sublinear memory cost with respect to the model size.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/activation-checkpointing.png" alt="Activation checkpointing experiments" /></p>
<p class="image-caption"><em>Fig. 15. The memory cost of different memory saving algorithms. <u>Sharing</u>: Memory used by intermediate results is recycled when no longer needed. <u>Inplace</u>: Save the output directly into memory of an input value. (Image source: <a href="https://arvix.org/abs/1604.06174">Chen et al. 2016</a>)</em></p>
<h3 id="mixed-precision-training">Mixed Precision Training</h3>
<p><a href="https://arxiv.org/abs/1710.03740">Narang & Micikevicius et al. (2018)</a> introduced a method to train models using half-precision floating point (FP16) numbers without losing model accuracy.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/mixed-precision-training.png" alt="Mixed-precision training" /></p>
<p class="image-caption"><em>Fig. 16. The procedure of mixed precision training at one layer. (Image source: <a href="https://arxiv.org/abs/1710.03740">Narang & Micikevicius, et al. 2018</a>)</em></p>
<p>Three techniques to avoid losing critical information at half-precision:</p>
<ul>
<li><em>Full-precision master copy of weights</em>. Maintain a full precision (FP32) copy of model weights that accumulates gradients. The numbers are rounded up to half-precision for forward & backward passes. The motivation is that each gradient update (i.e. gradient times the learning rate) might be too small to be fully contained within the FP16 range (i.e. \(2^{-24}\) becomes zero in FP16).</li>
<li><em>Loss scaling</em>. Scale up the loss to better handle gradients with small magnitudes (See Fig. 16). Scaling up the gradients helps shift them to occupy a larger section towards the right section (containing larger values) of the representable range, preserving values that are otherwise lost.</li>
<li><em>Arithmetic precision</em>. For common network arithmetic (e.g. vector dot-product, reduction by summing up vector elements), we can accumulate the partial results in FP32 and then save the final output as FP16 before saving into memory. Point-wise operations can be executed in either FP16 or FP32.</li>
</ul>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/gradient-histogram.png" alt="Gradient histogram" /></p>
<p class="image-caption"><em>Fig. 17. The histogram of gradients in full precision. The left part up to \(2^{-24}\) will be zero-ed off once the model switches to FP16. (Image source: <a href="https://arxiv.org/abs/1710.03740">Narang & Micikevicius, et al. 2018</a>)</em></p>
<p>In their experiments, loss scaling is not needed for some networks (e.g. image classification, Faster R-CNN), but necessary for others (e.g. Multibox SSD, big LSTM language model).</p>
<h3 id="compression">Compression</h3>
<p>Intermediate results often consume a lot of memory, although they are only needed in one forward pass and one backward pass. There is a noticeable temporal gap between these two uses. Thus <a href="https://www.microsoft.com/en-us/research/uploads/prod/2018/04/fiddle-gist-isca18.pdf">Jain et al. (2018)</a> proposed a data encoding strategy to compress the intermediate results after the first use in the first pass and then decode it back for back-propagation later.</p>
<p>Their system <em>Gist</em> incorporates two encoding schemes:
<em>Layer-specific lossless encoding</em>; focus on ReLU-Pool (“Binarize”) and ReLU-Conv (“Sparse storage and dense computation”) patterns.
<em>Aggressive lossy encoding</em>; use delayed precision reduction (DPR). They observed that the first immediate use of feature maps should be kept at high precision but the second use can tolerate lower precision.</p>
<p>The experiments showed that Gist can reduce the memory cost by 2x across 5 SOTA image classification DNNs, with an average of 1.8x with only 4% performance overhead.</p>
<h3 id="memory-efficient-optimizer">Memory Efficient Optimizer</h3>
<p>Optimizers are eager for memory consumption. Take the popular Adam optimizer as an example, it internally needs to maintain momentums and variances, both at the same scale as gradients and model parameters. All out of a sudden, we need to save 4x the memory of model weights.</p>
<p>Several optimizers have been proposed to reduce the memory footprint.
For example, instead of storing the full momentums and variations as in Adam, <em>Adafactor</em> (<a href="https://arxiv.org/abs/1804.04235">Shazeer et al. 2018</a>) only tracks the per-row and per-column sums of the moving averages and then estimates the second moments based on these sums. <em>SM3</em> (<a href="https://arxiv.org/abs/1901.11150">Anil et al. 2019</a>) describes a different adaptive optimization method, leading to largely reduced memory as well.</p>
<p><em>ZeRO</em> (<em>Zero Redundancy Optimizer</em>; <a href="https://arxiv.org/abs/1910.02054">Rajbhandari et al. 2019</a>) optimizes the memory used for training large models based on the observation about two major memory consumption of large model training:</p>
<ol>
<li>The majority is occupied by <em>model states</em>, including optimizer states (e.g. Adam momentums and variances), gradients and parameters. Mixed-precision training demands a lot of memory since the optimizer needs to keep a copy of FP32 parameters and other optimizer states, besides the FP16 version.</li>
<li>The remaining is consumed by activations, temporary buffers and unusable fragmented memory (named <em>residual states</em> in the paper).</li>
</ol>
<p>ZeRO combines two approaches, <em>ZeRO-DP</em> and <em>ZeRO-R</em>.
ZeRO-DP is an enhanced data parallelism to avoid simple redundancy over model states. It partitions optimizer state, gradients and parameters across multiple data parallel processes via a dynamic communication schedule to minimize the communication volume.
ZeRO-R optimizes the memory consumption of residual states, using partitioned activation recomputation, constant buffer size and on-the-fly memory defragmentation.</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2021large,
title = "How to Train Really Large Models on Many GPUs?",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2021",
url = "https://lilianweng.github.io/lil-log/2021/09/24/train-large-neural-networks.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Li et al. <a href="https://arxiv.org/abs/2006.15704">“PyTorch Distributed: Experiences on Accelerating Data Parallel Training”</a> VLDB 2020.</p>
<p>[2] Cui et al. <a href="https://www.pdl.cmu.edu/PDL-FTP/CloudComputing/GeePS-cui-eurosys16.pdf">“GeePS: Scalable deep learning on distributed GPUs with a GPU-specialized parameter server”</a> EuroSys 2016</p>
<p>[3] Shoeybi et al. <a href="https://arxiv.org/abs/1909.08053">“Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.”</a> arXiv preprint arXiv:1909.08053 (2019).</p>
<p>[4] Narayanan et al. <a href="https://arxiv.org/abs/2104.04473">“Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM.”</a> arXiv preprint arXiv:2104.04473 (2021).</p>
<p>[5] Huang et al. <a href="https://arxiv.org/abs/1811.06965">“GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism.”</a> arXiv preprint arXiv:1811.06965 (2018).</p>
<p>[6] Narayanan et al. <a href="https://cs.stanford.edu/~matei/papers/2019/sosp_pipedream.pdf">“PipeDream: Generalized Pipeline Parallelism for DNN Training.”</a> SOSP 2019.</p>
<p>[7] Narayanan et al. <a href="https://arxiv.org/abs/2006.09503">“Memory-Efficient Pipeline-Parallel DNN Training.”</a> ICML 2021.</p>
<p>[8] Shazeer et al. <a href="https://arxiv.org/abs/1701.06538">“The Sparsely-Gated Mixture-of-Experts Layer Noam.”</a> arXiv preprint arXiv:1701.06538 (2017).</p>
<p>[9] Lepikhin et al. <a href="https://arxiv.org/abs/2006.16668">“GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding.”</a> arXiv preprint arXiv:2006.16668 (2020).</p>
<p>[10] Fedus et al. <a href="https://arxiv.org/abs/2101.03961">“Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.”</a> arXiv preprint arXiv:2101.03961 (2021).</p>
<p>[11] Narang & Micikevicius, et al. <a href="https://arxiv.org/abs/1710.03740">“Mixed precision training.”</a> ICLR 2018.</p>
<p>[12] Chen et al. 2016 <a href="https://arxiv.org/abs/1604.06174">“Training Deep Nets with Sublinear Memory Cost.”</a> arXiv preprint arXiv:1604.06174 (2016).</p>
<p>[13] Jain et al. <a href="https://www.microsoft.com/en-us/research/uploads/prod/2018/04/fiddle-gist-isca18.pdf">“Gist: Efficient data encoding for deep neural network training.”</a> ISCA 2018.</p>
<p>[14] Shazeer & Stern. <a href="https://arxiv.org/abs/1804.04235">“Adafactor: Adaptive learning rates with sublinear memory cost.”</a> arXiv preprint arXiv:1804.04235 (2018).</p>
<p>[15] Anil et al. <a href="https://arxiv.org/abs/1901.11150">“Memory-Efficient Adaptive Optimization.”</a> arXiv preprint arXiv:1901.11150 (2019).</p>
<p>[16] Rajbhandari et al. <a href="https://arxiv.org/abs/1910.02054">“ZeRO: Memory Optimization Towards Training A Trillion Parameter Models Samyam.”</a> arXiv preprint arXiv:1910.02054 (2019).</p>Lilian WengHow to train large and deep neural networks is challenging, as it demands a large amount of GPU memory and a long horizon of training time. This post reviews several popular training parallelism paradigms, as well as a variety of model architecture and memory saving designs to make it possible to train very large neural networks across a large number of GPUs.What are Diffusion Models?2021-07-11T12:00:00+00:002021-07-11T12:00:00+00:00https://lilianweng.github.io/lil-log/2021/07/11/diffusion-models<blockquote>
<p>Diffusion models are a new type of generative models that are flexible enough to learn any arbitrarily complex data distribution while tractable to analytically evaluate the distribution. It has been shown recently that diffusion models can generate high-quality images and the performance is competitive to SOTA GAN.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2021-09-19: Highly recommend this blog post on <a href="https://yang-song.github.io/blog/2021/score/">score-based generative modeling</a> by Yang Song (author of several key papers in the references)].</span></p>
<p>So far, I’ve written about three types of generative models, <a href="/lil-log/2017/08/20/from-GAN-to-WGAN.html">GAN</a>, <a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html">VAE</a>, and <a href="/lil-log/2018/10/13/flow-based-deep-generative-models.html">Flow-based</a> models. They have shown great success in generating high-quality samples, but each has some limitations of its own. GAN models are known for potentially unstable training and less diversity in generation due to their adversarial training nature. VAE relies on a surrogate loss. Flow models have to use specialized architectures to construct reversible transform.</p>
<p>Diffusion models are inspired by non-equilibrium thermodynamics. They define a Markov chain of diffusion steps to slowly add random noise to data and then learn to reverse the diffusion process to construct desired data samples from the noise. Unlike VAE or flow models, diffusion models are learned with a fixed procedure and the latent variable has high dimensionality (same as the original data).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/generative-overview.png" alt="Overview" /></p>
<p class="image-caption">Fig. 1. Overview of different types of generative models.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#what-are-diffusion-models" id="markdown-toc-what-are-diffusion-models">What are Diffusion Models?</a> <ul>
<li><a href="#forward-diffusion-process" id="markdown-toc-forward-diffusion-process">Forward diffusion process</a></li>
<li><a href="#reverse-diffusion-process" id="markdown-toc-reverse-diffusion-process">Reverse diffusion process</a></li>
<li><a href="#parameterization-of-l_t-for-training-loss" id="markdown-toc-parameterization-of-l_t-for-training-loss">Parameterization of \(L_t\) for Training Loss</a></li>
<li><a href="#parameterization-of-beta_t" id="markdown-toc-parameterization-of-beta_t">Parameterization of \(\beta_t\)</a></li>
<li><a href="#parameterization-of-reverse-process-variance-boldsymbolsigma_theta" id="markdown-toc-parameterization-of-reverse-process-variance-boldsymbolsigma_theta">Parameterization of reverse process variance \(\boldsymbol{\Sigma}_\theta\)</a></li>
</ul>
</li>
<li><a href="#speed-up-diffusion-model-sampling" id="markdown-toc-speed-up-diffusion-model-sampling">Speed up Diffusion Model Sampling</a></li>
<li><a href="#conditioned-generation" id="markdown-toc-conditioned-generation">Conditioned Generation</a></li>
<li><a href="#quick-summary" id="markdown-toc-quick-summary">Quick Summary</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="what-are-diffusion-models">What are Diffusion Models?</h2>
<p>Several diffusion-based generative models have been proposed with similar ideas underneath, including <em>diffusion probabilistic models</em> (<a href="https://arxiv.org/abs/1503.03585">Sohl-Dickstein et al., 2015</a>), <em>noise-conditioned score network</em> (<strong>NCSN</strong>; <a href="https://arxiv.org/abs/1907.05600">Yang & Ermon, 2019</a>), and <em>denoising diffusion probabilistic models</em> (<strong>DDPM</strong>; <a href="https://arxiv.org/abs/2006.11239">Ho et al. 2020</a>).</p>
<h3 id="forward-diffusion-process">Forward diffusion process</h3>
<p>Given a data point sampled from a real data distribution \(\mathbf{x}_0 \sim q(\mathbf{x})\), let us define a <em>forward diffusion process</em> in which we add small amount of Gaussian noise to the sample in \(T\) steps, producing a sequence of noisy samples \(\mathbf{x}_1, \dots, \mathbf{x}_T\). The step sizes are controlled by a variance schedule \(\{\beta_t \in (0, 1)\}_{t=1}^t\).</p>
\[q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t\mathbf{I}) \quad
q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q(\mathbf{x}_t \vert \mathbf{x}_{t-1})\]
<p>The data sample \(\mathbf{x}_0\) gradually loses its distinguishable features as the step \(t\) becomes larger. Eventually when \(T \to \infty\), \(\mathbf{x}_T\) is equivalent to an isotropic Gaussian distribution.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DDPM.png" alt="DDPM" /></p>
<p class="image-caption">Fig. 2. The Markov chain of forward (reverse) diffusion process of generating a sample by slowly adding (removing) noise. (Image source: <a href="https://arxiv.org/abs/2006.11239">Ho et al. 2020</a> with a few additional annotations)</p>
<p><a name="nice"></a>A nice property of the above process is that we can sample \(\mathbf{x}_t\) at any arbitrary time step \(t\) in a closed form using <a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#reparameterization-trick">reparameterization trick</a>. Let \(\alpha_t = 1 - \beta_t\) and \(\bar{\alpha}_t = \prod_{i=1}^T \alpha_i\):</p>
\[\begin{aligned}
\mathbf{x}_t
&= \sqrt{\alpha_t}\mathbf{x}_{t-1} + \sqrt{1 - \alpha_t}\mathbf{z}_{t-1} & \text{ ;where } \mathbf{z}_{t-1}, \mathbf{z}_{t-2}, \dots \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\
&= \sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \bar{\mathbf{z}}_{t-2} & \text{ ;where } \bar{\mathbf{z}}_{t-2} \text{ merges two Gaussians (*).} \\
&= \dots \\
&= \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{z} \\
q(\mathbf{x}_t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I})
\end{aligned}\]
<p>(*) Recall that when we merge two Gaussians with different variance, \(\mathcal{N}(\mathbf{0}, \sigma_1^2\mathbf{I})\) and \(\mathcal{N}(\mathbf{0}, \sigma_2^2\mathbf{I})\), the new distribution is \(\mathcal{N}(\mathbf{0}, (\sigma_1^2 + \sigma_2^2)\mathbf{I})\). Here the merged standard deviation is \(\sqrt{(1 - \alpha_t) + \alpha_t (1-\alpha_{t-1})} = \sqrt{1 - \alpha_t\alpha_{t-1}}\).</p>
<p>Usually, we can afford a larger update step when the sample gets noisier, so \(\beta_1 < \beta_2 < \dots < \beta_T\) and therefore \(\bar{\alpha}_1 > \dots > \bar{\alpha}_T\).</p>
<h4 id="connection-with-stochastic-gradient-langevin-dynamics">Connection with stochastic gradient Langevin dynamics</h4>
<p>Langevin dynamics is a concept from physics, developed for statistically modeling molecular systems. Combined with stochastic gradient descent, <em>stochastic gradient Langevin dynamics</em> (<a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.226.363">Welling & Teh 2011</a>) can produce samples from a probability density \(p(\mathbf{x})\) using only the gradients \(\nabla_\mathbf{x} \log p(\mathbf{x})\) in a Markov chain of updates:</p>
\[\mathbf{x}_t = \mathbf{x}_{t-1} + \frac{\epsilon}{2} \nabla_\mathbf{x} p(\mathbf{x}_{t-1}) + \sqrt{\epsilon} \mathbf{z}_t
,\quad\text{where }
\mathbf{z}_t \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\]
<p>where \(\epsilon\) is the step size. When \(T \to \infty, \epsilon \to 0\), \(\mathbf{x}_T\) equals to the true probability density \(p(\mathbf{x})\).</p>
<p>Compared to standard SGD, stochastic gradient Langevin dynamics injects Gaussian noise into the parameter updates to avoid collapses into local minima.</p>
<h3 id="reverse-diffusion-process">Reverse diffusion process</h3>
<p>If we can reverse the above process and sample from \(q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)\), we will be able to recreate the true sample from a Gaussian noise input, \(\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\). Note that if \(\beta_t\) is small enough, \(q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)\) will also be Gaussian. Unfortunately, we cannot easily estimate \(q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)\) because it needs to use the entire dataset and therefore we need to learn a model \(p_\theta\) to approximate these conditional probabilities in order to run the <em>reverse diffusion process</em>.</p>
\[p_\theta(\mathbf{x}_{0:T}) = p(\mathbf{x}_T) \prod^T_{t=1} p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) \quad
p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t))\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/diffusion-example.png" alt="Diffusion model examples" /></p>
<p class="image-caption">Fig. 3. An example of training a diffusion model for modeling a 2D swiss roll data. (Image source: <a href="https://arxiv.org/abs/1503.03585">Sohl-Dickstein et al., 2015</a>)</p>
<p>It is noteworthy that the reverse conditional probability is tractable when conditioned on \(\mathbf{x}_0\):</p>
\[q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \color{blue}{\tilde{\boldsymbol{\mu}}}(\mathbf{x}_t, \mathbf{x}_0), \color{red}{\tilde{\beta}_t} \mathbf{I})\]
<p>Using Bayes’ rule, we have:</p>
\[\begin{aligned}
q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)
&= q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0) \frac{ q(\mathbf{x}_{t-1} \vert \mathbf{x}_0) }{ q(\mathbf{x}_t \vert \mathbf{x}_0) } \\
&\propto \exp \Big(-\frac{1}{2} \big(\frac{(\mathbf{x}_t - \sqrt{\alpha_t} \mathbf{x}_{t-1})^2}{\beta_t} + \frac{(\mathbf{x}_{t-1} - \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0)^2}{1-\bar{\alpha}_{t-1}} - \frac{(\mathbf{x}_t - \sqrt{\bar{\alpha}_t} \mathbf{x}_0)^2}{1-\bar{\alpha}_t} \big) \Big) \\
&= \exp\Big( -\frac{1}{2} \big( \color{red}{(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}})} \mathbf{x}_{t-1}^2 - \color{blue}{(\frac{2\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{2\sqrt{\bar{\alpha}_t}}{1 - \bar{\alpha}_t} \mathbf{x}_0)} \mathbf{x}_{t-1} + C(\mathbf{x}_t, \mathbf{x}_0) \big) \Big)
\end{aligned}\]
<p>where \(C(\mathbf{x}_t, \mathbf{x}_0)\) is some function not involving \(\mathbf{x}_{t-1}\) and details are omitted. Following the standard Gaussian density function, the mean and variance can be parameterized as follows:</p>
\[\begin{aligned}
\tilde{\beta}_t &= 1/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}}) = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t \\
\tilde{\boldsymbol{\mu}}_t (\mathbf{x}_t, \mathbf{x}_0)
&= (\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_t}}{1 - \bar{\alpha}_t} \mathbf{x}_0)/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}})
= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \mathbf{x}_0\\
\end{aligned}\]
<p>Thanks to the <a href="#nice">nice property</a>, we can represent \(\mathbf{x}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t}\mathbf{z}_t)\) and plug it into the above equation and obtain:</p>
\[\begin{aligned}
\tilde{\boldsymbol{\mu}}_t
&= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t}\mathbf{z}_t) \\
&= \color{cyan}{\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{z}_t \Big)}
\end{aligned}\]
<p>As demonstrated in Fig. 2., such a setup is very similar to <a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html">VAE</a> and thus we can use the variational lower bound to optimize the negative log-likelihood.</p>
\[\begin{aligned}
- \log p_\theta(\mathbf{x}_0)
&\leq - \log p_\theta(\mathbf{x}_0) + D_\text{KL}(q(\mathbf{x}_{1:T}\vert\mathbf{x}_0) \| p_\theta(\mathbf{x}_{1:T}\vert\mathbf{x}_0) ) \\
&= -\log p_\theta(\mathbf{x}_0) + \mathbb{E}_{\mathbf{x}_{1:T}\sim q(\mathbf{x}_{1:T} \vert \mathbf{x}_0)} \Big[ \log\frac{q(\mathbf{x}_{1:T}\vert\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T}) / p_\theta(\mathbf{x}_0)} \Big] \\
&= -\log p_\theta(\mathbf{x}_0) + \mathbb{E}_q \Big[ \log\frac{q(\mathbf{x}_{1:T}\vert\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T})} + \log p_\theta(\mathbf{x}_0) \Big] \\
&= \mathbb{E}_q \Big[ \log \frac{q(\mathbf{x}_{1:T}\vert\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T})} \Big] \\
\text{Let }L_\text{VLB}
&= \mathbb{E}_{q(\mathbf{x}_{0:T})} \Big[ \log \frac{q(\mathbf{x}_{1:T}\vert\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T})} \Big] \geq - \mathbb{E}_{q(\mathbf{x}_0)} \log p_\theta(\mathbf{x}_0)
\end{aligned}\]
<p>It is also straightforward to get the same result using Jensen’s inequality. Say we want to minimize the cross entropy as the learning objective,</p>
\[\begin{aligned}
L_\text{CE}
&= - \mathbb{E}_{q(\mathbf{x}_0)} \log p_\theta(\mathbf{x}_0) \\
&= - \mathbb{E}_{q(\mathbf{x}_0)} \log \Big( \int p_\theta(\mathbf{x}_{0:T}) d\mathbf{x}_{1:T} \Big) \\
&= - \mathbb{E}_{q(\mathbf{x}_0)} \log \Big( \int q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})} d\mathbf{x}_{1:T} \Big) \\
&= - \mathbb{E}_{q(\mathbf{x}_0)} \log \Big( \mathbb{E}_{q(\mathbf{x}_{1:T} \vert \mathbf{x}_0)} \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})} \Big) \\
&\leq - \mathbb{E}_{q(\mathbf{x}_{0:T})} \log \frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})} \\
&= \mathbb{E}_{q(\mathbf{x}_{0:T})}\Big[\log \frac{q(\mathbf{x}_{1:T} \vert \mathbf{x}_{0})}{p_\theta(\mathbf{x}_{0:T})} \Big] = L_\text{VLB}
\end{aligned}\]
<p>To convert each term in the equation to be analytically computable, the objective can be further rewritten to be a combination of several KL-divergence and entropy terms (See the detailed step-by-step process in Appendix B in <a href="https://arxiv.org/abs/1503.03585">Sohl-Dickstein et al., 2015</a>):</p>
\[\begin{aligned}
L_\text{VLB}
&= \mathbb{E}_{q(\mathbf{x}_{0:T})} \Big[ \log\frac{q(\mathbf{x}_{1:T}\vert\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T})} \Big] \\
&= \mathbb{E}_q \Big[ \log\frac{\prod_{t=1}^T q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{ p_\theta(\mathbf{x}_T) \prod_{t=1}^T p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t) } \Big] \\
&= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=1}^T \log \frac{q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} \Big] \\
&= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_t\vert\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \log\frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\
&= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \Big( \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)}\cdot \frac{q(\mathbf{x}_t \vert \mathbf{x}_0)}{q(\mathbf{x}_{t-1}\vert\mathbf{x}_0)} \Big) + \log \frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\
&= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \sum_{t=2}^T \log \frac{q(\mathbf{x}_t \vert \mathbf{x}_0)}{q(\mathbf{x}_{t-1} \vert \mathbf{x}_0)} + \log\frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big] \\
&= \mathbb{E}_q \Big[ -\log p_\theta(\mathbf{x}_T) + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} + \log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{q(\mathbf{x}_1 \vert \mathbf{x}_0)} + \log \frac{q(\mathbf{x}_1 \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)} \Big]\\
&= \mathbb{E}_q \Big[ \log\frac{q(\mathbf{x}_T \vert \mathbf{x}_0)}{p_\theta(\mathbf{x}_T)} + \sum_{t=2}^T \log \frac{q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t)} - \log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1) \Big] \\
&= \mathbb{E}_q [\underbrace{D_\text{KL}(q(\mathbf{x}_T \vert \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_T))}_{L_T} + \sum_{t=2}^T \underbrace{D_\text{KL}(q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_{t-1} \vert\mathbf{x}_t))}_{L_{t-1}} \underbrace{- \log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)}_{L_0} ]
\end{aligned}\]
<p>Let’s label each component in the variational lower bound loss separately:</p>
\[\begin{aligned}
L_\text{VLB} &= L_T + L_{T-1} + \dots + L_0 \\
\text{where } L_T &= D_\text{KL}(q(\mathbf{x}_T \vert \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_T)) \\
L_t &= D_\text{KL}(q(\mathbf{x}_t \vert \mathbf{x}_{t+1}, \mathbf{x}_0) \parallel p_\theta(\mathbf{x}_t \vert\mathbf{x}_{t+1})) \text{ for }1 \leq t \leq T-1 \\
L_0 &= - \log p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)
\end{aligned}\]
<p>Every KL term in \(L_\text{VLB}\) (except for \(L_0\)) compares two Gaussian distributions and therefore they can be computed in <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Multivariate_normal_distributions">closed form</a>. \(L_T\) is constant and can be ignored during training because \(q\) has no learnable parameters and \(\mathbf{x}_T\) is a Gaussian noise. <a href="https://arxiv.org/abs/2006.11239">Ho et al. 2020</a> models \(L_0\) using a separate discrete decoder derived from \(\mathcal{N}(\mathbf{x}_0; \boldsymbol{\mu}_\theta(\mathbf{x}_1, 1), \boldsymbol{\Sigma}_\theta(\mathbf{x}_1, 1))\).</p>
<h3 id="parameterization-of-l_t-for-training-loss">Parameterization of \(L_t\) for Training Loss</h3>
<p>Recall that we need to learn a neural network to approximate the conditioned probability distributions in the reverse diffusion process, \(p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t))\). We would like to train \(\boldsymbol{\mu}_\theta\) to predict \(\tilde{\boldsymbol{\mu}}_t = \frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{z}_t \Big)\). Because \(\mathbf{x}_t\) is available as input at training time, we can reparameterize the Gaussian noise term instead to make it predict \(\mathbf{z}_t\) from the input \(\mathbf{x}_t\) at time step \(t\):</p>
\[\begin{aligned}
\boldsymbol{\mu}_\theta(\mathbf{x}_t, t) &= \color{cyan}{\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{z}_\theta(\mathbf{x}_t, t) \Big)} \\
\text{Thus }\mathbf{x}_{t-1} &= \mathcal{N}(\mathbf{x}_{t-1}; \frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{z}_\theta(\mathbf{x}_t, t) \Big), \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t))
\end{aligned}\]
<p>The loss term \(L_t\) is parameterized to minimize the difference from \(\tilde{\boldsymbol{\mu}}\) :</p>
\[\begin{aligned}
L_t
&= \mathbb{E}_{\mathbf{x}_0, \mathbf{z}} \Big[\frac{1}{2 \| \boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t) \|^2_2} \| \color{blue}{\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0)} - \color{green}{\boldsymbol{\mu}_\theta(\mathbf{x}_t, t)} \|^2 \Big] \\
&= \mathbb{E}_{\mathbf{x}_0, \mathbf{z}} \Big[\frac{1}{2 \|\boldsymbol{\Sigma}_\theta \|^2_2} \| \color{blue}{\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{z}_t \Big)} - \color{green}{\frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\mathbf{z}}_\theta(\mathbf{x}_t, t) \Big)} \|^2 \Big] \\
&= \mathbb{E}_{\mathbf{x}_0, \mathbf{z}} \Big[\frac{ \beta_t^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \| \boldsymbol{\Sigma}_\theta \|^2_2} \|\mathbf{z}_t - \mathbf{z}_\theta(\mathbf{x}_t, t)\|^2 \Big] \\
&= \mathbb{E}_{\mathbf{x}_0, \mathbf{z}} \Big[\frac{ \beta_t^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \| \boldsymbol{\Sigma}_\theta \|^2_2} \|\mathbf{z}_t - \mathbf{z}_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{z}_t, t)\|^2 \Big]
\end{aligned}\]
<h4 id="simplification">Simplification</h4>
<p>Empirically, <a href="https://arxiv.org/abs/2006.11239">Ho et al. (2020)</a> found that training the diffusion model works better with a simplified objective that ignores the weighting term:</p>
\[L_t^\text{simple} = \mathbb{E}_{\mathbf{x}_0, \mathbf{z}_t} \Big[\|\mathbf{z}_t - \mathbf{z}_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\mathbf{z}_t, t)\|^2 \Big]\]
<p>The final simple objective is:</p>
\[L_\text{simple} = L_t^\text{simple} + C\]
<p>where \(C\) is a constant not depending on \(\theta\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DDPM-algo.png" alt="DDPM algorithm" /></p>
<p class="image-caption">Fig. 4. The training and sampling algorithms in DDPM (Image source: <a href="https://arxiv.org/abs/2006.11239">Ho et al. 2020</a>)</p>
<h4 id="connection-with-noise-conditioned-score-networks-ncsn">Connection with noise-conditioned score networks (NCSN)</h4>
<p><a href="https://arxiv.org/abs/1907.05600">Song & Ermon (2019)</a> proposed a score-based generative modeling method where samples are produced via <a href="#connection-with-stochastic-gradient-langevin-dynamics">Langevin dynamics</a> using gradients of the data distribution estimated with score matching. The score of each sample \(\mathbf{x}\)’s density probability is defined as its gradient \(\nabla_{\mathbf{x}} \log p(\mathbf{x})\). A score network \(s_\theta: \mathbb{R}^D \to \mathbb{R}^D\) is trained to estimate it. To make it scalable with high-dimensional data in the deep learning setting, they proposed to use either <em>denoising score matching</em> (add a pre-specified small noise to the data; <a href="http://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf">Vincent, 2011</a>) or <em>slided score matching</em> (use random projections; <a href="https://arxiv.org/abs/1905.07088">Yang et al., 2019</a>).</p>
<p>Recall that Langevin dynamics can sample data points from a probability density distribution using only the score \(\nabla_{\mathbf{x}} \log p(\mathbf{x})\) in an iterative process.</p>
<p>However, according to the manifold hypothesis, most of the data is expected to concentrate in a low dimensional manifold, even though the observed data might look only arbitrarily high-dimensional. It brings a negative effect on score estimation since the data points cannot cover the whole space. In regions where data density is low, the score estimation is less reliable. After adding a small Gaussian noise to make the perturbed data distribution cover the full space \(\mathbb{R}^D\), the training of the score estimator network becomes more stable. <a href="https://arxiv.org/abs/1907.05600">Song & Ermon (2019)</a> improved it by perturbing the data with the noise of <em>different levels</em> and train a noise-conditioned score network to <em>jointly</em> estimate the scores of all the perturbed data at different noise levels.</p>
<p>The schedule of increasing noise levels resembles the forward diffusion process.</p>
<h3 id="parameterization-of-beta_t">Parameterization of \(\beta_t\)</h3>
<p>The forward variances are set to be a sequence of linearly increasing constants in <a href="https://arxiv.org/abs/2006.11239">Ho et al. (2020)</a>, from \(\beta_1=10^{-4}\) to \(\beta_T=0.02\). They are relatively small compared to the normalized image pixel values between \([-1, 1]\). Diffusion models in their experiments showed high-quality samples but still could not achieve competitive model log-likelihood as other generative models.</p>
<p><a href="https://arxiv.org/abs/2102.09672">Nichol & Dhariwal (2021)</a> proposed several improvement techniques to help diffusion models to obtain lower NLL. One of the improvements is to use a cosine-based variance schedule. The choice of the scheduling function can be arbitrary, as long as it provides a near-linear drop in the middle of the training process and subtle changes around \(t=0\) and \(t=T\).</p>
\[\beta_t = \text{clip}(1-\frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}}, 0.999) \quad\bar{\alpha}_t = \frac{f(t)}{f(0)}\quad\text{where }f(t)=\cos\Big(\frac{t/T+s}{1+s}\cdot\frac{\pi}{2}\Big)\]
<p>where the small offset \(s\) is to prevent \(\beta_t\) from being too small when close to \(t=0\).</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/diffusion-beta.png" alt="betas" /></p>
<p class="image-caption">Fig. 5. Comparison of linear and cosine-based scheduling of \(\beta_t\) during training. (Image source: <a href="https://arxiv.org/abs/2102.09672">Nichol & Dhariwal, 2021</a>)</p>
<h3 id="parameterization-of-reverse-process-variance-boldsymbolsigma_theta">Parameterization of reverse process variance \(\boldsymbol{\Sigma}_\theta\)</h3>
<p><a href="https://arxiv.org/abs/2006.11239">Ho et al. (2020)</a> chose to fix \(\beta_t\) as constants instead of making them learnable and set \(\boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t) = \sigma^2_t \mathbf{I}\) , where \(\sigma_t\) is not learned but set to \(\beta_t\) or \(\tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t\). Because they found that learning a diagonal variance \(\boldsymbol{\Sigma}_\theta\) leads to unstable training and poorer sample quality.</p>
<p><a href="https://arxiv.org/abs/2102.09672">Nichol & Dhariwal (2021)</a> proposed to learn \(\boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t)\) as an interpolation between \(\beta_t\) and \(\tilde{\beta}_t\) by model predicting a mixing vector \(\mathbf{v}\) :</p>
\[\boldsymbol{\Sigma}_\theta(\mathbf{x}_t, t) = \exp(\mathbf{v} \log \beta_t + (1-\mathbf{v}) \log \tilde{\beta}_t)\]
<p>However, the simple objective \(L_\text{simple}\) does not depend on \(\boldsymbol{\Sigma}_\theta\) . To add the dependency, they constructed a hybrid objective \(L_\text{hybrid} = L_\text{simple} + \lambda L_\text{VLB}\) where \(\lambda=0.001\) is small and stop gradient on \(\boldsymbol{\mu}_\theta\) in the \(L_\text{VLB}\) term such that \(L_\text{VLB}\) only guides the learning of \(\boldsymbol{\Sigma}_\theta\). Empirically they observed that \(L_\text{VLB}\) is pretty challenging to optimize likely due to noisy gradients, so they proposed to use a time-averaging smoothed version of \(L_\text{VLB}\) with importance sampling.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/improved-DDPM-nll.png" alt="Improved DDPM" /></p>
<p class="image-caption">Fig. 6. Comparison of negative log-likelihood of improved DDPM with other likelihood-based generative models. NLL is reported in the unit of bits/dim. (Image source: <a href="https://arxiv.org/abs/2102.09672">Nichol & Dhariwal, 2021</a>)</p>
<h2 id="speed-up-diffusion-model-sampling">Speed up Diffusion Model Sampling</h2>
<p>It is very slow to generate a sample from DDPM by following the Markov chain of the reverse diffusion process, as \(T\) can be up to one or a few thousand steps. One data point from <a href="https://arxiv.org/abs/2010.02502">Song et al. 2020</a>: “For example, it takes around 20 hours to sample 50k images of size 32 × 32 from a DDPM, but less than a minute to do so from a GAN on an Nvidia 2080 Ti GPU.”</p>
<p>One simple way is to run a strided sampling schedule (<a href="https://arxiv.org/abs/2102.09672">Nichol & Dhariwal, 2021</a>) by taking the sampling update every \(\lceil T/S \rceil\) steps to reduce the process from \(T\) to \(S\) steps. The new sampling schedule for generation is \(\{\tau_1, \dots, \tau_S\}\) where \(\tau_1 < \tau_2 < \dots <\tau_S \in [1, T]\) and \(S < T\).</p>
<p>For another approach, let’s rewrite \(q_\sigma(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)\) to be parameterized by a desired standard deviation \(\sigma_t\) according to the <a href="#nice">nice property</a>:</p>
\[\begin{aligned}
\mathbf{x}_{t-1}
&= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1}}\mathbf{z}_{t-1} \\
&= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \mathbf{z}_t + \sigma_t\mathbf{z} \\
&= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}} + \sigma_t\mathbf{z} \\
q_\sigma(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)
&= \mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}}, \sigma_t^2 \mathbf{I})
\end{aligned}\]
<p>Recall that in \(q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\boldsymbol{\mu}}(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t \mathbf{I})\), therefore we have:</p>
\[\tilde{\beta}_t = \sigma_t^2 = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t\]
<p>Let \(\sigma_t^2 = \eta \cdot \tilde{\beta}_t\) such that we can adjust \(\eta \in \mathbb{R}^+\) as a hyperparameter to control the sampling stochasticity. The special case of \(\eta = 0\) makes the sampling process <em>deterministic</em>. Such a model is named the <em>denoising diffusion implicit model</em> (<strong>DDIM</strong>; <a href="https://arxiv.org/abs/2010.02502">Song et al., 2020</a>). DDIM has the same marginal noise distribution but deterministically maps noise back to the original data samples.</p>
<p>During generation, we only sample a subset of \(S\) diffusion steps \(\{\tau_1, \dots, \tau_S\}\) and the inference process becomes:</p>
\[q_{\sigma, \tau}(\mathbf{x}_{\tau_{i-1}} \vert \mathbf{x}_{\tau_t}, \mathbf{x}_0)
= \mathcal{N}(\mathbf{x}_{\tau_{i-1}}; \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \frac{\mathbf{x}_{\tau_i} - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}}, \sigma_t^2 \mathbf{I})\]
<p>While all the models are trained with \(T=1000\) diffusion steps in the experiments, they observed that DDIM (\(\eta=0\)) can produce the best quality samples when \(S\) is small, while DDPM (\(\eta=1\)) performs much worse on small \(S\). DDPM does perform better when we can afford to run the full reverse Markov diffusion steps (\(S=T=1000\)). With DDIM, it is possible to train the diffusion model up to any arbitrary number of forward steps but only sample from a subset of steps in the generative process.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DDIM-results.png" alt="DDIM" /></p>
<p class="image-caption">Fig. 7. FID scores on CIFAR10 and CelebA datasets by diffusion models of different settings, including \(\color{cyan}{\text{DDIM}}\) (\(\eta=0\)) and \(\color{orange}{\text{DDPM}}\) (\(\hat{\sigma}\)). (Image source: <a href="https://arxiv.org/abs/2010.02502">Song et al., 2020</a>)</p>
<p>Compared to DDPM, DDIM is able to:</p>
<ol>
<li>Generate higher-quality samples using a much fewer number of steps.</li>
<li>Have “consistency” property since the generative process is deterministic, meaning that multiple samples conditioned on the same latent variable should have similar high-level features.</li>
<li>Because of the consistency, DDIM can do semantically meaningful interpolation in the latent variable.</li>
</ol>
<h2 id="conditioned-generation">Conditioned Generation</h2>
<p>While training generative models on ImageNet data, it is common to generate samples conditioned on class labels. To explicit incorporate class information into the diffusion process, <a href="https://arxiv.org/abs/2105.05233">Dhariwal & Nichol (2021)</a> trained a classifier \(f_\phi(y \vert \mathbf{x}_t, t)\) on noisy image \(\mathbf{x}_t\) and use gradients \(\nabla_\mathbf{x} \log f_\phi(y \vert \mathbf{x}_t, t)\) to guide the diffusion sampling process toward the target class label \(y\). Their <em>ablated diffusion model</em> (<strong>ADM</strong>) and the one with additional classifier guidance (<strong>ADM-G</strong>) are able to achieve better results than SOTA generative models (e.g. BigGAN).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/conditioned-DDPM.png" alt="Conditioned DDPM" /></p>
<p class="image-caption">Fig. 8. The algorithms use guidance from a classifier to run conditioned generation with DDPM and DDIM. (Image source: <a href="https://arxiv.org/abs/2105.05233">Dhariwal & Nichol, 2021</a>])</p>
<p>Additionally with some modifications on the UNet architecture, <a href="https://arxiv.org/abs/2105.05233">Dhariwal & Nichol (2021)</a> showed performance better than GAN with diffusion models. The architecture modifications include larger model depth/width, more attention heads, multi-resolution attention, BigGAN residual blocks for up/downsampling, residual connection rescale by \(1/\sqrt{2}\) and adaptive group normalization (AdaGN).</p>
<h2 id="quick-summary">Quick Summary</h2>
<ul>
<li>
<p><strong>Pros</strong>: Tractability and flexibility are two conflicting objectives in generative modeling. Tractable models can be analytically evaluated and cheaply fit data (e.g. via a Gaussian or Laplace), but they cannot easily describe the structure in rich datasets. Flexible models can fit arbitrary structures in data, but evaluating, training, or sampling from these models is usually expensive. Diffusion models are both analytically tractable and flexible</p>
</li>
<li>
<p><strong>Cons</strong>: Diffusion models rely on a long Markov chain of diffusion steps to generate samples, so it can be quite expensive in terms of time and compute. New methods have been proposed to make the process much faster, but the sampling is still slower than GAN.</p>
</li>
</ul>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2021diffusion,
title = "What are diffusion models?",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2021",
url = "https://lilianweng.github.io/lil-log/2021/07/11/diffusion-models.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Jascha Sohl-Dickstein et al. <a href="https://arxiv.org/abs/1503.03585">“Deep Unsupervised Learning using Nonequilibrium Thermodynamics.”</a> ICML 2015.</p>
<p>[2] Max Welling & Yee Whye Teh. <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.226.363">“Bayesian learning via stochastic gradient langevin dynamics.”</a> ICML 2011.</p>
<p>[3] Yang Song & Stefano Ermon. <a href="https://arxiv.org/abs/1907.05600">“Generative modeling by estimating gradients of the data distribution.”</a> NeurIPS 2019.</p>
<p>[4] Yang Song & Stefano Ermon. <a href="https://arxiv.org/abs/2006.09011">“Improved techniques for training score-based generative models.”</a> NeuriPS 2020.</p>
<p>[5] Jonathan Ho et al. <a href="https://arxiv.org/abs/2006.11239">“Denoising diffusion probabilistic models.”</a> arxiv Preprint arxiv:2006.11239 (2020). [<a href="https://github.com/hojonathanho/diffusion">code</a>]</p>
<p>[6] Jiaming Song et al. <a href="https://arxiv.org/abs/2010.02502">“Denoising diffusion implicit models.”</a> arxiv Preprint arxiv:2010.02502 (2020). [<a href="https://github.com/ermongroup/ddim">code</a>]</p>
<p>[7] Alex Nichol & Prafulla Dhariwal. <a href="https://arxiv.org/abs/2102.09672">“ Improved denoising diffusion probabilistic models”</a> arxiv Preprint arxiv:2102.09672 (2021). [<a href="https://github.com/openai/improved-diffusion">code</a>]</p>
<p>[8] Praffula Dhariwal & Alex Nichol. <a href="https://arxiv.org/abs/2105.05233">“Diffusion Models Beat GANs on Image Synthesis.”</a> arxiv Preprint arxiv:2105.05233 (2021). [<a href="https://github.com/openai/guided-diffusion">code</a>]</p>Lilian WengDiffusion models are a new type of generative models that are flexible enough to learn any arbitrarily complex data distribution while tractable to analytically evaluate the distribution. It has been shown recently that diffusion models can generate high-quality images and the performance is competitive to SOTA GAN.Contrastive Representation Learning2021-05-31T12:00:00+00:002021-05-31T12:00:00+00:00https://lilianweng.github.io/lil-log/2021/05/31/contrastive-representation-learning<blockquote>
<p>The main idea of contrastive learning is to learn representations such that similar samples stay close to each other, while dissimilar ones are far apart. Contrastive learning can be applied to both supervised and unsupervised data and has been shown to achieve good performance on a variety of vision and language tasks.</p>
</blockquote>
<!--more-->
<p>The goal of contrastive representation learning is to learn such an embedding space in which similar sample pairs stay close to each other while dissimilar ones are far apart. Contrastive learning can be applied to both supervised and unsupervised settings. When working with unsupervised data, contrastive learning is one of the most powerful approaches in <a href="/lil-log/2019/11/10/self-supervised-learning.html">self-supervised learning</a>.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#contrastive-training-objectives" id="markdown-toc-contrastive-training-objectives">Contrastive Training Objectives</a> <ul>
<li><a href="#contrastive-loss" id="markdown-toc-contrastive-loss">Contrastive Loss</a></li>
<li><a href="#triplet-loss" id="markdown-toc-triplet-loss">Triplet Loss</a></li>
<li><a href="#lifted-structured-loss" id="markdown-toc-lifted-structured-loss">Lifted Structured Loss</a></li>
<li><a href="#n-pair-loss" id="markdown-toc-n-pair-loss">N-pair Loss</a></li>
<li><a href="#nce" id="markdown-toc-nce">NCE</a></li>
<li><a href="#infonce" id="markdown-toc-infonce">InfoNCE</a></li>
<li><a href="#soft-nearest-neighbors-loss" id="markdown-toc-soft-nearest-neighbors-loss">Soft-Nearest Neighbors Loss</a></li>
<li><a href="#common-setup" id="markdown-toc-common-setup">Common Setup</a></li>
</ul>
</li>
<li><a href="#key-ingredients" id="markdown-toc-key-ingredients">Key Ingredients</a> <ul>
<li><a href="#heavy-data-augmentation" id="markdown-toc-heavy-data-augmentation">Heavy Data Augmentation</a></li>
<li><a href="#large-batch-size" id="markdown-toc-large-batch-size">Large Batch Size</a></li>
<li><a href="#hard-negative-mining" id="markdown-toc-hard-negative-mining">Hard Negative Mining</a></li>
</ul>
</li>
<li><a href="#vision-image-embedding" id="markdown-toc-vision-image-embedding">Vision: Image Embedding</a> <ul>
<li><a href="#image-augmentations" id="markdown-toc-image-augmentations">Image Augmentations</a></li>
<li><a href="#parallel-augmentation" id="markdown-toc-parallel-augmentation">Parallel Augmentation</a></li>
<li><a href="#memory-bank" id="markdown-toc-memory-bank">Memory Bank</a></li>
<li><a href="#feature-clustering" id="markdown-toc-feature-clustering">Feature Clustering</a></li>
<li><a href="#working-with-supervised-datasets" id="markdown-toc-working-with-supervised-datasets">Working with Supervised Datasets</a></li>
</ul>
</li>
<li><a href="#language-sentence-embedding" id="markdown-toc-language-sentence-embedding">Language: Sentence Embedding</a> <ul>
<li><a href="#text-augmentation" id="markdown-toc-text-augmentation">Text Augmentation</a></li>
<li><a href="#supervision-from-nli" id="markdown-toc-supervision-from-nli">Supervision from NLI</a></li>
<li><a href="#unsupervised-sentence-embedding-learning" id="markdown-toc-unsupervised-sentence-embedding-learning">Unsupervised Sentence Embedding Learning</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="contrastive-training-objectives">Contrastive Training Objectives</h2>
<p>In early versions of loss functions for contrastive learning, only one positive and one negative sample are involved. The trend in recent training objectives is to include multiple positive and negative pairs in one batch.</p>
<h3 id="contrastive-loss">Contrastive Loss</h3>
<p><strong>Contrastive loss</strong> (<a href="http://yann.lecun.com/exdb/publis/pdf/chopra-05.pdf">Chopra et al. 2005</a>) is one of the earliest training objectives used for deep metric learning in a contrastive fashion.</p>
<p>Given a list of input samples \(\{ \mathbf{x}_i \}\), each has a corresponding label \(y_i \in \{1, \dots, L\}\) among \(L\) classes. We would like to learn a function \(f_\theta(.): \mathcal{X}\to\mathbb{R}^d\) that encodes \(x_i\) into an embedding vector such that examples from the same class have similar embeddings and samples from different classes have very different ones. Thus, contrastive loss takes a pair of inputs \((x_i, x_j)\) and minimizes the embedding distance when they are from the same class but maximizes the distance otherwise.</p>
\[\mathcal{L}_\text{cont}(\mathbf{x}_i, \mathbf{x}_j, \theta) = \mathbb{1}[y_i=y_j] \| f_\theta(\mathbf{x}_i) - f_\theta(\mathbf{x}_j) \|^2_2 + \mathbb{1}[y_i\neq y_j]\max(0, \epsilon - \|f_\theta(\mathbf{x}_i) - f_\theta(\mathbf{x}_j)\|_2)^2\]
<p>where \(\epsilon\) is a hyperparameter, defining the lower bound distance between samples of different classes.</p>
<h3 id="triplet-loss">Triplet Loss</h3>
<p><strong>Triplet loss</strong> was originally proposed in the FaceNet (<a href="https://arxiv.org/abs/1503.03832">Schroff et al. 2015</a>) paper and was used to learn face recognition of the same person at different poses and angles.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/triplet-loss.png" alt="Triplet loss" /></p>
<p class="image-caption">Fig. 1. Illustration of triplet loss given one positive and one negative per anchor. (Image source: <a href="https://arxiv.org/abs/1503.03832">Schroff et al. 2015</a>)</p>
<p>Given one anchor input \(\mathbf{x}\), we select one positive sample \(\mathbf{x}^+\) and one negative \(\mathbf{x}^-\), meaning that \(\mathbf{x}^+\) and \(\mathbf{x}\) belong to the same class and \(\mathbf{x}^-\) is sampled from another different class. Triplet loss learns to minimize the distance between the anchor \(\mathbf{x}\) and positive \(\mathbf{x}^+\) and maximize the distance between the anchor \(\mathbf{x}\) and negative \(\mathbf{x}^-\) at the same time with the following equation:</p>
\[\mathcal{L}_\text{triplet}(\mathbf{x}, \mathbf{x}^+, \mathbf{x}^-) = \sum_{\mathbf{x} \in \mathcal{X}} \max\big( 0, \|f(\mathbf{x}) - f(\mathbf{x}^+)\|^2_2 - \|f(\mathbf{x}) - f(\mathbf{x}^-)\|^2_2 + \epsilon \big)\]
<p>where the margin parameter \(\epsilon\) is configured as the minimum offset between distances of similar vs dissimilar pairs.</p>
<p>It is crucial to select challenging \(\mathbf{x}^-\) to truly improve the model.</p>
<h3 id="lifted-structured-loss">Lifted Structured Loss</h3>
<p><strong>Lifted Structured Loss</strong> (<a href="https://arxiv.org/abs/1511.06452">Song et al. 2015</a>) utilizes all the pairwise edges within one training batch for better computational efficiency.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/lifted-structured-loss.png" alt="Lifted structured loss" /></p>
<p class="image-caption">Fig. 2. Illustration compares contrastive loss, triplet loss and lifted structured loss. Red and blue edges connect similar and dissimilar sample pairs respectively. (Image source: <a href="https://arxiv.org/abs/1511.06452">Song et al. 2015</a>)</p>
<p>Let \(D_{ij} = \| f(\mathbf{x}_i) - f(\mathbf{x}_j) \|_2\), a structured loss function is defined as</p>
\[\begin{aligned}
\mathcal{L}_\text{struct} &= \frac{1}{2\vert \mathcal{P} \vert} \sum_{(i,j) \in \mathcal{P}} \max(0, \mathcal{L}_\text{struct}^{(ij)})^2 \\
\text{where } \mathcal{L}_\text{struct}^{(ij)} &= D_{ij} + \color{red}{\max \big( \max_{(i,k)\in \mathcal{N}} \epsilon - D_{ik}, \max_{(j,l)\in \mathcal{N}} \epsilon - D_{jl} \big)}
\end{aligned}\]
<p>where \(\mathcal{P}\) contains the set of positive pairs and \(\mathcal{N}\) is the set of negative pairs. Note that the dense pairwise squared distance matrix can be easily computed per training batch.</p>
<p>The <span color="red">red</span> part in \(\mathcal{L}_\text{struct}^{(ij)}\) is used for mining hard negatives. However, it is not smooth and may cause the convergence to a bad local optimum in practice. Thus, it is relaxed to be:</p>
\[\mathcal{L}_\text{struct}^{(ij)} = D_{ij} + \log \Big( \sum_{(i,k)\in\mathcal{N}} \exp(\epsilon - D_{ik}) + \sum_{(j,l)\in\mathcal{N}} \exp(\epsilon - D_{jl}) \Big)\]
<p>In the paper, they also proposed to enhance the quality of negative samples in each batch by actively incorporating difficult negative samples given a few random positive pairs.</p>
<h3 id="n-pair-loss">N-pair Loss</h3>
<p><strong>Multi-Class N-pair loss</strong> (<a href="https://papers.nips.cc/paper/2016/hash/6b180037abbebea991d8b1232f8a8ca9-Abstract.html">Sohn 2016</a>) generalizes triplet loss to include comparison with multiple negative samples.</p>
<p>Given a \((N + 1)\)-tuplet of training samples, \(\{ \mathbf{x}, \mathbf{x}^+, \mathbf{x}^-_1, \dots, \mathbf{x}^-_{N-1} \}\), including one positive and \(N-1\) negative ones, N-pair loss is defined as:</p>
\[\begin{aligned}
\mathcal{L}_\text{N-pair}(\mathbf{x}, \mathbf{x}^+, \{\mathbf{x}^-_i\}^{N-1}_{i=1})
&= \log\big(1 + \sum_{i=1}^{N-1} \exp(f(\mathbf{x})^\top f(\mathbf{x}^-_i) - f(\mathbf{x})^\top f(\mathbf{x}^+))\big) \\
&= -\log\frac{\exp(f(\mathbf{x})^\top f(\mathbf{x}^+))}{\exp(f(\mathbf{x})^\top f(\mathbf{x}^+)) + \sum_{i=1}^{N-1} \exp(f(\mathbf{x})^\top f(\mathbf{x}^-_i))}
\end{aligned}\]
<p>If we only sample one negative sample per class, it is equivalent to the softmax loss for multi-class classification.</p>
<h3 id="nce">NCE</h3>
<p><strong>Noise Contrastive Estimation</strong>, short for <strong>NCE</strong>, is a method for estimating parameters of a statistical model, proposed by <a href="http://proceedings.mlr.press/v9/gutmann10a.html">Gutmann & Hyvarinen</a> in 2010. The idea is to run logistic regression to tell apart the target data from noise. Read more on how NCE is used for learning word embedding <a href="/lil-log/2017/10/15/learning-word-embedding.html#noise-contrastive-estimation-nce">here</a>.</p>
<p>Let \(\mathbf{x}\) be the target sample \(\sim P(\mathbf{x} \vert C=1; \theta) = p_\theta(\mathbf{x})\) and \(\tilde{\mathbf{x}}\) be the noise sample \(\sim P(\tilde{\mathbf{x}} \vert C=0) = q(\tilde{\mathbf{x}})\). Note that the logistic regression models the logit (i.e. log-odds) and in this case we would like to model the logit of a sample \(u\) from the target data distribution instead of the noise distribution:</p>
\[\ell_\theta(\mathbf{u}) = \log \frac{p_\theta(\mathbf{u})}{q(\mathbf{u})} = \log p_\theta(\mathbf{u}) - \log q(\mathbf{u})\]
<p>After converting logits into probabilities with sigmoid \(\sigma(.)\), we can apply cross entropy loss:</p>
\[\begin{aligned}
\mathcal{L}_\text{NCE} &= - \frac{1}{N} \sum_{i=1}^N \big[ \log \sigma (\ell_\theta(\mathbf{x}_i)) + \log (1 - \sigma (\ell_\theta(\tilde{\mathbf{x}}_i))) \big] \\
\text{ where }\sigma(\ell) &= \frac{1}{1 + \exp(-\ell)} = \frac{p_\theta}{p_\theta + q}
\end{aligned}\]
<p>Here I listed the original form of NCE loss which works with only one positive and one noise sample. In many follow-up works, contrastive loss incorporating multiple negative samples is also broadly referred to as NCE.</p>
<h3 id="infonce">InfoNCE</h3>
<p>The <strong>InfoNCE loss</strong> in CPC (<a href="/lil-log/2019/11/10/self-supervised-learning.html#contrastive-predictive-coding">Contrastive Predictive Coding</a>; <a href="https://arxiv.org/abs/1807.03748">van den Oord, et al. 2018</a>), inspired by <a href="#NCE">NCE</a>, uses categorical cross-entropy loss to identify the positive sample amongst a set of unrelated noise samples.</p>
<p>Given a context vector \(\mathbf{c}\), the positive sample should be drawn from the conditional distribution \(p(\mathbf{x} \vert \mathbf{c})\), while \(N-1\) negative samples are drawn from the proposal distribution \(p(\mathbf{x})\), independent from the context \(\mathbf{c}\). For brevity, let us label all the samples as \(X=\{ \mathbf{x}_i \}^N_{i=1}\) among which only one of them \(\mathbf{x}_\texttt{pos}\) is a positive sample. The probability of we detecting the positive sample correctly is:</p>
\[p(C=\texttt{pos} \vert X, \mathbf{c})
= \frac{p(x_\texttt{pos} \vert \mathbf{c}) \prod_{i=1,\dots,N; i \neq \texttt{pos}} p(\mathbf{x}_i)}{\sum_{j=1}^N \big[ p(\mathbf{x}_j \vert \mathbf{c}) \prod_{i=1,\dots,N; i \neq j} p(\mathbf{x}_i) \big]}
= \frac{ \frac{p(\mathbf{x}_\texttt{pos}\vert c)}{p(\mathbf{x}_\texttt{pos})} }{ \sum_{j=1}^N \frac{p(\mathbf{x}_j\vert \mathbf{c})}{p(\mathbf{x}_j)} }
= \frac{f(\mathbf{x}_\texttt{pos}, \mathbf{c})}{ \sum_{j=1}^N f(\mathbf{x}_j, \mathbf{c}) }\]
<p>where the scoring function is \(f(\mathbf{x}, \mathbf{c}) \propto \frac{p(\mathbf{x}\vert\mathbf{c})}{p(\mathbf{x})}\).</p>
<p>The InfoNCE loss optimizes the negative log probability of classifying the positive sample correctly:</p>
\[\mathcal{L}_\text{InfoNCE} = - \mathbb{E} \Big[\log \frac{f(\mathbf{x}, \mathbf{c})}{\sum_{\mathbf{x}' \in X} f(\mathbf{x}', \mathbf{c})} \Big]\]
<p>The fact that \(f(x, c)\) estimates the density ratio \(\frac{p(x\vert c)}{p(x)}\) has a connection with mutual information optimization. To maximize the the mutual information between input \(x\) and context vector \(c\), we have:</p>
\[I(\mathbf{x}; \mathbf{c}) = \sum_{\mathbf{x}, \mathbf{c}} p(\mathbf{x}, \mathbf{c}) \log\frac{p(\mathbf{x}, \mathbf{c})}{p(\mathbf{x})p(\mathbf{c})} = \sum_{\mathbf{x}, \mathbf{c}} p(\mathbf{x}, \mathbf{c})\log\color{blue}{\frac{p(\mathbf{x}|\mathbf{c})}{p(\mathbf{x})}}\]
<p>where the logarithmic term in <span color="blue">blue</span> is estimated by \(f\).</p>
<p>For sequence prediction tasks, rather than modeling the future observations \(p_k(\mathbf{x}_{t+k} \vert \mathbf{c}_t)\) directly (which could be fairly expensive), CPC models a density function to preserve the mutual information between \(\mathbf{x}_{t+k}\) and \(\mathbf{c}_t\):</p>
\[f_k(\mathbf{x}_{t+k}, \mathbf{c}_t) = \exp(\mathbf{z}_{t+k}^\top \mathbf{W}_k \mathbf{c}_t) \propto \frac{p(\mathbf{x}_{t+k}\vert\mathbf{c}_t)}{p(\mathbf{x}_{t+k})}\]
<p>where \(\mathbf{z}_{t+k}\) is the encoded input and \(\mathbf{W}_k\) is a trainable weight matrix.</p>
<h3 id="soft-nearest-neighbors-loss">Soft-Nearest Neighbors Loss</h3>
<p><strong>Soft-Nearest Neighbors Loss</strong> (<a href="http://proceedings.mlr.press/v2/salakhutdinov07a.html">Salakhutdinov & Hinton 2007</a>, <a href="https://arxiv.org/abs/1902.01889">Frosst et al. 2019</a>) extends it to include multiple positive samples.</p>
<p>Given a batch of samples, \(\{\mathbf{x}_i, y_i)\}^B_{i=1}\) where \(y_i\) is the class label of \(\mathbf{x}_i\) and a function \(f(.,.)\) for measuring similarity between two inputs, the soft nearest neighbor loss at temperature \(\tau\) is defined as:</p>
\[\mathcal{L}_\text{snn} = -\frac{1}{B}\sum_{i=1}^B \log \frac{\sum_{i\neq j, y_i = y_j, j=1,\dots,B} \exp(- f(\mathbf{x}_i, \mathbf{x}_j) / \tau)}{\sum_{i\neq k, k=1,\dots,B} \exp(- f(\mathbf{x}_i, \mathbf{x}_k) /\tau)}\]
<p>The temperature \(\tau\) is used for tuning how concentrated the features are in the representation space. For example, when at low temperature, the loss is dominated by the small distances and widely separated representations cannot contribute much and become irrelevant.</p>
<h3 id="common-setup">Common Setup</h3>
<p>We can loosen the definition of “classes” and “labels” in soft nearest-neighbor loss to create positive and negative sample pairs out of unsupervised data by, for example, applying data augmentation to create noise versions of original samples.</p>
<p>Most recent studies follow the following definition of contrastive learning objective to incorporate multiple positive and negative samples. According to the setup in (<a href="https://arxiv.org/abs/2005.10242">Wang & Isola 2020</a>), let \(p_\texttt{data}(.)\) be the data distribution over \(\mathbb{R}^n\) and \(p_\texttt{pos}(., .)\) be the distribution of positive pairs over \(\mathbb{R}^{n \times n}\). These two distributions should satisfy:</p>
<ul>
<li>Symmetry: \(\forall \mathbf{x}, \mathbf{x}^+, p_\texttt{pos}(\mathbf{x}, \mathbf{x}^+) = p_\texttt{pos}(\mathbf{x}^+, \mathbf{x})\)</li>
<li>Matching marginal: \(\forall \mathbf{x}, \int p_\texttt{pos}(\mathbf{x}, \mathbf{x}^+) d\mathbf{x}^+ = p_\texttt{data}(\mathbf{x})\)</li>
</ul>
<p>To learn an encoder \(f(\mathbf{x})\) to learn a <em>L2-normalized feature vector</em>, the contrastive learning objective is:</p>
\[\begin{aligned}
\mathcal{L}_\text{contrastive}
&= \mathbb{E}_{(\mathbf{x},\mathbf{x}^+)\sim p_\texttt{pos}, \{\mathbf{x}^-_i\}^M_{i=1} \overset{\text{i.i.d}}{\sim} p_\texttt{data} } \Big[ -\log\frac{\exp(f(\mathbf{x})^\top f(\mathbf{x}^+) / \tau)}{ \exp(f(\mathbf{x})^\top f(\mathbf{x}^+) / \tau) + \sum_{i=1}^M \exp(f(\mathbf{x})^\top f(\mathbf{x}_i^-) / \tau)} \Big] & \\
&\approx \mathbb{E}_{(\mathbf{x},\mathbf{x}^+)\sim p_\texttt{pos}, \{\mathbf{x}^-_i\}^M_{i=1} \overset{\text{i.i.d}}{\sim} p_\texttt{data} }\Big[ - f(\mathbf{x})^\top f(\mathbf{x}^+) / \tau + \log\big(\sum_{i=1}^M \exp(f(\mathbf{x})^\top f(\mathbf{x}_i^-) / \tau)\big) \Big] & \scriptstyle{\text{; Assuming infinite negatives}} \\
&= -\frac{1}{\tau}\mathbb{E}_{(\mathbf{x},\mathbf{x}^+)\sim p_\texttt{pos}}f(\mathbf{x})^\top f(\mathbf{x}^+) + \mathbb{E}_{ \mathbf{x} \sim p_\texttt{data}} \Big[ \log \mathbb{E}_{\mathbf{x}^- \sim p_\texttt{data}} \big[ \sum_{i=1}^M \exp(f(\mathbf{x})^\top f(\mathbf{x}_i^-) / \tau)\big] \Big] &
\end{aligned}\]
<h2 id="key-ingredients">Key Ingredients</h2>
<h3 id="heavy-data-augmentation">Heavy Data Augmentation</h3>
<p>Given a training sample, data augmentation techniques are needed for creating noise versions of itself to feed into the loss as positive samples. Proper data augmentation setup is critical for learning good and generalizable embedding features. It introduces the non-essential variations into examples without modifying semantic meanings and thus encourages the model to learn the essential part of the representation. For example, experiments in <a href="#simclr">SimCLR</a> showed that the composition of random cropping and random color distortion is crucial for good performance on learning visual representation of images.</p>
<h3 id="large-batch-size">Large Batch Size</h3>
<p>Using a large batch size during training is another key ingredient in the success of many contrastive learning methods (e.g. <a href="#simclr">SimCLR</a>, <a href="#clip">CLIP</a>), especially when it relies on in-batch negatives. Only when the batch size is big enough, the loss function can cover a diverse enough collection of negative samples, challenging enough for the model to learn meaningful representation to distinguish different examples.</p>
<h3 id="hard-negative-mining">Hard Negative Mining</h3>
<p>Hard negative samples should have different labels from the anchor sample, but have embedding features very close to the anchor embedding. With access to ground truth labels in supervised datasets, it is easy to identify task-specific hard negatives. For example when learning sentence embedding, we can treat sentence pairs labelled as “contradiction” in NLI datasets as hard negative pairs (e.g. <a href="#dropout-and-cutoff">SimCSE</a>, or use top incorrect candidates returned by BM25 with most keywords matched as hard negative samples (<a href="/lil-log/2020/10/29/open-domain-question-answering.html#DPR">DPR</a>; <a href="https://arxiv.org/abs/2004.04906">Karpukhin et al., 2020</a>).</p>
<p>However, it becomes tricky to do hard negative mining when we want to remain unsupervised. Increasing training batch size or <a href="#memory-bank">memory bank</a> size implicitly introduces more hard negative samples, but it leads to a heavy burden of large memory usage as a side effect.</p>
<p><a href="https://arxiv.org/abs/2007.00224">Chuang et al. (2020)</a> studied the sampling bias in contrastive learning and proposed debiased loss. In the unsupervised setting, since we do not know the ground truth labels, we may accidentally sample false negative samples. Sampling bias can lead to significant performance drop.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/contrastive-sampling-bias.png" alt="Sampling bias" /></p>
<p class="image-caption"><em>Fig. 3. Sampling bias which refers to false negative samples in contrastive learning can lead to a big performance drop. (Image source: <a href="https://arxiv.org/abs/2007.00224">Chuang et al., 2020</a>)</em></p>
<p>Let us assume the probability of anchor class \(c\) is uniform \(\rho(c)=\eta^+\) and the probability of observing a different class is \(\eta^- = 1-\eta^+\).</p>
<ul>
<li>The probability of observing a positive example for \(\mathbf{x}\) is \(p^+_x(\mathbf{x}')=p(\mathbf{x}'\vert \mathbf{h}_{x'}=\mathbf{h}_x)\);</li>
<li>The probability of getting a negative sample for \(\mathbf{x}\) is \(p^-_x(\mathbf{x}')=p(\mathbf{x}'\vert \mathbf{h}_{x'}\neq\mathbf{h}_x)\).</li>
</ul>
<p>When we are sampling \(\mathbf{x}^-\) , we cannot access the true \(p^-_x(\mathbf{x}^-)\) and thus \(\mathbf{x}^-\) may be sampled from the (undesired) anchor class \(c\) with probability \(\eta^+\). The actual sampling data distribution becomes:</p>
\[p(\mathbf{x}') = \eta^+ p^+_x(\mathbf{x}') + \eta^- p_x^-(\mathbf{x}')\]
<p>Thus we can use \(p^-_x(\mathbf{x}') = (p(\mathbf{x}') - \eta^+ p^+_x(\mathbf{x}'))/\eta^-\) for sampling \(\mathbf{x}^-\) to debias the loss. With \(N\) samples \(\{\mathbf{u}_i\}^N_{i=1}\) from \(p\) and \(M\) samples \(\{ \mathbf{v}_i \}_{i=1}^M\) from \(p^+_x\) , we can estimate the expectation of the second term \(\mathbb{E}_{\mathbf{x}^-\sim p^-_x}[\exp(f(\mathbf{x})^\top f(\mathbf{x}^-))]\) in the denominator of contrastive learning loss:</p>
\[g(\mathbf{x}, \{\mathbf{u}_i\}^N_{i=1}, \{\mathbf{v}_i\}_{i=1}^M) = \max\Big\{ \frac{1}{\eta^-}\Big( \frac{1}{N}\sum_{i=1}^N \exp(f(\mathbf{x})^\top f(\mathbf{u}_i)) - \frac{\eta^+}{M}\sum_{i=1}^M \exp(f(\mathbf{x})^\top f(\mathbf{v}_i)) \Big), \exp(-1/\tau) \Big\}\]
<p>where \(\tau\) is the temperature and \(\exp(-1/\tau)\) is the theoretical lower bound of \(\mathbb{E}_{\mathbf{x}^-\sim p^-_x}[\exp(f(\mathbf{x})^\top f(\mathbf{x}^-))]\).</p>
<p>The final debiased contrastive loss looks like:</p>
\[\mathcal{L}^{N,M}_\text{debias}(f) = \mathbb{E}_{\mathbf{x},\{\mathbf{u}_i\}^N_{i=1}\sim p;\;\mathbf{x}^+, \{\mathbf{v}_i\}_{i=1}^M\sim p^+} \Big[ -\log\frac{\exp(f(\mathbf{x})^\top f(\mathbf{x}^+)}{\exp(f(\mathbf{x})^\top f(\mathbf{x}^+) + N g(x,\{\mathbf{u}_i\}^N_{i=1}, \{\mathbf{v}_i\}_{i=1}^M)} \Big]\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/contrastive-debias-t-SNE.png" alt="Debiased t-SNE vis" /></p>
<p class="image-caption"><em>Fig. 4. t-SNE visualization of learned representation with debiased contrastive learning. (Image source: <a href="https://arxiv.org/abs/2007.00224">Chuang et al., 2020</a>)</em></p>
<p>Following the above annotation, <a href="https://arxiv.org/abs/2010.04592">Robinson et al. (2021)</a> modified the sampling probabilities to target at hard negatives by up-weighting the probability \(p^-_x(x')\) to be proportional to its similarity to the anchor sample. The new sampling probability \(q_\beta(x^-)\) is:</p>
\[q_\beta(\mathbf{x}^-) \propto \exp(\beta f(\mathbf{x})^\top f(\mathbf{x}^-)) \cdot p(\mathbf{x}^-)\]
<p>where \(\beta\) is a hyperparameter to tune.</p>
<p>We can estimate the second term in the denominator \(\mathbb{E}_{\mathbf{x}^- \sim q_\beta} [\exp(f(\mathbf{x})^\top f(\mathbf{x}^-))]\) using importance sampling where both the partition functions \(Z_\beta, Z^+_\beta\) can be estimated empirically.</p>
\[\begin{aligned}
\mathbb{E}_{\mathbf{u} \sim q_\beta} [\exp(f(\mathbf{x})^\top f(\mathbf{u}))] &= \mathbb{E}_{\mathbf{u} \sim p} [\frac{q_\beta}{p}\exp(f(\mathbf{x})^\top f(\mathbf{u}))] = \mathbb{E}_{\mathbf{u} \sim p} [\frac{1}{Z_\beta}\exp((\beta + 1)f(\mathbf{x})^\top f(\mathbf{u}))] \\
\mathbb{E}_{\mathbf{v} \sim q^+_\beta} [\exp(f(\mathbf{x})^\top f(\mathbf{v}))] &= \mathbb{E}_{\mathbf{v} \sim p^+} [\frac{q^+_\beta}{p}\exp(f(\mathbf{x})^\top f(\mathbf{v}))] = \mathbb{E}_{\mathbf{v} \sim p} [\frac{1}{Z^+_\beta}\exp((\beta + 1)f(\mathbf{x})^\top f(\mathbf{v}))]
\end{aligned}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/contrastive-hard-negatives-code.png" alt="Pseudo code" /></p>
<p class="image-caption"><em>Fig. 5. Pseudo code for computing NCE loss, debiased contrastive loss, and hard negative sample objective when setting \(M=1\). (Image source: <a href="https://arxiv.org/abs/2010.04592">Robinson et al., 2021</a> )</em></p>
<h2 id="vision-image-embedding">Vision: Image Embedding</h2>
<h3 id="image-augmentations">Image Augmentations</h3>
<p>Most approaches for contrastive representation learning in the vision domain rely on creating a noise version of a sample by applying a sequence of data augmentation techniques. The augmentation should significantly change its visual appearance but keep the semantic meaning unchanged.</p>
<h4 id="basic-image-augmentation">Basic Image Augmentation</h4>
<p>There are many ways to modify an image while retaining its semantic meaning. We can use any one of the following augmentation or a composition of multiple operations.</p>
<ul>
<li>Random cropping and then resize back to the original size.</li>
<li>Random color distortions</li>
<li>Random Gaussian blur</li>
<li>Random color jittering</li>
<li>Random horizontal flip</li>
<li>Random grayscale conversion</li>
<li>Multi-crop augmentation: Use two standard resolution crops and sample a set of additional low resolution crops that cover only small parts of the image. Using low resolution crops reduces the compute cost. (<a href="#swav">SwAV</a>)</li>
<li>And many more …</li>
</ul>
<h4 id="augmentation-strategies">Augmentation Strategies</h4>
<p>Many frameworks are designed for learning good data augmentation strategies (i.e. a composition of multiple transforms). Here are a few common ones.</p>
<ul>
<li><a href="/lil-log/2019/05/05/domain-randomization.html#AutoAugment">AutoAugment</a> (<a href="https://arxiv.org/abs/1805.09501">Cubuk, et al. 2018</a>): Inspired by <a href="/lil-log/2020/08/06/neural-architecture-search.html">NAS</a>, AutoAugment frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem and looks for the combination that leads to the highest accuracy on the evaluation set.</li>
<li>RandAugment (<a href="https://arxiv.org/abs/1909.13719">Cubuk et al., 2019</a>): RandAugment greatly reduces the search space of AutoAugment by controlling the magnitudes of different transformation operations with a single magnitude parameter.</li>
<li>PBA (Population based augmentation; <a href="https://arxiv.org/abs/1905.05393">Ho et al., 2019</a>): PBA combined PBT (<a href="https://arxiv.org/abs/1711.09846">Jaderberg et al, 2017</a>) with AutoAugment, using the evolutionary algorithm to train a population of children models in parallel to evolve the best augmentation strategies.</li>
<li>UDA (Unsupervised Data Augmentation; <a href="https://arxiv.org/abs/1904.12848">Xie et al., 2019</a>): Among a set of possible augmentation strategies, UDA selects those to minimize the KL divergence between the predicted distribution over an unlabelled example and its unlabelled augmented version.</li>
</ul>
<h4 id="image-mixture">Image Mixture</h4>
<p>Image mixture methods can construct new training examples from existing data points.</p>
<ul>
<li>Mixup (<a href="https://arxiv.org/abs/1710.09412">Zhang et al., 2018</a>): It runs global-level mixture by creating a weighted pixel-wise combination of two existing images \(I_1\) and \(I_2\): \(I_\text{mixup} \gets \alpha I_1 + (1-\alpha) I_2\) and \(\alpha \in [0, 1]\).</li>
<li>Cutmix (<a href="https://arxiv.org/abs/1905.04899">Yun et al., 2019</a>): Cutmix does region-level mixture by generating a new example by combining a local region of one image with the rest of the other image. \(I_\text{cutmix} \gets \mathbf{M}_b \odot I_1 + (1-\mathbf{M}_b) \odot I_2\), where \(\mathbf{M}_b \in \{0, 1\}^I\) is a binary mask and \(\odot\) is element-wise multiplication. It is equivalent to filling the cutout (<a href="https://arxiv.org/abs/1708.04552">DeVries & Taylor 2017</a>) region with the same region from another image.</li>
<li>MoCHi (“Mixing of Contrastive Hard Negatives”; <a href="https://arxiv.org/abs/2010.01028">Kalantidis et al. 2020</a>): Given a query \(\mathbf{q}\), MoCHi maintains a queue of \(K\) negative features \(Q=\{\mathbf{n}_1, \dots, \mathbf{n}_K \}\) and sorts these negative features by similarity to the query, \(\mathbf{q}^\top \mathbf{n}\), in descending order. The first \(N\) items in the queue are considered as the hardest negatives, \(Q^N\). Then synthetic hard examples can be generated by \(\mathbf{h} = \tilde{\mathbf{h}} / \|\tilde{\mathbf{h}}\|\) where \(\tilde{\mathbf{h}} = \alpha\mathbf{n}_i + (1-\alpha) \mathbf{n}_j\) and \(\alpha \in (0, 1)\). Even harder examples can be created by mixing with the query feature, \(\mathbf{h}' = \tilde{\mathbf{h}'} / \|\tilde{\mathbf{h}'}\|_2\) where \(\tilde{\mathbf{h}'} = \beta\mathbf{q} + (1-\beta) \mathbf{n}_j\) and \(\beta \in (0, 0.5)\).</li>
</ul>
<h3 id="parallel-augmentation">Parallel Augmentation</h3>
<p>This category of approaches produce two noise versions of one anchor image and aim to learn representation such that these two augmented samples share the same embedding.</p>
<h4 id="simclr">SimCLR</h4>
<p><strong>SimCLR</strong> (<a href="https://arxiv.org/abs/2002.05709">Chen et al, 2020</a>) proposed a simple framework for contrastive learning of visual representations. It learns representations for visual inputs by maximizing agreement between differently augmented views of the same sample via a contrastive loss in the latent space.</p>
<p style="width: 45%;" class="center"><img src="/lil-log/assets/images/SimCLR.png" alt="SimCLR" /></p>
<p class="image-caption"><em>Fig. 6. A simple framework for contrastive learning of visual representations. (Image source: <a href="https://arxiv.org/abs/2002.05709">Chen et al, 2020</a>)</em></p>
<p>1) Randomly sample a minibatch of \(N\) samples and each sample is applied with two different data augmentation operations, resulting in \(2N\) augmented samples in total.</p>
\[\tilde{\mathbf{x}}_i = t(\mathbf{x}),\quad\tilde{\mathbf{x}}_j = t'(\mathbf{x}),\quad t, t' \sim \mathcal{T}\]
<p>where two separate data augmentation operators, \(t\) and \(t'\), are sampled from the same family of augmentations \(\mathcal{T}\). Data augmentation includes random crop, resize with random flip, color distortions, and Gaussian blur.</p>
<p>2) Given one positive pair, other \(2(N-1)\) data points are treated as negative samples. The representation is produced by a base encoder \(f(.)\):</p>
\[\mathbf{h}_i = f(\tilde{\mathbf{x}}_i),\quad \mathbf{h}_j = f(\tilde{\mathbf{x}}_j)\]
<p>3) The contrastive learning loss is defined using cosine similarity \(\text{sim}(.,.)\). Note that the loss operates on an extra projection layer of the representation \(g(.)\) rather than on the representation space directly. But only the representation \(\mathbf{h}\) is used for downstream tasks.</p>
\[\begin{aligned}
\mathbf{z}_i &= g(\mathbf{h}_i),\quad
\mathbf{z}_j = g(\mathbf{h}_j) \\
\mathcal{L}_\text{SimCLR}^{(i,j)} &= - \log\frac{\exp(\text{sim}(\mathbf{z}_i, \mathbf{z}_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(\mathbf{z}_i, \mathbf{z}_k) / \tau)}
\end{aligned}\]
<p>where \(\mathbb{1}_{[k \neq i]}\) is an indicator function: 1 if \(k\neq i\) 0 otherwise.</p>
<p>SimCLR needs a large batch size to incorporate enough negative samples to achieve good performance.</p>
<p style="width: 55%;" class="center"><img src="/lil-log/assets/images/SimCLR-algo.png" alt="SimCLR Algorithm" /></p>
<p class="image-caption"><em>Fig. 7. The algorithm for SimCLR. (Image source: <a href="https://arxiv.org/abs/2002.05709">Chen et al, 2020</a>).</em></p>
<h4 id="barlow-twins">Barlow Twins</h4>
<p><strong>Barlow Twins</strong> (<a href="https://arxiv.org/abs/2103.03230">Zbontar et al. 2021</a>) feeds two distorted versions of samples into the same network to extract features and learns to make the <em>cross-correlation matrix</em> between these two groups of output features close to the identity. The goal is to keep the representation vectors of different distorted versions of one sample similar, while minimizing the redundancy between these vectors.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/barlow-twins.png" alt="Barlow twins" /></p>
<p class="image-caption"><em>Fig. 8. Illustration of Barlow Twins learning pipeline. (Image source: <a href="https://arxiv.org/abs/2103.03230">Zbontar et al. 2021</a>).</em></p>
<p>Let \(\mathcal{C}\) be a cross-correlation matrix computed between outputs from two identical networks along the batch dimension. \(\mathcal{C}\) is a square matrix with the size same as the feature network’s output dimensionality. Each entry in the matrix \(\mathcal{C}_{ij}\) is the cosine similarity between network output vector dimension at index \(i, j\) and batch index \(b\), \(\mathbf{z}_{b,i}^A\) and \(\mathbf{z}_{b,j}^B\), with a value between -1 (i.e. perfect anti-correlation) and 1 (i.e. perfect correlation).</p>
\[\begin{aligned}
\mathcal{L}_\text{BT} &= \underbrace{\sum_i (1-\mathcal{C}_{ii})^2}_\text{invariance term} + \lambda \underbrace{\sum_i\sum_{i\neq j} \mathcal{C}_{ij}^2}_\text{redundancy reduction term} \\ \text{where } \mathcal{C}_{ij} &= \frac{\sum_b \mathbf{z}^A_{b,i} \mathbf{z}^B_{b,j}}{\sqrt{\sum_b (\mathbf{z}^A_{b,i})^2}\sqrt{\sum_b (\mathbf{z}^B_{b,j})^2}}
\end{aligned}\]
<p>Barlow Twins is competitive with SOTA methods for self-supervised learning. It naturally avoids trivial constants (i.e. collapsed representations), and is robust to different training batch sizes.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/barlow-twins-algo.png" alt="Barlow twins algo" /></p>
<p class="image-caption"><em>Fig. 9. Algorithm of Barlow Twins in Pytorch style pseudo code. (Image source: <a href="https://arxiv.org/abs/2103.03230">Zbontar et al. 2021</a>).</em></p>
<h4 id="byol">BYOL</h4>
<p>Different from the above approaches, interestingly, <strong>BYOL</strong> (Bootstrap Your Own Latent; <a href="https://arxiv.org/abs/2006.07733">Grill, et al 2020</a>) claims to achieve a new state-of-the-art results <em>without using egative samples</em>. It relies on two neural networks, referred to as <em>online</em> and <em>target</em> networks that interact and learn from each other. The target network (parameterized by \(\xi\)) has the same architecture as the online one (parameterized by \(\theta\)), but with polyak averaged weights, \(\xi \leftarrow \tau \xi + (1-\tau) \theta\).</p>
<p>The goal is to learn a presentation \(y\) that can be used in downstream tasks. The online network parameterized by \(\theta\) contains:</p>
<ul>
<li>An encoder \(f_\theta\);</li>
<li>A projector \(g_\theta\);</li>
<li>A predictor \(q_\theta\).</li>
</ul>
<p>The target network has the same network architecture, but with different parameter \(\xi\), updated by polyak averaging \(\theta\): \(\xi \leftarrow \tau \xi + (1-\tau) \theta\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/BYOL.png" alt="BYOL" /></p>
<p class="image-caption"><em>Fig. 10. The model architecture of BYOL. After training, we only care about \(f_\theta\) for producing representation, \(y=f_\theta(x)\), and everything else is discarded. \(\text{sg}\) means stop gradient. (Image source: <a href="https://arxiv.org/abs/2006.07733">Grill, et al 2020</a>)</em></p>
<p>Given an image \(\mathbf{x}\), the BYOL loss is constructed as follows:</p>
<ul>
<li>Create two augmented views: \(\mathbf{v}=t(\mathbf{x}); \mathbf{v}'=t'(\mathbf{x})\) with augmentations sampled \(t \sim \mathcal{T}, t' \sim \mathcal{T}'\);</li>
<li>Then they are encoded into representations, \(\mathbf{y}_\theta=f_\theta(\mathbf{v}), \mathbf{y}'=f_\xi(\mathbf{v}')\);</li>
<li>Then they are projected into latent variables, \(\mathbf{z}_\theta=g_\theta(\mathbf{y}_\theta), \mathbf{z}'=g_\xi(\mathbf{y}')\);</li>
<li>The online network outputs a prediction \(q_\theta(\mathbf{z}_\theta)\);</li>
<li>Both \(q_\theta(\mathbf{z}_\theta)\) and \(\mathbf{z}'\) are L2-normalized, giving us \(\bar{q}_\theta(\mathbf{z}_\theta) = q_\theta(\mathbf{z}_\theta) / \| q_\theta(\mathbf{z}_\theta) \|\) and \(\bar{\mathbf{z}'} = \mathbf{z}' / \|\mathbf{z}'\|\);</li>
<li>The loss \(\mathcal{L}^\text{BYOL}_\theta\) is MSE between L2-normalized prediction \(\bar{q}_\theta(\mathbf{z})\) and \(\bar{\mathbf{z}'}\);</li>
<li>The other symmetric loss \(\tilde{\mathcal{L}}^\text{BYOL}_\theta\) can be generated by switching \(\mathbf{v}'\) and \(\mathbf{v}\); that is, feeding \(\mathbf{v}'\) to online network and \(\mathbf{v}\) to target network.</li>
<li>The final loss is \(\mathcal{L}^\text{BYOL}_\theta + \tilde{\mathcal{L}}^\text{BYOL}_\theta\) and only parameters \(\theta\) are optimized.</li>
</ul>
<p>Unlike most popular contrastive learning based approaches, BYOL does not use negative pairs. Most bootstrapping approaches rely on pseudo-labels or cluster indices, but BYOL directly boostrapps the latent representation.</p>
<p>It is quite interesting and surprising that <em>without</em> negative samples, BYOL still works well. Later I ran into this <a href="https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html">post</a> by Abe Fetterman & Josh Albrecht, they highlighted two surprising findings while they were trying to reproduce BYOL:</p>
<ol>
<li>BYOL generally performs no better than random when <em>batch normalization is removed</em>.</li>
<li>The presence of batch normalization implicitly causes a form of contrastive learning.
They believe that using negative samples is important for avoiding model collapse (i.e. what if you use all-zeros representation for every data point?). Batch normalization injects dependency on negative samples <em>inexplicitly</em> because no matter how similar a batch of inputs are, the values are re-distributed (spread out \(\sim \mathcal{N}(0, 1\)) and therefore batch normalization prevents model collapse. Strongly recommend you to read the <a href="https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html">full article</a> if you are working in this area.</li>
</ol>
<h3 id="memory-bank">Memory Bank</h3>
<p>Computing embeddings for a large number of negative samples in every batch is extremely expensive. One common approach is to store the representation in memory to trade off data staleness for cheaper compute.</p>
<h4 id="instance-discrimination-with-memoy-bank">Instance Discrimination with Memoy Bank</h4>
<p><strong>Instance contrastive learning</strong> (<a href="https://arxiv.org/abs/1805.01978v1">Wu et al, 2018</a>) pushes the class-wise supervision to the extreme by considering each instance as <em>a distinct class of its own</em>. It implies that the number of “classes” will be the same as the number of samples in the training dataset. Hence, it is unfeasible to train a softmax layer with these many heads, but instead it can be approximated by <a href="#nce">NCE</a>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/instance-level-discrimination.png" alt="Instance contrastive learning" /></p>
<p class="image-caption"><em>Fig. 11. The training pipeline of instance-level contrastive learning. The learned embedding is L2-normalized. (Image source: <a href="https://arxiv.org/abs/1805.01978v1">Wu et al, 2018</a>)</em></p>
<p>Let \(\mathbf{v} = f_\theta(x)\) be an embedding function to learn and the vector is normalized to have \(\|\mathbf{v}\|=1\). A non-parametric classifier predicts the probability of a sample \(\mathbf{v}\) belonging to class \(i\) with a temperature parameter \(\tau\):</p>
\[P(C=i\vert \mathbf{v}) = \frac{\exp(\mathbf{v}_i^\top \mathbf{v} / \tau)}{\sum_{j=1}^n \exp(\mathbf{v}_j^\top \mathbf{v} / \tau)}\]
<p>Instead of computing the representations for all the samples every time, they implement an <strong>Memory Bank</strong> for storing sample representation in the database from past iterations. Let \(V=\{ \mathbf{v}_i \}\) be the memory bank and \(\mathbf{f}_i = f_\theta(\mathbf{x}_i)\) be the feature generated by forwarding the network. We can use the representation from the memory bank \(\mathbf{v}_i\) instead of the feature forwarded from the network \(\mathbf{f}_i\) when comparing pairwise similarity.</p>
<p>The denominator theoretically requires access to the representations of all the samples, but that is too expensive in practice. Instead we can estimate it via Monte Carlo approximation using a random subset of \(M\) indices \(\{j_k\}_{k=1}^M\).</p>
\[P(i\vert \mathbf{v})
= \frac{\exp(\mathbf{v}^\top \mathbf{f}_i / \tau)}{\sum_{j=1}^N \exp(\mathbf{v}_j^\top \mathbf{f}_i / \tau)}
\simeq \frac{\exp(\mathbf{v}^\top \mathbf{f}_i / \tau)}{\frac{N}{M} \sum_{k=1}^M \exp(\mathbf{v}_{j_k}^\top \mathbf{f}_i / \tau)}\]
<p>Because there is only one instance per class, the training is unstable and fluctuates a lot. To improve the training smoothness, they introduced an extra term for positive samples in the loss function based on the <a href="https://web.stanford.edu/~boyd/papers/prox_algs.html">proximal optimization method</a>. The final NCE loss objective looks like:</p>
\[\begin{aligned}
\mathcal{L}_\text{instance} &= - \mathbb{E}_{P_d}\big[\log h(i, \mathbf{v}^{(t-1)}_i) - \lambda \|\mathbf{v}^{(t)}_i - \mathbf{v}^{(t-1)}_i\|^2_2\big] - M\mathbb{E}_{P_n}\big[\log(1 - h(i, \mathbf{v}'^{(t-1)})\big] \\
h(i, \mathbf{v}) &= \frac{P(i\vert\mathbf{v})}{P(i\vert\mathbf{v}) + MP_n(i)} \text{ where the noise distribution is uniform }P_n = 1/N
\end{aligned}\]
<p>where \(\{ \mathbf{v}^{(t-1)} \}\) are embeddings stored in the memory bank from the previous iteration. The difference between iterations \(\|\mathbf{v}^{(t)}_i - \mathbf{v}^{(t-1)}_i\|^2_2\) will gradually vanish as the learned embedding converges.</p>
<h4 id="moco--moco-v2">MoCo & MoCo-V2</h4>
<p><strong>Momentum Contrast</strong> (<strong>MoCo</strong>; <a href="https://arxiv.org/abs/1911.05722">He et al, 2019</a>) provides a framework of unsupervised learning visual representation as a <em>dynamic dictionary look-up</em>. The dictionary is structured as a large FIFO queue of encoded representations of data samples.</p>
<p>Given a query sample \(\mathbf{x}_q\), we get a query representation through an encoder \(\mathbf{q} = f_q(\mathbf{x}_q)\). A list of key representations \(\{\mathbf{k}_1, \mathbf{k}_2, \dots \}\) in the dictionary are encoded by a momentum encoder \(\mathbf{k}_i = f_k (\mathbf{x}^k_i)\). Let’s assume among them there is a single <em>positive</em> key \(\mathbf{k}^+\) in the dictionary that matches \(\mathbf{q}\). In the paper, they create \(\mathbf{k}^+\) using a noise copy of \(\mathbf{x}_q\) with different <a href="#image-augmentations">augmentation</a>. Then the <a href="#infonce">InfoNCE</a> contrastive loss with temperature \(\tau\) is used over one positive and \(N-1\) negative samples:</p>
\[\mathcal{L}_\text{MoCo} = - \log \frac{\exp(\mathbf{q} \cdot \mathbf{k}^+ / \tau)}{\sum_{i=1}^N \exp(\mathbf{q} \cdot \mathbf{k}_i / \tau)}\]
<p>Compared to the <a href="#instance-discrimination-with-memoy-bank">memory bank</a>, a queue-based dictionary in MoCo enables us to reuse representations of immediately preceding mini-batches of data.</p>
<p>The MoCo dictionary is not differentiable as a queue, so we cannot rely on back-propagation to update the key encoder \(f_k\). One naive way might be to use the same encoder for both \(f_q\) and \(f_k\). Differently, MoCo proposed to use a momentum-based update with a momentum coefficient \(m \in [0, 1)\). Say, the parameters of \(f_q\) and \(f_k\) are labeled as \(\theta_q\) and \(\theta_k\), respectively.</p>
\[\theta_k \leftarrow m \theta_k + (1-m) \theta_q\]
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/MoCo.png" alt="MoCo" /></p>
<p class="image-caption"><em>Fig. 12. Illustration of how Momentum Contrast (MoCo) learns visual representations. (Image source: <a href="https://arxiv.org/abs/1911.05722">He et al, 2019</a>)</em></p>
<p>The advantage of MoCo compared to <a href="#simclr">SimCLR</a> is that MoCo decouples the batch size from the number of negatives, but SimCLR requires a large batch size in order to have enough negative samples and suffers performance drops when their batch size is reduced.</p>
<p>Two designs in SimCLR, namely, (1) an MLP projection head and (2) stronger data augmentation, are proved to be very efficient. <strong>MoCo V2</strong> (<a href="https://arxiv.org/abs/2003.04297">Chen et al, 2020</a>) combined these two designs, achieving even better transfer performance with no dependency on a very large batch size.</p>
<h4 id="curl">CURL</h4>
<p><strong>CURL</strong> (<a href="https://arxiv.org/abs/2004.04136">Srinivas, et al. 2020</a>) applies the above ideas in <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">Reinforcement Learning</a>. It learns a visual representation for RL tasks by matching embeddings of two data-augmented versions, \(o_q\) and \(o_k\), of the raw observation \(o\) via contrastive loss. CURL primarily relies on random crop data augmentation. The key encoder is implemented as a momentum encoder with weights as EMA of the query encoder weights, same as in <a href="#moco--moco-v2">MoCo</a>.</p>
<p>One significant difference between RL and supervised visual tasks is that RL depends on <em>temporal consistency</em> between consecutive frames. Therefore, CURL applies augmentation consistently on each stack of frames to retain information about the temporal structure of the observation.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/CURL.png" alt="CURL" /></p>
<p class="image-caption"><em>Fig. 13. The architecture of CURL. (Image source: <a href="https://arxiv.org/abs/2004.04136">Srinivas, et al. 2020</a>)</em></p>
<h3 id="feature-clustering">Feature Clustering</h3>
<h4 id="deepcluster">DeepCluster</h4>
<p><strong>DeepCluster</strong> (<a href="https://arxiv.org/abs/1807.05520">Caron et al. 2018</a>) iteratively clusters features via k-means and uses cluster assignments as pseudo labels to provide supervised signals.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/deepcluster.png" alt="DeepCluster" /></p>
<p class="image-caption"><em>Fig. 14. Illustration of DeepCluster method which iteratively clusters deep features and uses the cluster assignments as pseudo-labels. (Image source: <a href="https://arxiv.org/abs/1807.05520">Caron et al. 2018</a>)</em></p>
<p>In each iteration, DeepCluster clusters data points using the prior representation and then produces the new cluster assignments as the classification targets for the new representation. However this iterative process is prone to trivial solutions. While avoiding the use of negative pairs, it requires a costly clustering phase and specific precautions to avoid collapsing to trivial solutions.</p>
<h4 id="swav">SwAV</h4>
<p><strong>SwAV</strong> (<em>Swapping Assignments between multiple Views</em>; <a href="https://arxiv.org/abs/2006.09882">Caron et al. 2020</a>) is an online contrastive learning algorithm. It computes a code from an augmented version of the image and tries to predict this code using another augmented version of the same image.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SwAV.png" alt="SwAV" /></p>
<p class="image-caption"><em>Fig. 15. Comparison of SwAV and <a href="#instance-discrimination-with-memoy-bank">contrastive instance learning</a>. (Image source: <a href="https://arxiv.org/abs/2006.09882">Caron et al. 2020</a>)</em></p>
<p>Given features of images with two different augmentations, \(\mathbf{z}_t\) and \(\mathbf{z}_s\), SwAV computes corresponding codes \(\mathbf{q}_t\) and \(\mathbf{q}_s\) and the loss quantifies the fit by swapping two codes using \(\ell(.)\) to measure the fit between a feature and a code.</p>
\[\mathcal{L}_\text{SwAV}(\mathbf{z}_t, \mathbf{z}_s) = \ell(\mathbf{z}_t, \mathbf{q}_s) + \ell(\mathbf{z}_s, \mathbf{q}_t)\]
<p>The swapped fit prediction depends on the cross entropy between the predicted code and a set of \(K\) trainable prototype vectors \(\mathbf{C} = \{\mathbf{c}_1, \dots, \mathbf{c}_K\}\). The prototype vector matrix is shared across different batches and represents <em>anchor clusters</em> that each instance should be clustered to.</p>
\[\ell(\mathbf{z}_t, \mathbf{q}_s) = - \sum_k \mathbf{q}^{(k)}_s\log\mathbf{p}^{(k)}_t \text{ where } \mathbf{p}^{(k)}_t = \frac{\exp(\mathbf{z}_t^\top\mathbf{c}_k / \tau)}{\sum_{k'}\exp(\mathbf{z}_t^\top \mathbf{c}_{k'} / \tau)}\]
<p>In a mini-batch containing \(B\) feature vectors \(\mathbf{Z} = [\mathbf{z}_1, \dots, \mathbf{z}_B]\), the mapping matrix between features and prototype vectors is defined as \(\mathbf{Q} = [\mathbf{q}_1, \dots, \mathbf{q}_B] \in \mathbb{R}_+^{K\times B}\). We would like to maximize the similarity between the features and the prototypes:</p>
\[\begin{aligned}
\max_{\mathbf{Q}\in\mathcal{Q}} &\text{Tr}(\mathbf{Q}^\top \mathbf{C}^\top \mathbf{Z}) + \varepsilon \mathcal{H}(\mathbf{Q}) \\
\text{where }\mathcal{Q} &= \big\{ \mathbf{Q} \in \mathbb{R}_{+}^{K \times B} \mid \mathbf{Q}\mathbf{1}_B = \frac{1}{K}\mathbf{1}_K, \mathbf{Q}^\top\mathbf{1}_K = \frac{1}{B}\mathbf{1}_B \big\}
\end{aligned}\]
<p>where \(\mathcal{H}\) is the entropy, \(\mathcal{H}(\mathbf{Q}) = - \sum_{ij} \mathbf{Q}_{ij} \log \mathbf{Q}_{ij}\), controlling the smoothness of the code. The coefficient \(\epsilon\) should not be too large; otherwise, all the samples will be assigned uniformly to all the clusters. The candidate set of solutions for \(\mathbf{Q}\) requires every mapping matrix to have each row sum up to \(1/K\) and each column to sum up to \(1/B\), enforcing that each prototype gets selected at least \(B/K\) times on average.</p>
<p>SwAV relies on the iterative Sinkhorn-Knopp algorithm (<a href="https://arxiv.org/abs/1306.0895">Cuturi 2013</a>) to find the solution for \(\mathbf{Q}\).</p>
<h3 id="working-with-supervised-datasets">Working with Supervised Datasets</h3>
<h4 id="clip">CLIP</h4>
<p><strong>CLIP</strong> (<em>Contrastive Language-Image Pre-training</em>; <a href="https://arxiv.org/abs/2103.00020">Radford et al. 2021</a>) jointly trains a text encoder and an image feature extractor over the pretraining task that predicts which caption goes with which image.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CLIP.png" alt="CLIP" /></p>
<p class="image-caption"><em>Fig. 16. Illustration of CLIP contrastive pre-training over text-image pairs. (Image source: <a href="https://arxiv.org/abs/2103.00020">Radford et al. 2021</a>)</em></p>
<p>Given a batch of \(N\) (image, text) pairs, CLIP computes the dense cosine similarity matrix between all \(N\times N\) possible (image, text) candidates within this batch. The text and image encoders are jointly trained to maximize the similarity between \(N\) correct pairs of (image, text) associations while minimizing the similarity for \(N(N-1)\) incorrect pairs via a symmetric cross entropy loss over the dense matrix.</p>
<p>See the numy-like pseudo code for CLIP in Fig. 17.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/CLIP-algo.png" alt="CLIP pseudo code" /></p>
<p class="image-caption"><em>Fig. 17. CLIP algorithm in Numpy style pseudo code. (Image source: <a href="https://arxiv.org/abs/2103.00020">Radford et al. 2021</a>)</em></p>
<p>Compared to other methods above for learning good visual representation, what makes CLIP really special is <em>“the appreciation of using natural language as a training signal”</em>. It does demand access to supervised dataset in which we know which text matches which image. It is trained on 400 million (text, image) pairs, collected from the Internet. The query list contains all the words occurring at least 100 times in the English version of Wikipedia. Interestingly, they found that Transformer-based language models are 3x slower than a bag-of-words (BoW) text encoder at zero-shot ImageNet classification. Using contrastive objective instead of trying to predict the exact words associated with images (i.e. a method commonly adopted by image caption prediction tasks) can further improve the data efficiency another 4x.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/CLIP-efficiency.png" alt="CLIP efficiency" /></p>
<p class="image-caption"><em>Fig. 18. Using bag-of-words text encoding and contrastive training objectives can bring in multiple folds of data efficiency improvement. (Image source: <a href="https://arxiv.org/abs/2103.00020">Radford et al. 2021</a>)</em></p>
<p>CLIP produces good visual representation that can non-trivially transfer to many CV benchmark datasets, achieving results competitive with supervised baseline. Among tested transfer tasks, CLIP struggles with very fine-grained classification, as well as abstract or systematic tasks such as counting the number of objects. The transfer performance of CLIP models is smoothly correlated with the amount of model compute.</p>
<h4 id="supervised-contrastive-learning">Supervised Contrastive Learning</h4>
<p>There are several known issues with cross entropy loss, such as the lack of robustness to noisy labels and the possibility of poor margins. Existing improvement for cross entropy loss involves the curation of better training data, such as label smoothing and data augmentation. <strong>Supervised Contrastive Loss</strong> (<a href="https://arxiv.org/abs/2004.11362">Khosla et al. 2021</a>) aims to leverage label information more effectively than cross entropy, imposing that normalized embeddings from the same class are closer together than embeddings from different classes.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/sup-con.png" alt="SupCon" /></p>
<p class="image-caption"><em>Fig. 19. Supervised vs self-supervised contrastive losses. Supervised contrastive learning considers different samples from the same class as positive examples, in addition to augmented versions. (Image source: <a href="https://arxiv.org/abs/2004.11362">Khosla et al. 2021</a>)</em></p>
<p>Given a set of randomly sampled \(n\) (image, label) pairs, \(\{\mathbf{x}_i, y_i\}_{i=1}^n\), \(2n\) training pairs can be created by applying two random augmentations of every sample, \(\{\tilde{\mathbf{x}}_i, \tilde{y}_i\}_{i=1}^{2n}\).</p>
<p>Supervised contrastive loss \(\mathcal{L}_\text{supcon}\) utilizes multiple positive and negative samples, very similar to <a href="#soft-nearest-neighbors-loss">soft nearest-neighbor loss</a>:</p>
\[\mathcal{L}_\text{supcon} = - \sum_{i=1}^{2n} \frac{1}{2 \vert N_i \vert - 1} \sum_{j \in N(y_i), j \neq i} \log \frac{\exp(\mathbf{z}_i \cdot \mathbf{z}_j / \tau)}{\sum_{k \in I, k \neq i}\exp({\mathbf{z}_i \cdot \mathbf{z}_k / \tau})}\]
<p>where \(\mathbf{z}_k=P(E(\tilde{\mathbf{x}_k}))\), in which \(E(.)\) is an encoder network (augmented image mapped to vector) \(P(.)\) is a projection network (one vector mapped to another). \(N_i= \{j \in I: \tilde{y}_j = \tilde{y}_i \}\) contains a set of indices of samples with label \(y_i\). Including more positive samples into the set $N_i$ leads to improved results.</p>
<p>According to their experiments, supervised contrastive loss:</p>
<ul>
<li>does outperform the base cross entropy, but only by a small amount.</li>
<li>outperforms the cross entropy on robustness benchmark (ImageNet-C, which applies common naturally occuring perturbations such as noise, blur and contrast changes to the ImageNet dataset).</li>
<li>is less sensitive to hyperparameter changes.</li>
</ul>
<h2 id="language-sentence-embedding">Language: Sentence Embedding</h2>
<p>In this section, we focus on how to learn sentence embedding.</p>
<h3 id="text-augmentation">Text Augmentation</h3>
<p>Most contrastive methods in vision applications depend on creating an augmented version of each image. However, it is more challenging to construct text augmentation which does not alter the semantics of a sentence. In this section we look into three approaches for augmenting text sequences, including lexical edits, back-translation and applying cutoff or dropout.</p>
<h4 id="lexical-edits">Lexical Edits</h4>
<p><strong>EDA</strong> (<em>Easy Data Augmentation</em>; <a href="https://arxiv.org/abs/1901.11196">Wei & Zou 2019</a>) defines a set of simple but powerful operations for text augmentation. Given a sentence, EDA randomly chooses and applies one of four simple operations:</p>
<ol>
<li>Synonym replacement (SR): Replace \(n\) random non-stop words with their synonyms.</li>
<li>Random insertion (RI): Place a random synonym of a randomly selected non-stop word in the sentence at a random position.</li>
<li>Random swap (RS): Randomly swap two words and repeat \(n\) times.</li>
<li>Random deletion (RD): Randomly delete each word in the sentence with probability \(p\).</li>
</ol>
<p>where \(p=\alpha\) and \(n=\alpha \times \text{sentence_length}\), with the intuition that longer sentences can absorb more noise while maintaining the original label. The hyperparameter \(\alpha\) roughly indicates the percent of words in one sentence that may be changed by one augmentation.</p>
<p>EDA is shown to improve the classification accuracy on several classification benchmark datasets compared to baseline without EDA. The performance lift is more significant on a smaller training set. All the four operations in EDA help improve the classification accuracy, but get to optimal at different \(\alpha\)’s.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/EDA-exp1.png" alt="EDA classification" /></p>
<p class="image-caption"><em>Fig. 20. EDA leads to performance improvement on several classification benchmarks. (Image source: <a href="https://arxiv.org/abs/1901.11196">Wei & Zou 2019</a>)</em></p>
<p>In <strong>Contextual Augmentation</strong> (<a href="https://arxiv.org/abs/1805.06201">Sosuke Kobayashi, 2018</a>), new substitutes for word \(w_i\) at position \(i\) can be smoothly sampled from a given probability distribution, \(p(.\mid S\setminus\{w_i\})\), which is predicted by a bidirectional LM like BERT.</p>
<h4 id="back-translation">Back-translation</h4>
<p><strong>CERT</strong> (<em>Contrastive self-supervised Encoder Representations from Transformers</em>; <a href="https://arxiv.org/abs/2005.12766">Fang et al. (2020)</a>; <a href="https://github.com/UCSD-AI4H/CERT">code</a>) generates augmented sentences via <strong>back-translation</strong>. Various translation models for different languages can be employed for creating different versions of augmentations. Once we have a noise version of text samples, many contrastive learning frameworks introduced above, such as <a href="#moco--moco-v2">MoCo</a>, can be used to learn sentence embedding.</p>
<h4 id="dropout-and-cutoff">Dropout and Cutoff</h4>
<p><a href="https://arxiv.org/abs/2009.13818">Shen et al. (2020)</a> proposed to apply <strong>Cutoff</strong> to text augmentation, inspired by <a href="/lil-log/2019/01/31/generalized-language-models.html#cross-view-training">cross-view training</a>. They proposed three cutoff augmentation strategies:</p>
<ol>
<li><em>Token cutoff</em> removes the information of a few selected tokens. To make sure there is no data leakage, corresponding tokens in the input, positional and other relevant embedding matrices should all be zeroed out.,</li>
<li><em>Feature cutoff</em> removes a few feature columns.</li>
<li><em>Span cutoff</em> removes a continuous chunk of texts.</li>
</ol>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/text-cutoff.png" alt="Text cutoff" /></p>
<p class="image-caption"><em>Fig. 21. Schematic illustration of token, feature and span cutoff augmentation strategies. (Image source: <a href="https://arxiv.org/abs/2009.13818">Shen et al. 2020</a>)</em></p>
<p>Multiple augmented versions of one sample can be created. When training, <a href="https://arxiv.org/abs/2009.13818">Shen et al. (2020)</a> applied an additional KL-divergence term to measure the consensus between predictions from different augmented samples.</p>
<p><strong>SimCSE</strong> (<a href="https://arxiv.org/abs/2104.08821">Gao et al. 2021</a>; <a href="https://github.com/princeton-nlp/SimCSE">code</a>) learns from unsupervised data by predicting a sentence from itself with only <strong>dropout</strong> noise. In other words, they treat dropout as data augmentation for text sequences. A sample is simply fed into the encoder twice with different dropout masks and these two versions are the positive pair where the other in-batch samples are considered as negative pairs. It feels quite similar to the cutoff augmentation, but dropout is more flexible with less well-defined semantic meaning of what content can be masked off.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SimCSE.png" alt="SimCSE" /></p>
<p class="image-caption"><em>Fig. 22. SimCSE creates augmented samples by applying different dropout masks. The supervised version leverages NLI datasets to predict positive (entailment) or negative (contradiction) given a pair of sentences. (Image source: <a href="https://arxiv.org/abs/2104.08821">Gao et al. 2021</a>)</em></p>
<p>They ran experiments on 7 STS (Semantic Text Similarity) datasets and computed cosine similarity between sentence embeddings. They also tried out an optional MLM auxiliary objective loss to help avoid catastrophic forgetting of token-level knowledge. This aux loss was found to help improve performance on transfer tasks, but a consistent drop on the main STS tasks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SimCSE-STS-exp.png" alt="SimCSE experiments" /></p>
<p class="image-caption"><em>Fig. 23. Experiment numbers on a collection of STS benchmarks with SimCES. (Image source: <a href="https://arxiv.org/abs/2104.08821">Gao et al. 2021</a>)</em></p>
<h3 id="supervision-from-nli">Supervision from NLI</h3>
<p>The pre-trained BERT sentence embedding without any fine-tuning has been found to have poor performance for semantic similarity tasks. Instead of using the raw embeddings directly, we need to refine the embedding with further fine-tuning.</p>
<p><strong>Natural Language Inference (NLI)</strong> tasks are the main data sources to provide supervised signals for learning sentence embedding; such as <a href="https://nlp.stanford.edu/projects/snli/">SNLI</a>, <a href="https://cims.nyu.edu/~sbowman/multinli/">MNLI</a>, and <a href="https://www.kaggle.com/c/quora-question-pairs">QQP</a>.</p>
<h4 id="sentence-bert">Sentence-BERT</h4>
<p><strong>SBERT (Sentence-BERT)</strong> (<a href="https://arxiv.org/abs/1908.10084">Reimers & Gurevych, 2019</a>) relies on siamese and triplet network architectures to learn sentence embeddings such that the sentence similarity can be estimated by cosine similarity between pairs of embeddings. Note that learning SBERT depends on supervised data, as it is fine-tuned on several NLI datasets.</p>
<p>They experimented with a few different prediction heads on top of BERT model:</p>
<ul>
<li>Softmax classification objective: The classification head of the siamese network is built on the concatenation of two embeddings \(f(\mathbf{x}), f(\mathbf{x}')\) and \(\vert f(\mathbf{x}) - f(\mathbf{x}') \vert\). The predicted output is \(\hat{y}=\text{softmax}(\mathbf{W}_t [f(\mathbf{x}); f(\mathbf{x}'); \vert f(\mathbf{x}) - f(\mathbf{x}') \vert])\). They showed that the most important component is the element-wise difference \(\vert f(\mathbf{x}) - f(\mathbf{x}') \vert\).</li>
<li>Regression objective: This is the regression loss on \(\cos(f(\mathbf{x}), f(\mathbf{x}'))\), in which the pooling strategy has a big impact. In the experiments, they observed that <code class="language-plaintext highlighter-rouge">max</code> performs much worse than <code class="language-plaintext highlighter-rouge">mean</code> and <code class="language-plaintext highlighter-rouge">CLS</code>-token.</li>
<li>Triplet objective: \(\max(0, \|f(\mathbf{x}) - f(\mathbf{x}^+)\|- \|f(\mathbf{x}) - f(\mathbf{x}^-)\| + \epsilon)\), where \(\mathbf{x}, \mathbf{x}^+, \mathbf{x}^-\) are embeddings of the anchor, positive and negative sentences.</li>
</ul>
<p>In the experiments, which objective function works the best depends on the datasets, so there is no universal winner.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/SBERT.png" alt="SBERT" /></p>
<p class="image-caption"><em>Fig. 24. Illustration of Sentence-BERT training framework with softmax classification head and regression head. (Image source: <a href="https://arxiv.org/abs/1908.10084">Reimers & Gurevych, 2019</a>)</em></p>
<p>The <a href="https://github.com/facebookresearch/SentEval">SentEval</a> library (<a href="https://arxiv.org/abs/1803.05449">Conneau and Kiela, 2018</a>) is commonly used for evaluating the quality of learned sentence embedding. SBERT outperformed other baselines at that time (Aug 2019) on 5 out of 7 tasks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SBERT-SentEval.png" alt="SBERT SentEval results" /></p>
<p class="image-caption"><em>Fig. 25. The performance of Sentence-BERT on the SentEval benchmark. (Image source: <a href="https://arxiv.org/abs/1908.10084">Reimers & Gurevych, 2019</a>)</em></p>
<h4 id="bert-flow">BERT-flow</h4>
<p><a name="isotropy"></a>The embedding representation space is deemed <em>isotropic</em> if embeddings are uniformly distributed on each dimension; otherwise, it is <em>anisotropic</em>. <a href="https://arxiv.org/abs/2011.05864">Li et al, (2020)</a> showed that a pre-trained BERT learns a non-smooth <em>anisotropic</em> semantic space of sentence embeddings and thus leads to poor performance for text similarity tasks without fine-tuning. Empirically, they observed two issues with BERT sentence embedding:
Word frequency biases the embedding space. High-frequency words are close to the origin, but low-frequency ones are far away from the origin.
Low-frequency words scatter sparsely. The embeddings of low-frequency words tend to be farther to their \(k\)-NN neighbors, while the embeddings of high-frequency words concentrate more densely.</p>
<p><strong>BERT-flow</strong> (<a href="https://arxiv.org/abs/2011.05864">Li et al, 2020</a>; <a href="https://github.com/bohanli/BERT-flow">code</a>) was proposed to transform the embedding to a smooth and isotropic Gaussian distribution via <a href="/lil-log/2018/10/13/flow-based-deep-generative-models.html#what-is-normalizing-flows">normalizing flows</a>.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/BERT-flow.png" alt="BERT-flow" /></p>
<p class="image-caption"><em>Fig. 26. Illustration of the flow-based calibration over the original sentence embedding space in BERT-flow. (Image source: <a href="https://arxiv.org/abs/2011.05864">Li et al, 2020</a>)</em></p>
<p>Let \(\mathcal{U}\) be the observed BERT sentence embedding space and \(\mathcal{Z}\) be the desired latent space which is a standard Gaussian. Thus, \(p_\mathcal{Z}\) is a Gaussian density function and \(f_\phi: \mathcal{Z}\to\mathcal{U}\) is an invertible transformation:</p>
\[\mathbf{z}\sim p_\mathcal{Z}(\mathbf{z}) \quad
\mathbf{u}=f_\phi(\mathbf{z}) \quad
\mathbf{z}=f^{-1}_\phi(\mathbf{u})\]
<p>A flow-based generative model learns the invertible mapping function by maximizing the likelihood of \(\mathcal{U}\)’s marginal:</p>
\[\max_\phi\mathbb{E}_{\mathbf{u}=\text{BERT}(s), s\sim\mathcal{D}} \Big[ \log p_\mathcal{Z}(f^{-1}_\phi(\mathbf{u})) + \log\big\vert\det\frac{\partial f^{-1}_\phi(\mathbf{u})}{\partial\mathbf{u}}\big\vert \Big]\]
<p>where \(s\) is a sentence sampled from the text corpus \(\mathcal{D}\). Only the flow parameters \(\phi\) are optimized while parameters in the pretrained BERT stay unchanged.</p>
<p>BERT-flow was shown to improve the performance on most STS tasks either with or without supervision from NLI datasets. Because learning normalizing flows for calibration does not require labels, it can utilize the entire dataset including validation and test sets.</p>
<h4 id="whitening-operation">Whitening Operation</h4>
<p><a href="https://arxiv.org/abs/2103.15316">Su et al. (2021)</a> applied <strong>whitening</strong> operation to improve the <a href="#isotropy">isotropy</a> of the learned representation and also to reduce the dimensionality of sentence embedding.</p>
<p>They transform the mean value of the sentence vectors to 0 and the covariance matrix to the identity matrix. Given a set of samples \(\{\mathbf{x}_i\}_{i=1}^N\), let \(\tilde{\mathbf{x}}_i\) and \(\tilde{\Sigma}\) be the transformed samples and corresponding covariance matrix:</p>
\[\begin{aligned}
\mu &= \frac{1}{N}\sum_{i=1}^N \mathbf{x}_i \quad \Sigma = \frac{1}{N}\sum_{i=1}^N (\mathbf{x}_i - \mu)^\top (\mathbf{x}_i - \mu) \\
\tilde{\mathbf{x}}_i &= (\mathbf{x}_i - \mu)W \quad \tilde{\Sigma} = W^\top\Sigma W = I \text{ thus } \Sigma = (W^{-1})^\top W^{-1}
\end{aligned}\]
<p>If we get <a href="https://en.wikipedia.org/wiki/Singular_value_decomposition">SVD</a> decomposition of \(\Sigma = U\Lambda U^\top\), we will have \(W^{-1}=\sqrt{\Lambda} U^\top\) and \(W=U\sqrt{\Lambda^{-1}}\). Note that within SVD, \(U\) is an orthogonal matrix with column vectors as eigenvectors and \(\Lambda\) is a diagonal matrix with all positive elements as sorted eigenvalues.</p>
<p>A dimensionality reduction strategy can be applied by only taking the first \(k\) columns of \(W\), named <code class="language-plaintext highlighter-rouge">Whitening</code>-\(k\).</p>
<p style="width: 52%;" class="center"><img src="/lil-log/assets/images/whitening-SBERT.png" alt="Whitening-SBERT" /></p>
<p class="image-caption"><em>Fig. 27. Pseudo code of the whitening-\(k\) operation. (Image source: <a href="https://arxiv.org/abs/2103.15316">Su et al. 2021</a>)</em></p>
<p>Whitening operations were shown to outperform BERT-flow and achieve SOTA with 256 sentence dimensionality on many STS benchmarks, either with or without NLI supervision.</p>
<h3 id="unsupervised-sentence-embedding-learning">Unsupervised Sentence Embedding Learning</h3>
<h4 id="context-prediction">Context Prediction</h4>
<p><strong>Quick-Thought (QT) vectors</strong> (<a href="https://arxiv.org/abs/1803.02893">Logeswaran & Lee, 2018</a>) formulate sentence representation learning as a <em>classification</em> problem: Given a sentence and its context, a classifier distinguishes context sentences from other contrastive sentences based on their vector representations (<a href="/lil-log/2019/01/31/generalized-language-models.html#MLM">“cloze test”</a>). Such a formulation removes the softmax output layer which causes training slowdown.</p>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/quick-thought.png" alt="Quick-Thought vectors" /></p>
<p class="image-caption"><em>Fig. 28. Illustration of how Quick-Thought sentence embedding vectors are learned. (Image source: <a href="https://arxiv.org/abs/1803.02893">Logeswaran & Lee, 2018</a>)</em></p>
<p>Let \(f(.)\) and \(g(.)\) be two functions that encode a sentence \(s\) into a fixed-length vector. Let \(C(s)\) be the set of sentences in the context of \(s\) and \(S(s)\) be the set of candidate sentences including only one sentence \(s_c \in C(s)\) and many other non-context negative sentences. Quick Thoughts model learns to optimize the probability of predicting the only true context sentence \(s_c \in S(s)\). It is essentially NCE loss when considering the sentence \((s, s_c)\) as the positive pairs while other pairs \((s, s')\) where \(s' \in S(s), s'\neq s_c\) as negatives.</p>
\[\mathcal{L}_\text{QT}
= - \sum_{s \in \mathcal{D}} \sum_{s_c \in C(s)} \log p(s_c \vert s, S(s))
= - \sum_{s \in \mathcal{D}} \sum_{s_c \in C(s)}\frac{\exp(f(s)^\top g(s_c))}{\sum_{s'\in S(s)} \exp(f(s)^\top g(s'))}\]
<h4 id="mutual-information-maximization">Mutual Information Maximization</h4>
<p><strong>IS-BERT (Info-Sentence BERT)</strong> (<a href="https://arxiv.org/abs/2009.12061">Zhang et al. 2020</a>; <a href="https://github.com/yanzhangnlp/IS-BERT">code</a>) adopts a self-supervised learning objective based on <em>mutual information maximization</em> to learn good sentence embeddings in the <em>unsupervised</em> manners.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/IS-BERT.png" alt="IS-BERT" /></p>
<p class="image-caption"><em>Fig. 29. Illustration of Info-Sentence BERT. (Image source: <a href="https://arxiv.org/abs/2009.12061">Zhang et al. 2020</a>)</em></p>
<p>IS-BERT works as follows:</p>
<ol>
<li>
<p>Use BERT to encode an input sentence \(s\) to a token embedding of length \(l\), \(\mathbf{h}_{1:l}\).</p>
</li>
<li>Then apply 1-D conv net with different kernel sizes (e.g. 1, 3, 5) to process the token embedding sequence to capture the n-gram local contextual dependencies: \(\mathbf{c}_i = \text{ReLU}(\mathbf{w} \cdot \mathbf{h}_{i:i+k-1} + \mathbf{b})\). The output sequences are padded to stay the same sizes of the inputs.</li>
<li>The final local representation of the \(i\)-th token \(\mathcal{F}_\theta^{(i)} (\mathbf{x})\) is the concatenation of representations of different kernel sizes.</li>
<li>The global sentence representation \(\mathcal{E}_\theta(\mathbf{x})\) is computed by applying a mean-over-time pooling layer on the token representations \(\mathcal{F}_\theta(\mathbf{x}) = \{\mathcal{F}_\theta^{(i)} (\mathbf{x}) \in \mathbb{R}^d\}_{i=1}^l\).</li>
</ol>
<p>Since the mutual information estimation is generally intractable for continuous and high-dimensional random variables, IS-BERT relies on the Jensen-Shannon estimator (<a href="https://arxiv.org/abs/1606.00709">Nowozin et al., 2016</a>, <a href="https://arxiv.org/abs/1808.06670">Hjelm et al., 2019</a>) to maximize the mutual information between \(\mathcal{E}_\theta(\mathbf{x})\) and \(\mathcal{F}_\theta^{(i)} (\mathbf{x})\).</p>
\[I^\text{JSD}_\omega(\mathcal{F}_\theta^{(i)} (\mathbf{x}); \mathcal{E}_\theta(\mathbf{x})) = \mathbb{E}_{\mathbf{x}\sim P} [-\text{sp}(-T_\omega(\mathcal{F}_\theta^{(i)} (\mathbf{x}); \mathcal{E}_\theta(\mathbf{x})))] \\ - \mathbb{E}_{\mathbf{x}\sim P, \mathbf{x}' \sim\tilde{P}} [\text{sp}(T_\omega(\mathcal{F}_\theta^{(i)} (\mathbf{x}'); \mathcal{E}_\theta(\mathbf{x})))]\]
<p>where \(T_\omega: \mathcal{F}\times\mathcal{E} \to \mathbb{R}\) is a learnable network with parameters \(\omega\), generating discriminator scores. The negative sample \(\mathbf{x}'\) is sampled from the distribution \(\tilde{P}=P\). And \(\text{sp}(x)=\log(1+e^x)\) is the softplus activation function.</p>
<p>The unsupervised numbers on SentEval with IS-BERT outperforms most of the unsupervised baselines (Sep 2020), but unsurprisingly weaker than supervised runs. When using labelled NLI datasets, IS-BERT produces results comparable with SBERT (See Fig. 25 & 30).</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/IS-BERT-SentEval.png" alt="IS-BERT SentEval results" /></p>
<p class="image-caption"><em>Fig. 30. The performance of IS-BERT on the SentEval benchmark. (Image source: <a href="https://arxiv.org/abs/2009.12061">Zhang et al. 2020</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2021contrastive,
title = "Contrastive Representation Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2021",
url = "https://lilianweng.github.io/lil-log/2021/05/31/contrastive-representation-learning.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Sumit Chopra, Raia Hadsell and Yann LeCun. <a href="http://yann.lecun.com/exdb/publis/pdf/chopra-05.pdf">“Learning a similarity metric discriminatively, with application to face verification.”</a> CVPR 2005.</p>
<p>[2] Florian Schroff, Dmitry Kalenichenko and James Philbin. <a href="https://arxiv.org/abs/1503.03832">“FaceNet: A Unified Embedding for Face Recognition and Clustering.”</a> CVPR 2015.</p>
<p>[3] Hyun Oh Song et al. <a href="https://arxiv.org/abs/1511.06452">“Deep Metric Learning via Lifted Structured Feature Embedding.”</a> CVPR 2016. [<a href="https://github.com/rksltnl/Deep-Metric-Learning-CVPR16">code</a>]</p>
<p>[4] Ruslan Salakhutdinov and Geoff Hinton. <a href="http://proceedings.mlr.press/v2/salakhutdinov07a.html">“Learning a Nonlinear Embedding by Preserving Class Neighbourhood Structure”</a> AISTATS 2007.</p>
<p>[5] Michael Gutmann and Aapo Hyvärinen. <a href="http://proceedings.mlr.press/v9/gutmann10a.html">“Noise-contrastive estimation: A new estimation principle for unnormalized statistical models.”</a> AISTATS 2010.</p>
<p>[6] Kihyuk Sohn et al. <a href="https://papers.nips.cc/paper/2016/hash/6b180037abbebea991d8b1232f8a8ca9-Abstract.html">“Improved Deep Metric Learning with Multi-class N-pair Loss Objective”</a> NIPS 2016.</p>
<p>[7] Nicholas Frosst, Nicolas Papernot and Geoffrey Hinton. <a href="http://proceedings.mlr.press/v97/frosst19a.html">“Analyzing and Improving Representations with the Soft Nearest Neighbor Loss.”</a> ICML 2019</p>
<p>[8] Tongzhou Wang and Phillip Isola. <a href="https://arxiv.org/abs/2005.10242">“Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere.”</a> ICML 2020. [<a href="https://ssnl.github.io/hypersphere/">code</a>]</p>
<p>[9] Zhirong Wu et al. <a href="https://arxiv.org/abs/1805.01978">“Unsupervised feature learning via non-parametric instance-level discrimination.”</a> CVPR 2018.</p>
<p>[10] Ekin D. Cubuk et al. <a href="https://arxiv.org/abs/1805.09501">“AutoAugment: Learning augmentation policies from data.”</a> arXiv preprint arXiv:1805.09501 (2018).</p>
<p>[11] Daniel Ho et al. <a href="https://arxiv.org/abs/1905.05393">“Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules.”</a> ICML 2019.</p>
<p>[12] Ekin D. Cubuk & Barret Zoph et al. <a href="https://arxiv.org/abs/1909.13719">“RandAugment: Practical automated data augmentation with a reduced search space.”</a> arXiv preprint arXiv:1909.13719 (2019).</p>
<p>[13] Hongyi Zhang et al. <a href="https://arxiv.org/abs/1710.09412">“mixup: Beyond Empirical Risk Minimization.”</a> ICLR 2017.</p>
<p>[14] Sangdoo Yun et al. <a href="https://arxiv.org/abs/1905.04899">“CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features.”</a> ICCV 2019.</p>
<p>[15] Yannis Kalantidis et al. <a href="https://arxiv.org/abs/2010.01028">“Mixing of Contrastive Hard Negatives”</a> NeuriPS 2020.</p>
<p>[16] Ashish Jaiswal et al. <a href="https://arxiv.org/abs/2011.00362">“A Survey on Contrastive Self-Supervised Learning.”</a> arXiv preprint arXiv:2011.00362 (2021)</p>
<p>[17] Jure Zbontar et al. <a href="https://arxiv.org/abs/2103.03230">“Barlow Twins: Self-Supervised Learning via Redundancy Reduction.”</a> arXiv preprint arXiv:2103.03230 (2021) [<a href="https://github.com/facebookresearch/barlowtwins">code</a>]</p>
<p>[18] Alec Radford, et al. <a href="https://arxiv.org/abs/2103.00020">“Learning Transferable Visual Models From Natural Language Supervision”</a> arXiv preprint arXiv:2103.00020 (2021)</p>
<p>[19] Mathilde Caron et al. <a href="https://arxiv.org/abs/2006.09882">“Unsupervised Learning of Visual Features by Contrasting Cluster Assignments (SwAV).”</a> NeuriPS 2020.</p>
<p>[20] Mathilde Caron et al. <a href="https://arxiv.org/abs/1807.05520">“Deep Clustering for Unsupervised Learning of Visual Features.”</a> ECCV 2018.</p>
<p>[21] Prannay Khosla et al. <a href="https://arxiv.org/abs/2004.11362">“Supervised Contrastive Learning.”</a> NeurIPS 2020.</p>
<p>[22] Aaron van den Oord, Yazhe Li & Oriol Vinyals. <a href="https://arxiv.org/abs/1807.03748">“Representation Learning with Contrastive Predictive Coding”</a> arXiv preprint arXiv:1807.03748 (2018).</p>
<p>[23] Jason Wei and Kai Zou. <a href="https://arxiv.org/abs/1901.11196">“EDA: Easy data augmentation techniques for boosting performance on text classification tasks.”</a> EMNLP-IJCNLP 2019.</p>
<p>[24] Sosuke Kobayashi. <a href="https://arxiv.org/abs/1805.06201">“Contextual Augmentation: Data Augmentation by Words with Paradigmatic Relations.”</a> NAACL 2018</p>
<p>[25] Hongchao Fang et al. <a href="https://arxiv.org/abs/2005.12766">“CERT: Contrastive self-supervised learning for language understanding.”</a> arXiv preprint arXiv:2005.12766 (2020).</p>
<p>[26] Dinghan Shen et al. <a href="https://arxiv.org/abs/2009.13818">“A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation.”</a> arXiv preprint arXiv:2009.13818 (2020) [<a href="https://github.com/dinghanshen/cutoff">code</a>]</p>
<p>[27] Tianyu Gao et al. <a href="https://arxiv.org/abs/2104.08821">“SimCSE: Simple Contrastive Learning of Sentence Embeddings.”</a> arXiv preprint arXiv:2104.08821 (2020). [<a href="https://github.com/princeton-nlp/SimCSE">code</a>]</p>
<p>[28] Nils Reimers and Iryna Gurevych. <a href="https://arxiv.org/abs/1908.10084">“Sentence-BERT: Sentence embeddings using Siamese BERT-networks.”</a> EMNLP 2019.</p>
<p>[29] Jianlin Su et al. <a href="https://arxiv.org/abs/2103.15316">“Whitening sentence representations for better semantics and faster retrieval.”</a> arXiv preprint arXiv:2103.15316 (2021). [<a href="https://github.com/bojone/BERT-whitening">code</a>]</p>
<p>[30] Yan Zhang et al. <a href="https://arxiv.org/abs/2009.12061">“An unsupervised sentence embedding method by mutual information maximization.”</a> EMNLP 2020. [<a href="https://github.com/yanzhangnlp/IS-BERT">code</a>]</p>
<p>[31] Bohan Li et al. <a href="https://arxiv.org/abs/2011.05864">“On the sentence embeddings from pre-trained language models.”</a> EMNLP 2020.</p>
<p>[32] Lajanugen Logeswaran and Honglak Lee. <a href="https://arxiv.org/abs/1803.02893">“An efficient framework for learning sentence representations.”</a> ICLR 2018.</p>
<p>[33] Joshua Robinson, et al. <a href="https://arxiv.org/abs/2010.04592">“Contrastive Learning with Hard Negative Samples.”</a> ICLR 2021.</p>
<p>[34] Ching-Yao Chuang et al. <a href="https://arxiv.org/abs/2007.00224">“Debiased Contrastive Learning.”</a> NeuriPS 2020.</p>Lilian WengThe main idea of contrastive learning is to learn representations such that similar samples stay close to each other, while dissimilar ones are far apart. Contrastive learning can be applied to both supervised and unsupervised data and has been shown to achieve good performance on a variety of vision and language tasks.Reducing Toxicity in Language Models2021-03-21T12:00:00+00:002021-03-21T12:00:00+00:00https://lilianweng.github.io/lil-log/2021/03/21/reducing-toxicity-in-language-models<blockquote>
<p>Toxicity prevents us from safely deploying powerful pretrained language models for real-world applications. To reduce toxicity in language models, in this post, we will delve into three aspects of the problem: training dataset collection, toxic content detection and model detoxification.</p>
</blockquote>
<!--more-->
<p>Large pretrained <a href="/lil-log/2019/01/31/generalized-language-models.html">language models</a> are trained over a sizable collection of online data. They unavoidably acquire certain toxic behavior and biases from the Internet. Pretrained language models are very powerful and have shown great success in many NLP tasks. However, to safely deploy them for practical real-world applications demands a strong safety control over the model generation process.</p>
<p>Many challenges are associated with the effort to diminish various types of unsafe content:</p>
<ul>
<li>First, there are a variety of unsafe content types, such as toxicity, abusiveness, hate speech, biases, stereotypes, cyberbullying, identity attacks and more, which may or may not demand different treatment.</li>
<li>Second, there is no clearly and widely agreed-upon categorization and definition of unsafe behavior in pretrained language models. Individual perceptions could vary a lot due to different social backgrounds.</li>
</ul>
<p>In this post, we delve into the issue of toxicity in language models. As I’m still struggling to find a concrete definition of toxic content, I list a couple in the literature below.</p>
<blockquote>
<p>[<a href="https://support.perspectiveapi.com/s/about-the-api-attributes-and-languages">Perspective API</a>] A rude, disrespectful, or unreasonable comment; likely to make people leave a discussion.</p>
</blockquote>
<blockquote>
<p>[<a href="https://arxiv.org/abs/1912.06872">Kurita et al. 2019</a>] Content that can offend or harm its recipients, including hate speech, racism, and offensive language.</p>
</blockquote>
<blockquote>
<p>[<a href="https://arxiv.org/abs/2006.00998">Pavlopoulos et al. 2020</a>] We use the term ‘toxic’ as an umbrella term, but we note that the literature uses several terms for different kinds of toxic language or related phenomena: ‘offensive’, ‘abusive’, ‘hateful’, etc.</p>
</blockquote>
<p>Overall, toxicity is a broad term to describe several types of unsafe content. Methodologies in this post can be applied given some form of definition of toxicity; e.g. presented in the instruction for annotators. How to properly define the concept of toxicity and thus collect accurate annotation labels is out of the scope of this post.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#categorization-of-toxic-content" id="markdown-toc-categorization-of-toxic-content">Categorization of Toxic Content</a></li>
<li><a href="#data-collection" id="markdown-toc-data-collection">Data Collection</a> <ul>
<li><a href="#human-annotations" id="markdown-toc-human-annotations">Human Annotations</a></li>
<li><a href="#semi-supervised-dataset" id="markdown-toc-semi-supervised-dataset">Semi-supervised Dataset</a></li>
</ul>
</li>
<li><a href="#toxicity-detection" id="markdown-toc-toxicity-detection">Toxicity Detection</a> <ul>
<li><a href="#adversarial-attacks" id="markdown-toc-adversarial-attacks">Adversarial Attacks</a></li>
<li><a href="#perspective-api" id="markdown-toc-perspective-api">Perspective API</a></li>
<li><a href="#prompt-based-detection" id="markdown-toc-prompt-based-detection">Prompt-based Detection</a></li>
</ul>
</li>
<li><a href="#detoxification" id="markdown-toc-detoxification">Detoxification</a> <ul>
<li><a href="#blacklisting" id="markdown-toc-blacklisting">Blacklisting</a></li>
<li><a href="#prompt-based-detox" id="markdown-toc-prompt-based-detox">Prompt-based Detox</a></li>
<li><a href="#text-style-transfer" id="markdown-toc-text-style-transfer">Text Style Transfer</a></li>
<li><a href="#controllable-generation" id="markdown-toc-controllable-generation">Controllable Generation</a></li>
<li><a href="#system-level-safety-solution" id="markdown-toc-system-level-safety-solution">System-level Safety Solution</a></li>
</ul>
</li>
<li><a href="#appendix-datasets" id="markdown-toc-appendix-datasets">Appendix: Datasets</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="categorization-of-toxic-content">Categorization of Toxic Content</h2>
<p>How to categorize toxic content is not a straightforward task. Which content should be considered toxic and what types of toxic content exist can be very subjective. Language that does not look offensive to one group might seem inappropriate to another.</p>
<p>One popular categorization of offensive language is proposed by <a href="https://arxiv.org/abs/1902.09666">Zampieri et al. (2019)</a>, a three-level hierarchical taxonomy considering both the type and the target of offense. The Offensive Language Identification Dataset (<a href="#OLID">OLID</a>) dataset is collected based on this taxonomy.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/offensive-taxonomy.png" alt="Offensiveness categorization" /></p>
<p class="image-caption"><em>Fig. 1. The three-level hierarchical taxonomy for categorizing offensive language, proposed by <a href="https://arxiv.org/abs/1902.09666">Zampieri et al. (2019)</a>.</em></p>
<ul>
<li>Level A: “Is it offensive?”
<ul>
<li><code class="language-plaintext highlighter-rouge">[OFF]</code> Offensive: Inappropriate language, insults, or threats.</li>
<li><code class="language-plaintext highlighter-rouge">[NOT]</code> Not offensive: No offense or profanity.</li>
</ul>
</li>
<li>Level B: “Is the offensive text targeted?”
<ul>
<li><code class="language-plaintext highlighter-rouge">[TIN]</code> Targeted Insult: Targeted insult or threat towards an individual, a group or other.</li>
<li><code class="language-plaintext highlighter-rouge">[UNT]</code> Untargeted: Non-targeted profanity and swearing.</li>
</ul>
</li>
<li>Level C: What is the target?
<ul>
<li><code class="language-plaintext highlighter-rouge">[IND]</code> The offense targets an individual, often defined as “cyberbullying”.</li>
<li><code class="language-plaintext highlighter-rouge">[GRP]</code> The offense targets a group of people based on ethnicity, gender, sexual orientation, religion, or other common characteristic, often defined as “hate speech”.</li>
<li><code class="language-plaintext highlighter-rouge">[OTH]</code> The target can belong to other categories, such as an organization, an event, an issue, etc.</li>
</ul>
</li>
</ul>
<h2 id="data-collection">Data Collection</h2>
<p>Preparing a dataset of samples labelled as “safe” vs “unsafe” is the foundation for training a toxic language classifier and further providing signals for model detoxification.</p>
<h3 id="human-annotations">Human Annotations</h3>
<p><a href="https://arxiv.org/abs/2004.01670">Vidgen & Derczynski (2020)</a> summarized that training data annotations for toxicity detection on the high level can be collected by:</p>
<ol>
<li><em>Expert coding</em>: An expert has enough knowledge or training to complete the annotation tasks with good quality, such as a researcher who studies prejudice, a student with moderate level of training, or a NLP practitioner. It is more expensive but produces high-quality data.</li>
<li><em>Crowdsourcing</em>: Crowdsourcing platform pairs a large number of non-expert annotators with tasks. It is easier to scale up but demands more attention on quality control.</li>
<li><em>Professional moderators</em>: Professional moderators are experienced, well-trained on the tasks, but their goals are likely to optimize for the output specific to the platform.</li>
<li><em>Synthetic data</em>: Training dataset can also be manually created by relevant content creators to cover a broad range of toxic content types.</li>
</ol>
<p>Crowdsourcing is the most common approach among them (<a href="https://arxiv.org/abs/1703.04009">Davidson et al. 2017</a>, <a href="https://arxiv.org/abs/1902.09666">Zampieri et al. 2019</a>) and there are several good practices to improve the data quality:</p>
<ol>
<li><em>Test data</em>: A small set of annotations collected from a few experts can be used as test questions (<a href="https://arxiv.org/abs/1902.09666">Zampieri et al. 2019</a>) to filter out human annotators on the crowdsourcing platform who cannot achieve a certain threshold.</li>
<li><em>Clear guidelines</em>: Detailed instructions are useful to guide annotators to produce aligned and consistent labels. Without any guideline, annotators are encouraged to apply their personal perceptions, which could be problematic because (1) subjective interpretation of toxic content varies across individuals greatly and (2) it is tricky to mark certain types of noise like sarcasm and irony without any guideline.</li>
<li><em>Majority vote</em>: It is very common that we need labels from multiple annotators per sample and take the majority vote.</li>
<li><em>Understanding annotators’ identities</em>: Demographic background has a big impact on the annotator’s understanding of the task. We should aim to recruit diverse and qualified annotators.</li>
</ol>
<h3 id="semi-supervised-dataset">Semi-supervised Dataset</h3>
<p><a href="https://arxiv.org/abs/1811.12900">Khatri et al. (2018)</a> proposed a simple approach to bootstrap a large amount of semi-supervised dataset for learning toxic content classifiers. Their approach relies on a small annotated dataset and a large unlabelled dataset.</p>
<ol>
<li>First, they gather a blacklist of 800+ words covering topics of profanity, hate, sexual content and insults. A black list of profanities may have high precision and low recall, but it can provide weak supervised signals.</li>
<li>Subreddits are sorted by the percentage of blacklisted words. Then sensitive examples are sampled from the top subreddits and non-sensitive ones from the bottom, respectively.</li>
<li>Train a weak binary classifier to further select more samples from the sorted subreddits,
<ul>
<li>Sensitive: contain blacklisted words or toxic classifier confidence > 0.8;</li>
<li>Non-sensitive: not contain blacklisted words and toxic classifier confidence < 0.3</li>
</ul>
</li>
<li>Given this large expanded dataset, train a new classifier named “Two-stage bootstrap” (<strong>TS bootstrap</strong>).</li>
</ol>
<p>Their experiments showed that the TS bootstrap classifier achieved pretty good numbers on F1 score, accuracy and recall and it could also transfer to out-of-domain test data.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/TS-bootstrap.png" alt="Two-stage bootstrap" /></p>
<p class="image-caption"><em>Fig. 2. The two-stage bootstrap classifier is trained on a dataset bootstrapped by a weak toxic binary classifier on Reddit data. (Image source: <a href="https://arxiv.org/abs/1811.12900">Khatri et al. 2018</a>)</em></p>
<p><a href="#SOLID">SOLID</a> (Semi-Supervised Offensive Language Identification Dataset; <a href="https://arxiv.org/abs/2004.14454">Rosenthal et al. 2020</a>) contains 9+ M tweets annotated with the same taxonomy system as for <a href="#OLID">OLID</a>. SOLID treats OLID as a seed and extends it via a semi-supervised technique called <strong>democratic co-training</strong>. Democratic co-training (<a href="https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.76.3152&rep=rep1&type=pdf">Zhou & Goldman, 2004</a>) creates a large dataset from noisy labels provided by a collection of diverse models trained on a small supervised dataset. SOLID is constructed by:</p>
<ol>
<li>First, train a diverse set of supervised models on the labeled dataset OLID. The paper experimented with PMI (n-gram-based similarity), FastText (shallow neural model similar to BoW model), LSTM and BERT.</li>
<li>For each sample in the unannotated dataset, each model predicts a confidence score for the target class. The scores are aggregated by taking <code class="language-plaintext highlighter-rouge">avg()</code> or <code class="language-plaintext highlighter-rouge">min()</code>. Samples with high confidence are added into the dataset.</li>
</ol>
<p>BERT model performance does not improve when the supervised dataset is large enough for a simple task, but can benefit from a big semi-supervised dataset if the original supervised dataset is too small for the task.</p>
<h2 id="toxicity-detection">Toxicity Detection</h2>
<p>Given a supervised dataset, we can train a text classifier from scratch or fine-tune a pretrained language model to perform the classification task. But what if training samples are not good or sufficient enough? What if we don’t have access to such a supervised dataset?</p>
<h3 id="adversarial-attacks">Adversarial Attacks</h3>
<p>To create a toxicity detection model that is robust to adversarial attacks, <a href="https://arxiv.org/abs/1908.06083">Dinan et al. (2019)</a> proposed an iterative “<strong>build it, break it, fix it</strong>” strategy to improve the dialogue system safety with humans in the loop.</p>
<ol>
<li><em>Build it</em>: A BERT model is trained to classify toxic comments on the <a href="#jigsaw">Jigsaw dataset</a>.</li>
<li><em>Break it</em>: Crowdsourced workers are asked to write toxic messages that are mistakenly labelled as “safe” by the model.</li>
<li><em>Fix it</em>: The model is re-trained on the combination of the original dataset and newly collected adversarial samples.</li>
<li><em>Repeat</em>: Redeploy the robustified model and repeat a new round from step 1.</li>
</ol>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/build-break-fix.png" alt="build-break-fix" /></p>
<p class="image-caption"><em>Fig. 3. The illustration of iteratively improving a toxic content detection model via the “build it, break it, fix it” process. (Image source: <a href="https://arxiv.org/abs/1908.06083">Dinan et al. 2019</a>)</em></p>
<p>One baseline in their experiments is to replace the adversarial collection in the “break it” step with the standard collection where workers are asked to submit “offensive” messages directly . Compared to the standard collection, the adversarial collection has less explicit profanity and more negations to trick the model. The tasks become more challenging in the later rounds.</p>
<p>Adversarial models are more robust against adversarial attacks than baseline models trained on the standard collection. The third round adversarial model has worse performance on the standard task than the standard model, likely due to overfitting. I’m curious about how the model performance would be like if it is trained on both adversarial and standard collection, but I didn’t find it in the paper.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/build-break-fix-it-results.png" alt="build-break-fix results" /></p>
<p class="image-caption"><em>Fig. 4. The comparison of performance on standard and adversarial tasks of models trained on standard (\(S_i\)) and adversarial data collection (\(A_i\)). The subscript \(i\) indicates the number of training rounds. (Image source: <a href="https://arxiv.org/abs/1908.06083">Dinan et al. 2019</a>)</em></p>
<p>Another type of adversarial attack is to trick the detection model to mistakenly classify a toxic sentence as safe by replacing or scrambling a subset of characters. <a href="https://arxiv.org/abs/1912.06872">Kurita et al. (2019)</a> developed a method of generating such model-agnostic adversarial attacks, incorporating several types of character-level perturbations:</p>
<ol>
<li><em>Character scrambling</em>: randomly permute character positions.</li>
<li><em>Homoglyph substitution</em>: replace one or multiple letters with similar looking international letters.</li>
<li><em>Dictionary-based near-neighbor replacement</em>: find closest but distinct token in terms of Levenshtein distance.</li>
<li><em>Distractor injection</em>: inject distractor tokens by repeating random selected sequences of non-toxic tokens.</li>
</ol>
<p>Adversarial noise combining token obfuscation and distractor tokens leads to substantial performance degradation of a toxic classifier. Character-level perturbation degrades performance more than distractors.</p>
<p>The paper proposed two ways to resolve adversarial attacks:</p>
<ul>
<li><em>Adversarial training</em> refers to training the model on a dataset with noise. However, you need to know the details of the incoming attacks in advance. And there is no guarantee that training samples with arbitrary noise would generalize to the test set.</li>
<li><em>CDAE (contextual denoising autoencoder)</em> uses character-level and contextual information to denoise obfuscated tokens. CDAE takes a noise sample to predict the denoised version. Still, you need to know what types of character-level perturbation can be applied to create noise samples. CDAE performs comparable to BERT, but not substantially better.</li>
</ul>
<h3 id="perspective-api">Perspective API</h3>
<p><strong>perspective API</strong> (<a href="https://www.perspectiveapi.com/">www.perspectiveapi.com</a>) is the most widely used commercial API for toxic content detection. Perspective trains machine learning models to provide scores for several different <a href="https://support.perspectiveapi.com/s/about-the-api-attributes-and-languages">attributes</a>: toxicity, severe toxicity, insult, profanity, identity attack, threat, and sexually explicit. Each score is a number between [0, 1], indicating how likely the message contains a given attribute (i.e. confidence of a binary classifier) and it does not signify the severity of the attribute.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/about-perspective-api.png" alt="Perspective API" /></p>
<p class="image-caption"><em>Fig. 5. The overview of Perspective API scores. (Image source: <a href="https://support.perspectiveapi.com/s/about-the-api">About Perspective API</a>)</em></p>
<p><a href="https://arxiv.org/abs/2009.11462">Gehman et al. (2020)</a> measured the Perspective API toxicity scores of unprompted generations sampled from several pretrained language models. “Unprompted” means that the generation is only conditioned on the start-of-sentence tokens, without injecting any additional context. Noticeably, all the tested models get to the expected maximum toxicity > 0.5 after 100 generations. They also pointed out that training datasets for large LMs contain an non-negligible amount of toxic content.</p>
<p style="width: 45%;" class="center"><img src="/lil-log/assets/images/unprompted-toxicity.png" alt="Unprompted toxicity" /></p>
<p class="image-caption"><em>Fig. 6. Perspective API toxicity scores of unprompted generations. Each model generates a pool of 10k samples and the expected maximum toxicity score is estimated via bootstrapping. (Image source: <a href="https://arxiv.org/abs/2009.11462">Gehman et al. 2020</a>)</em></p>
<p>They collected the <a href="#RealToxicityPrompt"><strong>RealToxicityPrompt</strong> dataset</a> for studying toxicity in conditional language model generation. It contains 100k naturally occurring prompts with associated toxicity scores from Perspective API. Some prompts that do not contain any toxic language still can trigger very offensive completion.</p>
<p><a name="perspective-api-biases"></a>Despite of its popularity, Perspective API contains known biases, as summarized by <a href="https://arxiv.org/abs/2009.11462">Gehman et al. (2020)</a>:</p>
<blockquote>
<p>… exhibit biases against minorities and suffer from low agreement in annotations, partially due to annotator identity influencing their perception of hate speech and differences in annotation task setup.</p>
</blockquote>
<blockquote>
<p>Notably, recent work has found that systems are overestimating the prevalence of toxicity in text that contains a minority identity mention (e.g., “I’m a gay man”) or text by racial minorities (e.g., text in African American English). This is partially due to detectors’ over-reliance on lexical cues of toxicity (including swearwords, slurs, and other “bad” words).</p>
</blockquote>
<h3 id="prompt-based-detection">Prompt-based Detection</h3>
<p><strong>Self-diagnosis</strong> (<a href="https://arxiv.org/abs/2103.00453">Schick, et al. 2021</a>) is a process of exploiting the capacity of a pretrained language model to detect socially undesired attributes in its own outputs. The diagnosis depends on a predefined prompt template where the attributes are described in short text and measures the normalized probability of the model outputting “yes” versus “no”. Note that self-diagnosis does not need to access a labelled dataset for training.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># Self-diagnosis template
"{sample text}"
Question: Does the above text contain {undesired attribute}?
Answer: {yes/no}
</code></pre></div></div>
<p>They use RealToxicityPrompts dataset and Perspective API for evaluation in the experiments. The self-diagnosis performance is positively correlated with the model size.</p>
<p style="width: 55%;" class="center"><img src="/lil-log/assets/images/self-diagnosis-toxicity-score.png" alt="Self-diagnosis" /></p>
<p class="image-caption"><em>Fig. 7. Self-diagnosis abilities for identifying undesired attributes. The ground truth is provided by Perspective API. (Image source: <a href="https://arxiv.org/abs/2103.00453">Schick, et al. 2021</a>)</em></p>
<h2 id="detoxification">Detoxification</h2>
<h3 id="blacklisting">Blacklisting</h3>
<p><strong>Bad word filtering</strong> is a pretty intuitive and effective way to avoid explicit profane <a href="https://github.com/%20LDNOOBW/List-of-Dirty-Naughty-Obscene-%20and-Otherwise-Bad-Words">words</a> in the language model generation. At decoding time, we can manually reduce the probabilities of blocked words to avoid sampling them. However, it is not perfect, as it is still possible to have unsafe content composed of safe tokens.</p>
<p><strong>Vocabulary shifting</strong> (<a href="https://arxiv.org/abs/2009.11462">Gehman et al. 2020</a>) learns a 2-dimensional representation of toxicity versus non-toxicity for every token in the vocabulary of the pretrained model. Then the representation that encodes the non-toxicity is used to boost the likelihood of non-toxic tokens at decoding time.</p>
<h3 id="prompt-based-detox">Prompt-based Detox</h3>
<p><strong>Self-debiasing</strong> (<a href="https://arxiv.org/abs/2103.00453">Schick et al. 2021</a>) follows the similar idea as in <a href="#prompt-based-detection">self-diagnosis</a>. It is a process for using the internal knowledge of a pretrained language model to reduce the probability of undesired attributes in the model generation.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># Self-debiasing template, denoted as sdb(.)
The following text contains {undesired attribute s}:
{sample text x}
</code></pre></div></div>
<p>Given an input prompt \(\mathbf{x}\), a textual description of undesired attributes \(s\), and the language model \(M\), self-debiasing computes the difference between the probability of next words without and with the self-debiasing template \(\text{sdb}(.)\):</p>
\[\Delta(w, \mathbf{x}, s) = p_M(w\vert\mathbf{x}) - p_M(w\vert\text{sdb}(\mathbf{x}, s))\]
<p>Because \(\text{sdb}(.)\) is expected to boost the probabilities of undesired words, \(\Delta(w, \mathbf{x}, s)\) should be negative for undesirable words.</p>
<p>In self-diasing decoding, a scaling function of the probability difference \(\alpha(\Delta(w, \mathbf{x}, s)): \mathbb{R}\to[0,1]\) is used to alter the true sampling distribution,</p>
\[\tilde{p}_M(w\vert\mathbf{x}) \propto \alpha(\Delta(w, \mathbf{x}, s)) p_M(w\vert\mathbf{x})\]
<p>In the paper, they used a soft variant where the probabilities of the words with negative \(\Delta\) are reduced w.r.t. the magnitude of \(\Delta(w, \mathbf{x}, s)\):</p>
\[\alpha(x)=\begin{cases} 1 & \text{ if } x\geq 0 \\ e^{\lambda\cdot x} & \text{ otherwise} \end{cases}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/self-debiasing-decoding.png" alt="Self-debiasing" /></p>
<p class="image-caption"><em>Fig. 8. Self-diasing decoding can reduce the probabilities of undesirable attributes. The scores are provided by Perspective API. (Image source: <a href="https://arxiv.org/abs/2103.00453">Schick et al. 2021</a>)</em></p>
<p>There are a couple of major limitations in self-debiasing detoxification:</p>
<ol>
<li>The evaluation solely relies on Perspective API, so it cannot capture bias & toxicity attributes that are <a href="#perspective-api-biases">not covered</a> by Perspective API, such as gender biases. Using human evaluation is another alternative but the scale is limited.</li>
<li>Self-debiasing sometimes acts too aggressively and filters out harmless words and it does not maintain the same level of perplexity as the original model.</li>
<li>The approach is constrained by the internal capacity of the model. For example, if the model is not aware of certain biases, it would not be able to correct them.</li>
</ol>
<h3 id="text-style-transfer">Text Style Transfer</h3>
<p><strong>Unsupervised style transfer</strong> can be used to translate offensive sentences into innocuous ones (<a href="https://arxiv.org/abs/1805.07685">Santos et al. 2018</a>). The approach should work for non-parallel datasets, meaning that we only have access to two separate datasets of offensive and non-offensive samples, but not paired versions. To preserve the content when transferring the text into another style, a cycle consistency loss (<a href="https://arxiv.org/abs/1703.10593">Zhu et al. 2017</a>) is adopted.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/offensive-text-style-transfer.png" alt="Text style transfer" /></p>
<p class="image-caption"><em>Fig. 9. The training process of a neural text style transfer algorithm using non-parallel data. (Image source: <a href="https://arxiv.org/abs/1805.07685">Santos et al. 2018</a>)</em></p>
<p>Let \(s_i\) be the desired style (\(i=0\) for offensive and \(i=1\) for non-offensive), and \(\mathbf{x}^i_k\) be the \(k\)-th sample of style \(s_i\), \(k = 1, \dots, n\). Both the encoder \(E\) and decoder \(G\) take a sample (or hidden state) along with a style label. The classifier \(C\) predicts a probability distribution over the style labels given an input sample.</p>
<p>Following the illustration in Fig. 9:</p>
<ul>
<li>The top branch of forward transfer is auto encoder: \(E(\mathbf{x}^i_k, s_i) \to H^i_k \to G(H^i_k, s_i) \to \hat{\mathbf{x}}^{i\to i}_k\). Two losses are computed:
<ul>
<li>Reconstruction loss measures how well the decoder can reconstruct the sample back:</li>
</ul>
\[\mathcal{L}_\text{self} = \mathbb{E}_{\mathbf{x}^i_k \sim \mathcal{X}} [-\log p_G(\mathbf{x}_k^i \mid E(\mathbf{x}^i_k, s_i), s_i)]\]
</li>
<li>The bottom branch of forward transfer: \(E(\mathbf{x}^i_k, s_i) \to H^i_k \to G(H^i_k, s_j) \to \hat{\mathbf{x}}^{i\to j}_k\)
<ul>
<li>Classification loss measures the effectiveness of style transfer:</li>
</ul>
\[\mathcal{L}_\text{style_fwd} = \mathbb{E}_{\hat{\mathbf{x}}^{i\to j}_k \sim \hat{\mathcal{X}}} [-\log p_C(s_j \mid \hat{\mathbf{x}}^{i\to j}_k)]\]
</li>
<li>The back transfer uses cycle consistency loss: \(E(\hat{\mathbf{x}}^{i\to j}_k, s_j) \to H^{i\to j}_k \to G(H^{i\to j}_k, s_i) \to \hat{\mathbf{x}}^{i\to j \to i}_k\)
<ul>
<li>The cycle consistency loss controls how well the transferred sample can be converted back to the original form to encourage content preservation:</li>
</ul>
\[\mathcal{L}_\text{cycle} = \mathbb{E}_{\mathbf{x}^i_k \sim \mathcal{X}} [-\log p_G(\mathbf{x}_k^i \mid E(\hat{\mathbf{x}}^{i \to j}_k, s_j), s_i)]\]
<ul>
<li>The classification loss ensures that the back-transferred sample has the correct label:</li>
</ul>
\[\mathcal{L}_\text{style_back} = \mathbb{E}_{\hat{\mathbf{x}}^{i\to j}_k \sim \hat{\mathcal{X}}} [-\log p_C(s_i \mid G(E(\hat{\mathbf{x}}^{i\to j}_k, s_j), s_i))]\]
</li>
<li>There is an additional supervised classification loss for training an accurate classifier:</li>
</ul>
\[\mathcal{L}_\text{class} = \mathbb{E}_{\hat{\mathbf{x}}^{i\to j}_k \sim \hat{\mathcal{X}}} [-\log p_C(s_i \mid \hat{\mathbf{x}}^i_k)]\]
<p>The final training objective is as follows and the encoder, decoder and classifier are jointly trained:</p>
\[\mathcal{L}(\theta_E, \theta_G, \theta_C) = \min_{E, G, C} \mathcal{L}_\text{self} + \mathcal{L}_\text{style_fwd} + \mathcal{L}_\text{cycle} + \mathcal{L}_\text{style_back}+ \mathcal{L}_\text{class}\]
<p><strong>Style Transformer</strong> (<a href="https://arxiv.org/abs/1905.05621">Dai et al. 2019</a>) also aims to learn unsupervised text style transfer. Different from the encoder-decoder model in <a href="https://arxiv.org/abs/1805.07685">Santos et al. 2018</a>, it learns a Transformer-based style transfer function \(f_\theta(\mathbf{x}, s)\) for a given input sample \(\mathbf{x}\) and a desired style control variable \(s\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/style-transformer.png" alt="Style transformer" /></p>
<p class="image-caption"><em>Fig. 10. The comparison of style transformer and previous models that depend on disentangled latent representation. (Image source: <a href="https://arxiv.org/abs/1905.05621">Dai et al. 2019</a>)</em></p>
<p>Without access to the parallel corpus, the style transformer adopts a discriminator to create supervision from non-parallel dataset.</p>
<p>Let \(s\) and \(\hat{s}\) be two mutually exclusive style variables and \(\mathbf{x}\) is a sample of style \(s\), style transformer computes several losses:</p>
<ul>
<li>Self reconstruction loss: \(\mathcal{L}_\text{self} = - p_\theta (\mathbf{x} \vert \mathbf{x}, s)\)</li>
<li>Cycle-consistency loss: \(\mathcal{L}_\text{cycle} = - p_\theta (\mathbf{x} \vert f_\theta(\mathbf{x}, \hat{s}), s)\)</li>
<li>Style controlling loss: This is necessary because otherwise the model would simply learn to copy the input over.</li>
</ul>
\[\mathcal{L}_\text{style} = - p_\phi(\text{class} = 1 \vert f_\theta(\mathbf{x}, \hat{s}), \hat{s})\]
<p>, where the discriminator is a simple binary classifier trained to optimize the negative log-likelihood of the correct style. The discriminator is trained by labelling</p>
<ul>
<li>\(\{(\mathbf{x}, s), (f_\theta(\mathbf{x}, s), s), (f_\theta(\mathbf{x}, \hat{s}), \hat{s})\}\) as positive class 1</li>
<li>\(\{(\mathbf{x}, \hat{s}), (f_\theta(\mathbf{x}, s), \hat{s}), (f_\theta(\mathbf{x}, \hat{s}), s)\}\) as negative class 0.</li>
</ul>
<p style="width: 40%;" class="center"><img src="/lil-log/assets/images/style-transformer-training.png" alt="Style transformer training" /></p>
<p class="image-caption"><em>Fig. 11. The training process of Style Transformer. (Image source: <a href="https://arxiv.org/abs/1905.05621">Dai et al. 2019</a>)</em></p>
<p>Driven by the research question “Can we fine-tune a pre-trained language model to suggest civil rephrasings of rude comments using a dataset solely annotated in toxicity?”, <a href="https://arxiv.org/abs/2102.05456">Laugier et al. (2021)</a> fine-tuned a pretrained text-to-text transformer with a denoising and cyclic auto-encoder loss.</p>
<p>Let \(s\) be the attribute of \(\mathbf{x}\) (e.g. “civil”) and \(\bar{s}\) be the other opposite attribute (e.g. “toxic”). These two attributes are mutually exclusive. The goal is to learn a mapping function \(f_\theta\) such that it translates \(x\) to a new fluent sequence \(y\) with target attribute \(a\) while preserving \(x\)’s content.</p>
<p>The encoder-decoder model is trained with the loss:</p>
\[\mathcal{L} = \lambda_\text{DAE} \mathcal{L}_\text{DAE} + \lambda_\text{cycle} \mathcal{L}_\text{cycle}\]
<ul>
<li>The denoising auto-encoder loss is the loss for denoising auto-encoders, where \(\eta\) is a <a href="/lil-log/2019/01/31/generalized-language-models.html#pre-training-tasks">masking</a> function same as in BERT training:</li>
</ul>
\[\mathcal{L}_\text{DAE} = \mathbb{E}_{\mathbf{x} \sim \mathcal{X}} [−\log p_\theta(\mathbf{x} \mid \eta(\mathbf{x}), s)]\]
<ul>
<li>The cycle consistency loss (<a href="https://arxiv.org/abs/1703.10593">Zhu et al. 2017</a>) has \(\tilde{\theta}\) to produce a non-differentiable pseudo-prediction \(\hat{\mathbf{y}}\) and it does not take gradient backpropagation.</li>
</ul>
\[\mathcal{L}_\text{cycle} = \mathbb{E}_{\mathbf{x} \sim \mathcal{X}} [−\log p_\theta(\mathbf{x} \mid f_{\tilde{\theta}}(\mathbf{x}, \bar{s}), s)]\]
<p>They used the above loss to fine-tune a T5 model, resulting in a model named <strong>CAE-T5</strong>. The conditioning is implemented like CTRL via control code (“civil” or “toxic”) prepended to the start of a sequence.</p>
<p>Automatic evaluation of the text style transferred results relies on three metrics:</p>
<ol>
<li><em>Accuracy</em>: Classification accuracy measures how successful the style transfer is.</li>
<li><em>Fluency</em>: Fluency is commonly measured by perplexity by another separately trained LM on non-toxic samples.</li>
<li><em>Content preservation</em>: It is the content similarity between transferred and original sentences, measured by BLEU or embedding based content similarity.</li>
</ol>
<p>Human evaluation is also necessary but more costly.</p>
<p>Compared to the baseline (<a href="https://arxiv.org/abs/1705.09655">Shen et al. 2017</a>), the style transfer method by <a href="https://arxiv.org/abs/1805.07685">Santos et al. 2018</a> achieves better classification accuracy, better content preservation, but worse perplexity. CAE-T5 has worse classification accuracy, competitive content preservation, and better perplexity compared to a set of baselines including Style Transformer.</p>
<h3 id="controllable-generation">Controllable Generation</h3>
<p>We can try to avoid toxic outputs via <em>controllable text generation</em>. There are several popular approaches for steering a pretrained language model toward desired styles, topics or safety criteria:</p>
<ol>
<li>Apply guided decoding strategies and select desired outputs at test time.</li>
<li>Optimize for the most desired outcomes via good prompt design.</li>
<li>Fine-tune the base model or steerable layers to do conditioned content generation.</li>
</ol>
<p>Read more in my <a href="/lil-log/2021/01/02/controllable-neural-text-generation.html">last post</a> on controllable neural text generation, introducing methods like <a href="https://arxiv.org/abs/2010.15980">AutoPrompt</a>, <a href="https://arxiv.org/abs/1909.05858">CTRL</a>, <a href="https://arxiv.org/abs/1912.02164">PPLM</a>, <a href="https://arxiv.org/abs/2009.06367">GeDi</a> and many more.</p>
<p><a href="https://arxiv.org/abs/2009.11462">Gehman et al. (2020)</a> experimented with both data-based (supervised fine-tuning, CTRL training) and decoding-based (vocabulary shifting, blocked word filtering, PPLM) methods for language model detoxification. They found that toxicity control tokens (CTRL) and swear word filters are <em>less successful</em> than more computationally or data-intensive methods like fine-tuning on non-toxic corpora and PPLM.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RealToxicityPrompts-experiments.png" alt="RealToxicityPrompts detox experiments" /></p>
<p class="image-caption"><em>Fig. 12. Table list expected maximum toxicity score over 25 generations (left) and the empirical probability of generating toxic text over 25 generations (right) for several detoxification methods. Scores are provided by Perspective API. (Image source: <a href="https://arxiv.org/abs/2009.11462">Gehman et al., 2020</a>)</em></p>
<h3 id="system-level-safety-solution">System-level Safety Solution</h3>
<p><a href="https://arxiv.org/abs/2010.07079">Xu et al. (2020)</a> presented a thorough system-level design for building safe chatbots.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/safe-chatbot-system.png" alt="Safe chatbot system" /></p>
<p class="image-caption"><em>Fig. 13. Illustration of a safe chat bot system. (Image source: <a href="https://arxiv.org/abs/2010.07079">Xu et al. 2020</a>)</em></p>
<p>They consider four general strategies in the recipes for making the bot safer:</p>
<ul>
<li><em>Detect unsafe content</em>: Adopt a classifier for detecting unsafe language on both the input and output side, as an extra safety layer on top of the language model.
<ul>
<li>The classifier is trained on an enhanced version of the <a href="#jigsaw">Jigsaw toxic</a> comment dataset (safe vs unsafe binary labels), extended with <a href="#adversarial-attacks">adversarial human attacks</a> (<a href="https://arxiv.org/abs/1908.06083">Dinan et al. 2019</a>) and <a href="#semi-supervised-dataset">semi-supervision</a> (<a href="https://arxiv.org/abs/1811.12900">Khatri et al. 2018</a>).</li>
<li>The safety classifier can be used on both the user input and the model output. If it detects unsafe content, the system is configured to return a canned, predefined response (e.g “I’m sorry I’m not sure what to say.”), or decide to change topics. It is worthy noting that this approach relies on a high-quality classifier. The conversation experience would be drastically disrupted with too many false positives.</li>
<li>Bot adversarial dialogue (BAD) safety: The idea is to collect data on humans adversarially probing the system to make mistakes and then use the data for further training. During annotation, human labellers can tag the bot’s response with an unsafe-safe rating based on the percentage of population who may consider it as unsafe. This probing data collection is used to train a multi-turn safety classifier, predicting whether a response is offensive given the dialogue context.</li>
</ul>
</li>
<li><em>Safe generation</em>: Train a model that is less likely to output unsafe responses.
<ul>
<li>A predefined list of unsafe words/n-grams can be <a href="#blacklisting">blocked</a> at decoding time.</li>
<li>The pretraining data is filtered by the above safety classifier, or filtered based on known authors.</li>
<li>The problem with pre-training only with safe datasets is that if the model has never seen toxic language during training, it would not know how to respond at test time (OOD; e.g. may just copy the offensive content). They instead prepare a collection of training samples where the last utterance is labelled as “unsafe” and then attach a safe response following that unsafe attack. Then the model is fine-tuned on the “baked-in” safety data.</li>
<li>Do <a href="https://arxiv.org/abs/1909.05858">CTRL</a> style training by assigning “safe” vs “unsafe” label using the safety classifier.</li>
</ul>
</li>
<li><em>Avoid sensitive topics</em>:
<ul>
<li>In order to avoid sensitive topics (politics, religion, drug use, medical advice, and NSFW and relationships/dating), they trained a multi-class classifier to detect those topics using crowdsourced lists of subreddits. The classifier can be periodically re-trained to capture the changes within topics over time.</li>
<li>A small validation set is collected by recruiting crowdsourced workers to discuss one of the target topics.</li>
</ul>
</li>
<li><em>Gender bias mitigation</em>:
<ul>
<li>They used <a href="https://arxiv.org/abs/1909.05858">CTRL</a> style training to mitigate gender biases.</li>
<li>Precisely, given a gendered word list, tag the training samples with \(F^0 M^0\), \(F^0 M^+\), \(F^+ M^+\), and \(F^+ M^0\) labels, indicating whether the response contains female / male words (\(+\) contains, \(-\) does not contain). At test time, the system runs with a control label \(F^0 M^0\) to avoid outputting gender specific words.</li>
</ul>
</li>
</ul>
<h2 id="appendix-datasets">Appendix: Datasets</h2>
<p>(*Only datasets in English are listed here.)</p>
<p><strong>Hate Speech and Offensive Language</strong> Dataset (2017): contains about 25k tweets, each labelled manually as one of three categories: hate speech, offensive but not hate speech, or neither offensive nor hate speech. [<a href="https://github.com/t-davidson/hate-speech-and-offensive-language/blob/master/data/readme.md">Download</a>]</p>
<p><a name="jigsaw"></a><strong>Jigsaw Toxic</strong> Comments Classification Dataset (2018): contains about 160k examples extracted from Wikipedia discussion pages, each annotated for 7 classes: toxic, severe toxic, obscene, threat, insult, identity hate and non-toxic. The labelling process involved 5000 crowdsourced annotators. [<a href="https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge">Download</a>]</p>
<p><strong>Jigsaw Unintended Bias in Toxicity</strong> Classification Dataset (2019): contains about 2 Millions comments from the Civil Comments platform, which shut down in 2017. This data is annotated for toxicity, toxicity sub-types, and mentions of identities, which enables evaluation of unintended bias with respect to identity mentions. [<a href="https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification">Download</a>]</p>
<p><a name="OLID"></a><strong>OLID</strong> (Offensive Language Identification Dataset; 2019): contains 14,100 English tweets, annotated according to the three-level taxonomy as described <a href="#categorization-of-toxic-content">here</a>. [<a href="https://sites.google.com/site/offensevalsharedtask/olid">Download</a>]</p>
<p><a name="SOLID"></a><strong>SOLID</strong> (Semi-Supervised Offensive Language Identification Dataset; 2020): contains 9+ Millions tweets annotated following OLID’s three level taxonomy. [<a href="https://sites.google.com/site/offensevalsharedtask/solid">Download</a>]</p>
<p><a name="RealToxicityPrompt"></a><strong>RealToxicityPrompts</strong> dataset (2020): contains 100k sentence snippets from the web with Perspective API toxicity scores for studying the risk of neural toxic degeneration in language models. [<a href="https://allenai.org/data/real-toxicity-prompts">Download</a>]</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2021toxic,
title = "Reducing Toxicity in Language Models.",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2021",
url = "https://lilianweng.github.io/lil-log/2021/03/21/reducing-toxicity-in-language-models.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Vidgen, et al. <a href="https://www.aclweb.org/anthology/W19-3509/">“Challenges and frontiers in abusive content detection.”</a> Workshop on Abusive Language Online 2019.</p>
<p>[2] Zampieri et al. <a href="https://arxiv.org/abs/1902.09666">“Predicting the type and target of offensive posts in social media.”</a> NAACL 2019.</p>
<p>[3] Vidgen & Deczynski. <a href="https://arxiv.org/abs/2004.01670">“Directions in abusive language training data, a systematic review: Garbage in, garbage out.”</a> PLoS ONE 15(12): e0243300 (2020).</p>
<p>[4] Davidson et al. <a href="https://arxiv.org/abs/1703.04009">“Automated hate speech detection and the problem of offensive language.”</a> ICWSM 2017.</p>
<p>[5] Khatri et al. <a href="https://arxiv.org/abs/1811.12900">“Detecting offensive content in open-domain conversations using two stage semi-supervision.”</a> NeuriIPS CONVAI Workshop 2018.</p>
<p>[6] Rosenthal et al. <a href="https://arxiv.org/abs/2004.14454">“A Large-Scale Semi-Supervised Dataset for Offensive Language Identification”</a> arXiv:2004.14454 (2020).</p>
<p>[7] Pavlopoulos et al. <a href="https://arxiv.org/abs/2006.00998">“Toxicity Detection: Does Context Really Matter?”</a> arXiv:2006.00998 (2020).</p>
<p>[8] Dinan et al. <a href="https://arxiv.org/abs/1908.06083">“Build it, break it, fix it for dialogue safety: Robustness from adversarial human attack.”</a> arXiv:1908.06083 (2019).</p>
<p>[9] Kurita et al. <a href="https://arxiv.org/abs/1912.06872">“Towards Robust Toxic Content Classification”</a> arXiv:1912.06872 (2019)</p>
<p>[10] Santos et al. <a href="https://arxiv.org/abs/1805.07685">“Fighting offensive language on social media with unsupervised text style transfer.”</a> arXiv:1805.07685 (2018)</p>
<p>[11] Dai et al. <a href="https://arxiv.org/abs/1905.05621">“Style Transformer: Unpaired Text Style Transfer without Disentangled Latent Representation”</a> ACL 2019.</p>
<p>[12] Laugier et al. <a href="https://arxiv.org/abs/2102.05456">“Civil Rephrases Of Toxic Texts With Self-Supervised Transformers”</a> arXiv:2102.05456 (2021). <a href="https://github.com/LeoLaugier/conditional-auto-encoder-text-to-text-transfer-transformer">code</a></p>
<p>[13] Schick et al. <a href="https://arxiv.org/abs/2103.00453">“Self-Diagnosis and Self-Debiasing: A Proposal for Reducing Corpus-Based Bias in NLP”</a> arXiv:2103.00453 (2021).</p>
<p>[14] Gehman et al. <a href="https://arxiv.org/abs/2009.11462">“RealToxicityPrompts: Evaluating Neural Toxic Degeneration in Language Models”</a> EMNLP 2020.</p>
<p>[15] Xu et al. <a href="https://arxiv.org/abs/2010.07079">“Recipes for Safety in Open-domain Chatbots”</a> arXiv:2010.07079 (2020).</p>Lilian WengToxicity prevents us from safely deploying powerful pretrained language models for real-world applications. To reduce toxicity in language models, in this post, we will delve into three aspects of the problem: training dataset collection, toxic content detection and model detoxification.Controllable Neural Text Generation2021-01-02T12:00:00+00:002021-01-02T12:00:00+00:00https://lilianweng.github.io/lil-log/2021/01/02/controllable-neural-text-generation<blockquote>
<p>The modern language model with SOTA results on many NLP tasks is trained on large scale free text on the Internet. It is challenging to steer such a model to generate content with desired attributes. Although still not perfect, there are several approaches for controllable text generation, such as guided or learned decoding strategy, smart prompt design, or fine-tuning the model with various methods.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2021-02-01: Updated to version 2.0 with several work added and many typos fixed.]</span>
<br />
<span style="color: #286ee0;">[Updated on 2021-05-26: Add P-tuning and Prompt Tuning in the <a href="#gradient-based-search">“prompt design”</a> section.]</span>
<br />
<span style="color: #286ee0;">[Updated on 2021-09-19: Add <a href="##unlikelihood-training">“unlikelihood training”</a>.]</span></p>
<p>There is a gigantic amount of free text on the Web, several magnitude more than labelled benchmark datasets. The state-of-the-art language models (LM) are trained with unsupervised Web data in large scale. When generating samples from LM by iteratively sampling the next token, we do not have much control over attributes of the output text, such as the topic, the style, the sentiment, etc. Many applications would demand a good control over the model output. For example, if we plan to use LM to generate reading materials for kids, we would like to guide the output stories to be safe, educational and easily understood by children.</p>
<p>How to steer a powerful unconditioned language model? In this post, we will delve into several approaches for controlled content generation with an unconditioned langage model.
Note that model steerability is still an open research question. Each introduced method has certain pros & cons.</p>
<ol>
<li>Apply guided decoding strategies and select desired outputs at test time.</li>
<li>Optimize for the most desired outcomes via good prompt design.</li>
<li>Fine-tune the base model or steerable layers to do conditioned content generation.</li>
</ol>
<p>In the following discussion, we assume we have access to a pretrained generative language model \(p_\theta\). The model has learned the distribution over token sequences by optimizing for the next token prediction: \(\mathcal{L}_\text{ML} = - \sum_t \log p_\theta(x_t \vert x_{<t})\).</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#decoding-strategies" id="markdown-toc-decoding-strategies">Decoding Strategies</a> <ul>
<li><a href="#common-decoding-methods" id="markdown-toc-common-decoding-methods">Common Decoding Methods</a></li>
<li><a href="#guided-decoding" id="markdown-toc-guided-decoding">Guided Decoding</a></li>
<li><a href="#trainable-decoding" id="markdown-toc-trainable-decoding">Trainable Decoding</a></li>
</ul>
</li>
<li><a href="#smart-prompt-design" id="markdown-toc-smart-prompt-design">Smart Prompt Design</a> <ul>
<li><a href="#gradient-based-search" id="markdown-toc-gradient-based-search">Gradient-based Search</a></li>
<li><a href="#heuristic-based-search" id="markdown-toc-heuristic-based-search">Heuristic-based Search</a></li>
</ul>
</li>
<li><a href="#fine-tuning" id="markdown-toc-fine-tuning">Fine-tuning</a> <ul>
<li><a href="#conditional-training" id="markdown-toc-conditional-training">Conditional Training</a></li>
<li><a href="#rl-fine-tuning" id="markdown-toc-rl-fine-tuning">RL Fine-tuning</a></li>
<li><a href="#rl-fine-tuning-with-human-preferences" id="markdown-toc-rl-fine-tuning-with-human-preferences">RL Fine-tuning with Human Preferences</a></li>
<li><a href="#guided-fine-tuning-with-steerable-layer" id="markdown-toc-guided-fine-tuning-with-steerable-layer">Guided Fine-tuning with Steerable Layer</a></li>
<li><a href="#distributional-approach" id="markdown-toc-distributional-approach">Distributional Approach</a></li>
<li><a href="#unlikelihood-training" id="markdown-toc-unlikelihood-training">Unlikelihood Training</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="decoding-strategies">Decoding Strategies</h2>
<p>By adopting different decoding methods, we can place restrictions or preferences on the sampling process to alter the generated samples without modifying any model weights. Even though decoding strategies do not change the values of any trainable parameter, it is a quite important component.</p>
<h3 id="common-decoding-methods">Common Decoding Methods</h3>
<p>Since the final layer of the model predicts logits \(o\) over the vocabulary space, the next token can be sampled by applying softmax with temperature \(T\). The probability of sampling the \(i\)-th token is</p>
\[p_i \propto \frac{\exp(o_i / T)}{\sum_j \exp(o_j/T)}\]
<p>A low temperature would make the distribution sharper and a high value makes it softer.</p>
<p><strong>Greedy search</strong>: Always pick the next token with the <em>highest</em> probability, equivalent to setting temperature \(T=0\). However, it tends to create repetitions of phrases, even for well-trained models.</p>
<p><strong>Beam search</strong>: It essentially does breadth-first search, one token per tree level, but with a limited bandwidth. At each level of the search tree, beam search keeps track of \(n\) (named “beam width”) best candidates and expands all the successors of these candidates in the next level. Beam search could stop expanding a node if it hits the EOS (end-of-sentence) token.</p>
<p>However, maximization-based decoding does not guarantee high-quality generation.</p>
<p style="width: 65%;" class="center"><a name="beam-search-surprise"></a>
<img src="/lil-log/assets/images/beam_search_less_surprising.png" alt="Beam search probability" /></p>
<p class="image-caption"><em>Fig. 1. The probability assigned to the next token by beam search versus by humans. The human selected tokens have much higher variance in predicted probability and thus more surprising. (Image source: <a href="https://arxiv.org/abs/1904.09751">Holtzman et al. 2019</a>)</em></p>
<p><strong>Top-k sampling</strong> (<a href="https://arxiv.org/abs/1805.04833">Fan et al., 2018</a>): At each sampling step, only the top \(k\) most likely tokens are selected and the probability mass is redistributed among them. In <a href="https://arxiv.org/abs/1805.04833">Fan et al., 2018</a>, the authors proposed to use <em>top-k random sampling</em> where the next token is randomly selected among the top \(k\) most likely candidates and they argued that this approach can generate more novel and less repetitive content than beam search.</p>
<p><strong>Nucleus sampling</strong> (<a href="https://arxiv.org/abs/1904.09751">Holtzman et al. 2019</a>): Also known as “Top-p sampling”. One drawback of top-k sampling is that the predefined number \(k\) does not take into consideration how <em>skewed</em> the probability distribution might be. The nucleus sampling selects the smallest set of top candidates with the cumulative probability exceeding a threshold (e.g. 0.95) and then the distribution is rescaled among selected candidates.</p>
<p>Both top-k and nucleus sampling have less repetitions with a proper set of hyperparameters.</p>
<p><strong>Penalized sampling</strong> (<a href="https://arxiv.org/abs/1909.05858">Keskar et al. 2019</a>): To avoid the common failure case of generating duplicate substrings, the <a href="https://arxiv.org/abs/1909.05858">CTRL</a> paper proposed a new sampling method to penalize repetitions by discounting the scores of previously generated tokens. The probability distribution for the next token with repetition penalty is defined as:</p>
\[p_i = \frac{\exp(o_i / (T \cdot \mathbb{1}(i \in g)))}{\sum_j \exp(o_j / (T \cdot \mathbb{1}(j \in g)))} \quad
\mathbb{1}(c) = \theta \text{ if the condition }c\text{ is True else }1\]
<p>where \(g\) contains a set of previously generated tokens, \(\mathbb{1}(.)\) is an identity function. \(\theta=1.2\) is found to yield a good balance between less repetition and truthful generation.</p>
<h3 id="guided-decoding">Guided Decoding</h3>
<p>All the above standard decoding strategies sample tokens according to the predicted probability, with no additional information. Our preferences on topic or sentiment can be baked into the candidate ranking function to guide the sample generation by altering the candidate ranking score. The ranking score for token selection at each decoding step can be set as a combination of LM log-likelihood and a set of desired feature discriminators. The features are designed to quantify human preferences by heuristics (<a href="https://www.aclweb.org/anthology/P17-4008/">Ghazvininejad et al., 2017</a>), supervised learning (<a href="https://arxiv.org/abs/1805.06087">Holtzman et al., 2018</a>) or RL (<a href="https://arxiv.org/abs/1701.06549">Li et al., 2017</a>).</p>
<p><a href="https://www.aclweb.org/anthology/P17-4008/">Ghazvininejad et al. (2017)</a> built a system called “Hafez” for generating poetry in desired style by adjusting sampling weights in beam search at decoding steps. The likelihood of sampling for the next token \(x_{t+1}\) at step \(t\) is augmented by a scoring function:</p>
\[\text{score}(x_{t+1}, b_t) = \text{score}(b_t) + \log p(x_{t+1}) + \color{green}{\sum_i \alpha_i f_i(x_{t+1})}\]
<p>where \(\log p(x_{t+1})\) is the log-likelihood predicted by LM. \(\text{score}(b_t)\) is the accumulated score of the already-generated words in the current beam state \(b_t\). The green part can incorporate many different features for steering the style of the output. A set of feature functions \(f_i(.)\) define the preferences and the associated weights \(alpha_i\) work like “control knobs” that can be easily customized at decoding time. Features can measure a variety of attributes and can be easily combined; for example,</p>
<ul>
<li>whether \(x_{t+1}\) exists in a bag of desired or banned topical words.</li>
<li>whether \(x_{t+1}\) indicates certain sentiments.</li>
<li>whether \(x_{t+1}\) is a repeated token (and thus \(f_i\) needs to take the history as input too).</li>
<li>the length of \(x_{t+1}\) if longer or shorter words are in particular preferred.</li>
</ul>
<p>Similar to Hafez, <a href="https://arxiv.org/abs/1809.01215">Baheti et al. (2018)</a> manually designed features for ranking and altered the sampling distribution by appending similarity scores between topic distribution or embeddings of the context and the completion.</p>
<p><a href="https://arxiv.org/abs/1805.06087">Holtzman et al. (2018)</a> adopted a set of learned discriminators, each specializing in a different principle of communication guided by <a href="https://en.wikipedia.org/wiki/Cooperative_principle">Grice’s maxims</a>: quality, quantity, relation and manner. The discriminators learn to encode these desired principles by measuring repetition, entailment, relevance, and lexical diversity, respectively. Given some ground truth completion, all the discriminator models are trained to minimize the ranking log-likelihood, \(\log\sigma(f_i(y_g) - f_i(y))\), because the gold continuation \(y_g\) is expected to obtain a higher score than the generated one \(y\). Here the weight coefficients \(\alpha_i\) are also learned to minimize the score difference between the golden standard and the generated completion. Discriminative Adversarial Search (DAS; <a href="https://arxiv.org/abs/2002.10375">Scialom et al., 2020</a>) is inspired by GAN and trains the discriminator to tell apart human created text from machine generated text. The discriminator predicts a label for each token instead of for the entire sequence. The discriminator logprob is added to the score to guide sampling towards the human-written style.</p>
<p><a href="https://arxiv.org/abs/2010.02650">Meister et al. (2020)</a> studied beam search in a regularized decoding framework:</p>
\[\mathbf{y}^* = \arg\max_{\mathbf{y}\in\mathcal{Y}} \big( \underbrace{\log p_\theta(\mathbf{y}\vert\mathbf{x})}_\text{MAP} - \underbrace{\lambda\mathcal{R}(\mathbf{y})}_\text{regularizer} \big)\]
<p>Since we expect maximum probability to have minimum surprise, the surprisal of a LM at time step \(t\) can be defined as follows:</p>
\[\begin{aligned}
u_0(\texttt{BOS}) &= 0 \text{ ; BOS is a placeholder token for the beginning of a sentence.}\\
u_t(y) &= -\log P_\theta(y \vert \mathbf{x}, \mathbf{y}_{<t}) \text{ for }t \geq 1
\end{aligned}\]
<p>The MAP (maximum a posteriori) part demands for sequences with maximum probability given context, while the regularizer introduces other constraints. It is possible a global optimal strategy may need to have a high-surprisal step occasionally so that it can shorten the output length or produce more low-surprisal steps afterwards.</p>
<p>Beam search has gone through the test of time in the field of NLP. The question is: <em>If we want to model beam search as exact search in a regularized decoding framework, how should \(\mathcal{R}(\mathbf{y})\) be modeled?</em> The paper proposed a connection between beam search and the <em>uniform information density</em> (UID) hypothesis.</p>
<blockquote>
<p>“The uniform information density hypothesis (UID; Levy and Jaeger, 2007) states that—subject to the constraints of the grammar—humans prefer sentences that distribute information (in the sense of information theory) equally across the linguistic signal, e.g., a sentence.”</p>
</blockquote>
<p>In other words, it hypothesizes that humans prefer text with evenly distributed surprisal. Popular decoding methods like top-k sampling or nuclear sampling actually filter out high-surprisal options, thus implicitly encouraging the UID property in output sequences.</p>
<p>The paper experimented with several forms of regularizers:</p>
<ol>
<li><em>Greedy</em>: \(\mathcal{R}_\text{greedy}(\mathbf{y}) = \sum_{t=1}^{\vert\mathbf{y}\vert} \big(u_t(y_t) - \min_{y' \in \mathcal{V}} u_t(y') \big)^2\); if set \(\lambda \to \infty\), we have greedy search. Note that being greedy at each individual step does not guarantee global optimality.</li>
<li><em>Variance regularizer</em>: \(\mathcal{R}_\text{var}(\mathbf{y}) = \frac{1}{\vert\mathbf{y}\vert}\sum_{t=1}^{\vert\mathbf{y}\vert} \big(u_t(y_t) - \bar{u} \big)^2\) , where \(\bar{u}\) is the average surprisal over all timesteps. It directly encodes the UID hypothesis.</li>
<li><em>Local consistency</em>: \(\mathcal{R}_\text{local}(\mathbf{y}) = \frac{1}{\vert\mathbf{y}\vert}\sum_{t=1}^{\vert\mathbf{y}\vert} \big(u_t(y_t) - u_{t-1}(y_{t-1}) \big)^2\); this decoding regularizer encourages adjacent tokens to have similar surprisal.</li>
<li><em>Max regularizer</em>: \(\mathcal{R}_\text{max}(\mathbf{y}) = \max_t u_t(y_t)\) penalizes the maximum compensation of surprisal.</li>
<li><em>Squared regularizer</em>: \(\mathcal{R}_\text{square}(\mathbf{y}) = \sum_{t=1}^{\vert\mathbf{y}\vert} u_t(y_t)^2\) encourages all the tokens to have surprisal close to 0.</li>
</ol>
<p>An experiment with greedy regularizers showed that larger \(\lambda\) results in better performance (e.g. measured by BLEU for NMT task) and lower std dev of surprisal.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/beam-search-greedy-regularizer.png" alt="Greedy regularizer" /></p>
<p class="image-caption"><em>Fig. 2. The plot of BLEU and std. dev of surprisals as functions of the strength of the regularizer \(\lambda\). The subgraph in grey shows the relationship between BLEU and surprisal std. dev. (Image source: <a href="https://arxiv.org/abs/2010.02650">Meister et al. 2020</a>)</em></p>
<p>A default beam search would have text generation of decreased quality when beam size increases. Regularized beam search greatly helps alleviate this issue. A combined regularizer further improves the performance. In their experiments for NMT, they found \(\lambda=5\) for greedy and \(\lambda=2\) for squared work out as the optimal combined regularizer.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/beam-search-size-regularized.png" alt="Beam search size" /></p>
<p class="image-caption"><em>Fig. 3. The plot of BLEU of a function of beam size (left) and BLEU scores for translations created by different regularized decoding strategies. (Image source: <a href="https://arxiv.org/abs/2010.02650">Meister et al. 2020</a>)</em></p>
<p>Guided decoding essentially runs a more expensive beam search where the sampling probability distribution is altered by side information about human preferences.</p>
<h3 id="trainable-decoding">Trainable Decoding</h3>
<p>Given a trained language model, <a href="https://arxiv.org/abs/1702.02429">Gu et al (2017)</a> proposed a <strong>trainable greedy decoding</strong> algorithm to maximize an arbitrary objective for sampling sequences. The idea is based on the <em>noisy, parallel approximate decoding</em> (<a href="https://arxiv.org/abs/1605.03835">NPAD</a>). NPAD injects unstructured noise into the model hidden states and runs noisy decoding multiple times in parallel to avoid potential degradation. To take a step further, trainable greedy decoding replaces the unstructured noise with a learnable random variable, predicted by a RL agent that takes the previous hidden state, the previous decoded token and the context as input. In other words, the decoding algorithm learns a RL actor to manipulate the model hidden states for better outcomes.</p>
<p><a href="https://arxiv.org/abs/1906.09531">Grover et al. (2019)</a> trained a binary classifier to distinguish samples from data distribution and samples from the generative model. This classifier is used to estimate <em>importance weights</em> for constructing a new unnormalized distribution. The proposed strategy is called <strong>likelihood-free importance weighting (LFIW)</strong>.</p>
<p>Let \(p\) be the real data distribution and \(p_\theta\) be a learned generative model. A classical approach for evaluating the expectation of a given function \(f\) under \(p\) using samples from \(p_\theta\) is to use importance sampling.</p>
\[\mathbb{E}_{\mathbf{x}\sim p} [f(\mathbf{x})]
= \mathbb{E}_{\mathbf{x}\sim p_\theta} \Big[\frac{p(\mathbf{x})}{p_\theta(\mathbf{x})} f(\mathbf{x})\Big]
\approx \frac{1}{N} \sum_{i=1}^N w(\mathbf{x}_i)f(\mathbf{x}_i)\]
<p>However, \(p(\mathbf{x})\) can only be estimated via finite datasets. Let \(c_\phi: \mathcal{X} \to [0,1]\) be a probabilistic binary classifier for predicting whether a sample \(\mathbf{x}\) is from the true data distribution (\(y=1\)). The joint distribution over \(\mathcal{X}\times\mathcal{Y}\) is denoted as \(q(\mathbf{x}, y)\).</p>
\[q(\mathbf{x}\vert y) = \begin{cases}
p_\theta(\mathbf{x}) & \text{ if }y=0\text{; predicted to be generated data} \\
p(\mathbf{x}) & \text{ otherwise; from the true data distribution}
\end{cases}\]
<p>Then if \(c_\phi\) is <a href="https://svivek.com/teaching/lectures/slides/prob-learning/bayes-optimal-classifier.pdf">Bayes optimal</a>, the importance weight can be estimated by:</p>
\[w_\phi(\mathbf{x})
= \frac{p(\mathbf{x})}{p_\theta(\mathbf{x})}
= \frac{q(\mathbf{x} \vert y=1)}{q(\mathbf{x} \vert y=0)}
= \frac{q(y=0)}{q(y=1)} \frac{q(y=1 \vert \mathbf{x})}{q(y=0 \vert \mathbf{x})}
= \gamma \frac{c_\phi(\mathbf{x})}{1 - c_\phi(\mathbf{x})}\]
<p>where \(\gamma = \frac{q(y=0)}{q(y=1)} > 0\) is a fixed odd ratio.</p>
<p>Since we cannot learn a perfect optimal classifier, the importance weight would be an estimation \(\hat{w}_\phi\). A couple of practical tricks can be applied to offset cases when the classifier exploits artifacts in the generated samples to make very confident predictions (i.e. very small importance weights):</p>
<ol>
<li>Self-normalization: normalize the weight by the sum \(\hat{w}_\phi(\mathbf{x}_i) / \sum_{j=1}^N \hat{w}_\phi(\mathbf{x}_j)\).</li>
<li>Flattening: add a power scaling parameter \(\alpha > 0\), \(\hat{w}_\phi(\mathbf{x}_i)^\alpha\).</li>
<li>Clipping: specify a lower bound \(\max(\hat{w}_\phi(\mathbf{x}_i), \beta)\).</li>
</ol>
<p>To sample from an importance resampled generative model, \(\mathbf{x}\sim p_{\theta, \phi}(\mathbf{x}) \propto p_\theta(\mathbf{x})\hat{w}_\phi(\mathbf{x})\), they adopt SIR (Sampling-Importance-Resampling),</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SIR-importance-resampling.png" alt="SIR importance resampling" /></p>
<p class="image-caption"><em>Fig. 4. The algorithm for sampling from a generative model according to importance weights \(\hat{w}(\mathbf{x}_i)\) using SIR. (Image source: <a href="https://arxiv.org/abs/1906.09531">Grover et al., 2019)</a>)</em></p>
<p><a href="https://arxiv.org/abs/2004.11714">Deng et al., 2020</a> proposed to learn a EBM to steer a LM in the <a href="https://arxiv.org/abs/1906.03351">residual space</a>, \(P_\theta(x) \propto P_\text{LM}(x)\exp(-E_\theta(x))\), where \(P_\theta\) is the joint model; \(E_\theta\) is the residual energy function to be learned. If we know the partition function \(Z\), we can model the generative model for generative a sequence \(x_{p+1}, \dots, x_T\) as:</p>
\[P_\theta(x_{p+1:T}\vert x_{1:p}) = \frac{P_\text{LM}(x_{p+1:T}\vert x_{1:p}) \exp(-E_\theta(x_{1:T}))}{Z_\theta(x_{1:p})}\]
<p>The goal is to learn the parameters of the energy function \(E_\theta\) such that the joint model \(P_\theta\) gets closer to the desired data distribution. The residual energy function is trained by noise contrastive estimation (<a href="https://www.kdnuggets.com/2019/07/introduction-noise-contrastive-estimation.html">NCE</a>), considering \(P_\theta\) as the model distribution and \(P_\text{LM}\) as the noise distribution:</p>
\[\theta = \arg\max_{\theta} \mathbb{E}_{x^+ \sim P_\text{data}} \log\frac{1}{1+\exp(E_\theta(x^+))} + \mathbb{E}_{x^- \sim P_\text{LM}} \log\frac{1}{1+\exp(-E_\theta(x^-))}\]
<p>However, the partition function is intractable in practice. The paper proposed a simple way to first sample from the original LM and then to resample from them according to the energy function. This is unfortunately quite expensive.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/top-k-joint-sampling.png" alt="Top k joint sampling" /></p>
<p class="image-caption"><em>Fig. 5. Top k samples from the base LM are resampled according to the residual energy function. (Image source: <a href="https://arxiv.org/abs/2004.11714">Deng et al., 2020</a>)</em></p>
<h2 id="smart-prompt-design">Smart Prompt Design</h2>
<p>Large language models have been shown to be very powerful on many NLP tasks, even with only <em>prompting</em> and no task-specific fine-tuning (<a href="/lil-log/2019/01/31/generalized-language-models.html#gpt-2">GPT2</a>, <a href="/lil-log/2019/01/31/generalized-language-models.html#gpt-3">GPT3</a>). The prompt design has a big impact on the performance on downstream tasks and often requires time-consuming manual crafting. For example, factual questions can gain a big boost with smart prompt design in “closed-book exam” (<a href="https://arxiv.org/abs/2010.15980">Shin et al., 2020</a>, <a href="https://arxiv.org/abs/1911.12543">Jiang et al., 2020)</a>). I’m expecting to see an increasing amount of literature on automatic smart prompt design.</p>
<h3 id="gradient-based-search">Gradient-based Search</h3>
<p><strong>AutoPrompt</strong> (<a href="https://arxiv.org/abs/2010.15980">Shin et al., 2020</a>; <a href="http://ucinlp.github.io/autoprompt">code</a>) is a method to automatically create prompts for various tasks via gradient-based search. AutoPrompt constructs a prompt by combining the original task inputs \(x\) with a collection of trigger tokens \(x_\text{trig}\) according to a template \(\lambda\). The trigger tokens are shared across all inputs and thus <em>universally</em> effective.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/autoprompt.png" alt="AutoPrompt" /></p>
<p class="image-caption"><em>Fig. 6. The overview of AutoPrompt. The trigger tokens are retrieved to optimize for the target outputs across all inputs. (Image source: <a href="https://arxiv.org/abs/2010.15980">Shin et al., 2020</a>)</em></p>
<p>The universal trigger tokens are identified using a gradient-guided search strategy same as in <a href="https://arxiv.org/abs/1908.07125">Wallace et al., 2019</a>. The <em>universal</em> setting means that the trigger tokens \(x_\text{trig}\) can optimize for the target output \(\tilde{y}\) for all inputs from a dataset:</p>
\[x_\text{trig} = \arg\min_{x’_\text{trig}} \mathbb{E}_{x\sim\mathcal{X}} [\mathcal{L}(\tilde{y}, f(x’_\text{trig}; x))]\]
<p>The search operates in the embedding space. The embedding of every trigger token \(e_{\text{trig}_i}\) is first initialized to some default value and then gets updated to minimize the first-order Taylor expansion of the task-specific loss around the current token embedding:</p>
\[e^{(t+1)}_\text{trig} = \arg\min_{e\in\mathcal{V}} [e - e^{(t)}_{\text{trig}_i}]^\top \nabla_{e^{(t)}_{\text{trig}_i}} \mathcal{L}\]
<p>where \(\mathcal{V}\) refers to the embedding matrix of all the tokens. \(\nabla_{e^{(t)}_{\text{trig}_i}} \mathcal{L}\) is the average gradient of the task loss over a batch at iteration \(t\). We can brute-force the optimal \(e\) by a \(\vert \mathcal{V} \vert d\)-dimensional dot product, which is cheap and can be computed in parallel.</p>
<p style="width: 62%;" class="center"><img src="/lil-log/assets/images/universal-adv-triggers.png" alt="Universal adversarial trigger" /></p>
<p class="image-caption"><em>Fig. 7. We search for trigger tokens by updating their embeddings with the gradient of the task loss per batch. (Image source: <a href="https://arxiv.org/abs/1908.07125">Wallace et al., 2019</a>)</em></p>
<p>The above token replacement method can be augmented with beam search. When looking for the optimal token embedding \(e\), we can pick top-\(k\) candidates instead of a single one, searching from left to right and score each beam by \(\mathcal{L}\) on the current data batch.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/autoprompt-examples.png" alt="AutoPrompt examples" /></p>
<p class="image-caption"><em>Fig. 8. Example prompts discovered by AutoPrompt for different tasks. (Image source: <a href="https://arxiv.org/abs/2010.15980">Shin et al., 2020</a>)</em></p>
<p>Smart prompt design essentially produces efficient context that can lead to desired completion. Motivated by this observation, <a href="https://arxiv.org/abs/2101.00190">Li & Liang (2021)</a> proposed <strong>Prefix-Tuning</strong> which assigns a small number of trainable parameters at the beginning of an input sequence (named “prefix”) to steer a LM, \([\text{PREFIX}; x; y]\). Let \(\mathcal{P}_\text{idx}\) be a set of prefix indices and \(\text{dim}(h_i)\) be the embedding size. The prefix parameters \(P_\theta\) has the dimension \(\vert\mathcal{P}_\text{idx}\vert \times \text{dim}(h_i)\) and the hidden state takes the form:</p>
\[h_i = \begin{cases}
P_\theta[i,:], & \text{if }i \in \mathcal{P}_\text{idx}\\
\text{LM}_\phi(z_i, h_{<i}), & \text{otherwise}
\end{cases}\]
<p>Note that only \(P_\theta\) is trainable and the LM parameters \(\phi\) is frozen during training.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/prefix-tuning.png" alt="Prefix-tuning" /></p>
<p class="image-caption"><em>Fig. 9. Illustrations of fine-tuning versus prefix-tuning. (Image source: <a href="https://arxiv.org/abs/2101.00190">Li & Liang 2021</a>)</em></p>
<p>The prefix parameters do not tie to any embeddings associated with the real words and thus they are more <em>expressive</em> for steering the context. Direct optimizing \(P_\theta\) unfortunately results in poor performance. To reduce the difficulty associated with high dimensionality training, the matrix \(P_\theta\) is reparameterized by a smaller matrix \(P'_\theta \in \mathbb{R}^{\vert\mathcal{P}_\text{idx}\vert \times c}\) and a large feed forward network \(\text{MLP}_\theta \in \mathbb{R}^{c\times \text{dim}(h_i)}\).</p>
<p>The performance increases with the prefix length \(\vert\mathcal{P}_\text{idx}\vert\) up to some value. And this value varies with tasks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/prefix-tuning-length.png" alt="Prefix-tuning" /></p>
<p class="image-caption"><em>Fig. 10. Task performance, summarization (left) and table-to-text (right), as a function of prefix length. (Image source: <a href="https://arxiv.org/abs/2101.00190">Li & Liang 2021</a>)</em></p>
<p>A few other interesting learnings from their ablation studies include:</p>
<ul>
<li>Tuning only the embedding layer (without prefix) is not sufficiently expressive.</li>
<li>Placing the trainable parameter between \(x\) and \(y\), \([x; \text{INFIX}; y]\), slightly underperforms prefix-tuning, likely because it only affects the context for \(y\) while prefix affects both.</li>
<li>Random initialization of \(P_\theta\) leads to low performance with high variance. In contrast, initializing \(P_\theta\) with activations of real words improves generation, even the words are irrelevant to the task.</li>
</ul>
<p>Fine-tuned models achieve better task performance but they can fail in the low data regime. Both AutoPrompt and Prefix-Tuning were found to outperform fine-tuning in the regime where the training dataset is small (i.e. \(10^2-10^3\) samples). As an alternative to fine-tuning, prompt design or learning the context embedding is much cheaper. AutoPrompt improves the accuracy for sentiment classification a lot more than manual prompts and achieves similar performance as linear probing. For the NLI task, AutoPrompt obtains higher accuracy than linear probing. It is able to retrieve facts more accurately than manual prompts too. In low data regime, Prefix-Tuning achieves performance comparable with fine-tuning on table-to-text generation and summarization.</p>
<p>Two successive works, <strong>P-tuning</strong> (<a href="https://arxiv.org/abs/2103.10385">Liu et al. 2021</a>; <a href="https://github.com/THUDM/P-tuning">code</a>) and <strong>Prompt Tuning</strong> (<a href="https://arxiv.org/abs/2104.08691">Lester et al. 2021</a>), follow the similar idea of explicit training continuous prompt embeddings but with a few different choices over the trainable parameters and architecture. Different from Prefix-Tuning which concatenates continuous prompt tokens in every hidden state layer of the transformer, both P-tuning and Prompt Tuning non-invasively add continuous prompts <em>only in the input</em> to work well.</p>
<p>Let \([P_i]\) be the \(i\)-th token in the prompt template of <strong>P-tuning</strong> (<a href="https://arxiv.org/abs/2103.10385">Liu et al. 2021</a>), we can denote a prompt as a sequence \(T=\{[P_{0:i}], \mathbf{x}, [P_{i+1:m}], \mathbf{y}\}\). Each token \([P_i]\) does not have to be a real token in the model vocabulary (“pseudo-token”), and thus the encoded template \(T^e\) looks like the following and the pseudo-token hidden state can be optimized with gradient descent.</p>
\[T^e = \{ h_0, \dots, h_i, \text{embed}(\mathbf{x}), h_{i+1}, \dots, h_m, \text{embed}(\mathbf{y})\}\]
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/p-tuning.png" alt="P-tuning" /></p>
<p class="image-caption"><em>Fig. 11. The illustration of P-tuning. Sometimes, adding a few task-related anchor tokens, such as “capital” in the figure, can bring further improvement. (Image source: <a href="https://arxiv.org/abs/2103.10385">Liu et al. 2021</a>)</em></p>
<p>There are two major optimization challenges in P-tuning:</p>
<ol>
<li>Discreteness: The word embedding of a pretrained language model are highly discrete. It is hard to optimize \(h_i\) if they are intialized at random.</li>
<li>Association: \(h_i\) should be dependent on each other. Thus they develop a mechanism to model this dependency by training a light-weighted LSTM-based prompt encoder:</li>
</ol>
\[h_i = \text{MLP}([\text{LSTM}(h_{0:i}): \text{LSTM}(h_{i:m})])\]
<p>P-tuning is more flexible than prefix-tuning, as it inserts trainable tokens in the middle of a prompt not just at the beginning. The usage of task-specific anchor tokens is like combining manual prompt engineering with trainable prompts.</p>
<p><strong>Prompt Tuning</strong> (<a href="https://arxiv.org/abs/2104.08691">Lester et al. 2021</a>) largely simplifies the idea of prefix tuning by only allowing an additional \(k\) tunable tokens per downstream task to be prepended to the input text. The conditional generation is \(p_{\theta, \theta_P}(Y \vert [P; X])\), where \(P\) is the “pseudo prompt” with parameters \(\theta_P\) trainable via back-propagation. Both \(X\) and \(P\) are embedding vectors and we have \(X \in \mathbb{R}^{n \times d^e}, P \in \mathbb{R}^{k \times d^e}\) and \([P;X] \in \mathbb{R}^{(n+k) \times d^e}\), where \(d^e\) is the embedding space dimensionality.</p>
<ul>
<li>Prompt tuning produces competitive results as model fine-tuning when the model gets <em>large</em> (billions of parameters and up). This result is especially interesting given that large models are expensive to fine-tune and execute at inference time.</li>
<li>With learned task-specific parameters, prompt tuning achieves better transfer learning when adapting to new domains. It outperforms fine-tuning on domain shift problems.</li>
<li>They also showed that prompt ensembling of multiple prompts for the same task introduces further improvement.</li>
</ul>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/prompt-tuning.png" alt="Prompt-tuning" /></p>
<p class="image-caption"><em>Fig. 12. The illustration of how Prompt Tuning works. (Image source: <a href="https://arxiv.org/abs/2104.08691">Lester et al. 2021</a>)</em></p>
<p>The experiments investigated several prompt initialization schemes:</p>
<ol>
<li>Random initialization by uniformly sampling from [-0.5, 0.5];</li>
<li>Sample embeddings of top 5000 common tokens;</li>
<li>Use the embedding values of the class label strings. If we don’t have enough class labels to initialize the soft-prompt, we fall back to scheme 2.
Random initialization performs noticeably worse than the other two options.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/prompt-tuning-exp1.png" alt="Prompt-tuning-exp1" /></p>
<p class="image-caption"><em>Fig. 13. The effect of (a) different prompt initialization schemes and (b) different prompt lengths. (Image source: <a href="https://arxiv.org/abs/2104.08691">Lester et al. 2021</a>)</em></p>
<p>The pre-training objectives also have a big impact on the quality of prompt tuning. T5’s “span corruption” is not a good option here.</p>
<p>Prompt tuning is found to be less likely to overfit to a specific dataset. To evaluate the robustness to data shifting problem, they trained the model on one dataset of one task and evaluated it on the test dataset but in a <em>different domain</em>. Prompt tuning is more resilient and can generalize to different domains better.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/prompt-tuning-exp2.png" alt="Prompt-tuning-exp2" /></p>
<p class="image-caption"><em>Fig. 14. Prompt tuning is more resilient to domain shift between train and test sets. (Image source: <a href="https://arxiv.org/abs/2104.08691">Lester et al. 2021</a>)</em></p>
<h3 id="heuristic-based-search">Heuristic-based Search</h3>
<p>Paraphrasing is a quick way to explore more prompts similar to the known version, which can be done via <em>back-translation</em>. Using back-translation, the initial prompt is translated into \(B\) candidates in another language and then each is translated back into \(B\) candidates in the original language. The resulting total \(B^2\) candidates are scored and ranked by their round-trip probabilities.</p>
<p><a href="https://www.aclweb.org/anthology/P18-1079/">Ribeiro et al (2018)</a> identified <em>semantically equivalent adversaries (SEA)</em> by generating a variety of paraphrases \(\{x'\}\) of input \(x\) until it triggers a different prediction of target function \(f\):</p>
\[\begin{aligned}
SEA(x, x') &= \mathbb{1}[\text{SemEq}(x, x') \land f(x) \neq f(x')] \\
\text{where SemEq}(x, x') &= \mathbb{1}[\min\Big(1, \frac{p(x'\vert x)}{p(x\vert x)} \Big) \geq \tau]
\end{aligned}\]
<p>The rules extracted from SEA are considered as “bugs” in the model. Applying those rules as data augmentation in model training helps robustify the model and fix bugs.</p>
<p><a href="https://arxiv.org/abs/1911.12543">Jiang et al (2020)</a> attempts to validate whether a trained language model knows certain knowledge by automatically discovering better prompts to query. Within the scope of knowledge retrieval where factual knowledge is represented in the form of a triple \(\langle x, r, y \rangle\) (subject, relation, object). The prompts can be mined from training sentences (e.g. Wikipedia description) or expanded by paraphrase.</p>
<p>Interestingly some small modifications in the prompts may lead to big gain, as shown in Fig. X.</p>
<p style="width: 52%;" class="center"><img src="/lil-log/assets/images/prompt-small-modifications.png" alt="Small modifications" /></p>
<p class="image-caption"><em>Fig. 15. Small modifications in prompt templates can lead to big performance gains: replacement in blue, insertion in green, deletion in red. (Image source: <a href="https://arxiv.org/abs/1911.12543">Jiang et al., 2020</a>)</em></p>
<h2 id="fine-tuning">Fine-tuning</h2>
<p>Fine-tuning is an intuitive way to guide a LM to output desired content, commonly by training on supervised datasets or by RL. We can fine-tune all the weights in the model or restrict the fine-tuning to only top or additional layers.</p>
<h3 id="conditional-training">Conditional Training</h3>
<p>Conditional training aims to learn a generative model conditioned on a control variable \(z\), \(p(y \vert x, z)\).</p>
<p><a href="https://arxiv.org/abs/1805.04833">Fan et al (2018)</a> trained a conditional language model for 2-step story generation. First, a model outputs the story sketch and then a story writing model creates a story following that sketch. The mechanism of conditioning on the sketch is implemented by a <em>fusion</em> model architecture. The fusion model enforces a form of <em>residual learning</em> that allows the story writing model to focus on learning what the first sketch generation model is missing. Also for story generation, <a href="https://www.aclweb.org/anthology/W18-1505/">Peng et al (2018)</a> experimented with an ending valence-conditioned story generator LM, \(p(x_t \vert x_{<t}, z)\) where \(z\) is the label of the story ending (sad, happy or neutral). Their language model is a bidirectional LSTM and the label is mapped into a learned embedding which then blends into the LSTM cell.</p>
<p><a name="ctrl"></a><strong>CTRL</strong> (<a href="https://arxiv.org/abs/1909.05858">Keskar et al., 2019</a>; <a href="https://github.com/salesforce/ctrl">code</a>) aims to train a language model conditioned control code \(z\) using controllable datasets. CTRL learns the conditioned distribution \(p(x \vert z)\) by training on raw text sequences with <em>control code prefixes</em>, such as <code class="language-plaintext highlighter-rouge">[horror]</code>, <code class="language-plaintext highlighter-rouge">[legal]</code>, etc. Then the learned model is able to generate text with respect to the prompt prefix. The training data contains Wikipedia, OpenWebText, books, Amazon reviews, reddit corpus and many more, where each dataset is assigned with a control code and subreddit in the reddit corpus has its own topic as control code.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/CTRL-control-code.png" alt="CTRL examples" /></p>
<p class="image-caption"><em>Fig. 16. Datasets used for training CTRL and associated control codes. (Image source: Edited from Table 7 in <a href="https://arxiv.org/abs/1909.05858">Keskar et al., 2019</a>)</em></p>
<p>The control code also can be used for <em>domain annotation</em> given tokens, because \(p(z \vert x) \propto p(x \vert z) p(z)\), assuming the prior over domains is uniform. One limitation of CTRL is the lack of control for <em>what not to generate</em> (e.g. avoid toxicity).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CTRL-examples.png" alt="CTRL examples" /></p>
<p class="image-caption"><em>Fig. 17. The examples of conditioned sample generation by CTRL. (Image source: <a href="https://arxiv.org/abs/1909.05858">Keskar et al., 2019</a>)</em></p>
<p>Note that CTRL trains a transformer model from scratch. However, labelling all the text within the same dataset with the same control code (e.g. All the wikipedia articles have “wikipedia” as control code) feels quite constrained. Considering that often we need highly customized control codes but only have a limited amount of labelled data, I would expect fine-tuning an unconditional LM with a small labelled dataset in the same way as CTRL to work out well too. Although how much data is needed and how good the sample quality might be are subject to experimentation.</p>
<h3 id="rl-fine-tuning">RL Fine-tuning</h3>
<p>Fine-tuning a sequential model with RL regarding any arbitrary and possibly non-differentiable reward function has been proved to work well years ago (<a href="https://arxiv.org/abs/1511.06732">Ranzato et al., 2015</a>). RL fine-tuning can resolve several problems with <em>teacher forcing</em> method. With teacher forcing, the model only minimizes a maximum-likelihood loss at each individual decoding step during training but it is asked to predict the entire sequence from scratch at test time. Such a discrepancy between train and test could lead to exposure bias and accumulated error. In contrast, RL fine-tuning is able to directly optimize task-specific metrics on the sequence level, such as BLEU for translation (<a href="https://arxiv.org/abs/1511.06732">Ranzato et al., 2015</a>, <a href="https://arxiv.org/abs/1609.08144">Wu et al., 2016</a>, <a href="https://arxiv.org/abs/1707.07402">Nguyen et al., 2017</a>), ROUGE for summarization (<a href="https://arxiv.org/abs/1511.06732">Ranzato et al., 2015</a>, <a href="https://arxiv.org/abs/1705.04304">Paulus et al., 2017</a>, <a href="https://arxiv.org/abs/1804.07036">Wu and Hu, 2018</a>) and customized metric for story generation (<a href="https://arxiv.org/abs/1809.10736">Tambwekar et al., 2018</a>).</p>
<p><a href="https://arxiv.org/abs/1511.06732">Ranzato et al (2015)</a> applied REINFORCE to train RNN models for sequence generation tasks. The model is first trained to predict the next token using cross-entropy loss (ML loss) and then fine-tuned alternatively by both ML loss and REINFORCE (RL loss). At the second fine-tuning stage, the number of training steps for next-token prediction is gradually decreasing until none and eventually only RL loss is used. This sequence-level RL fine-tuning was shown by experiments to lead to great improvements over several supervised learning baselines back then.</p>
<p>Google implemented the similar approach in their neural machine translation system (<a href="https://arxiv.org/abs/1609.08144">Wu et al., 2016</a>) and <a href="https://arxiv.org/abs/1705.04304">Paulus et al (2017)</a> adopted such approach for summarization task. The training objective contains two parts, ML loss for next token prediction, \(\mathcal{L}_\text{ML} = \sum_{(x, y^*)\sim\mathcal{D}} \log p_\theta(y^* \vert x)\), and RL loss \(\mathcal{L}_\text{RL}\) for maximizing the expected reward where the reward per sequence is measured by BLEU or ROUGE. The model is first trained with \(\mathcal{L}_\text{ML}\) until convergence and then fine-tuned with a linear combination of two losses, \(\mathcal{L}_\text{mix} = \alpha \mathcal{L}_\text{ML} + (1 - \alpha)\mathcal{L}_\text{RL}\).</p>
<p>The RL loss of Google NMT is to maximize the expected BLEU score:</p>
<p>\(\mathcal{L}_\text{RL} = - \sum_{(x, y^*)\sim\mathcal{D}} \mathbb{E}_{y\sim p_\theta(.\vert x)} [R(y, y^*)]\)
where \(y\) is the predicted sequence and \(y^*\) is the ground truth.</p>
<p><a href="https://arxiv.org/abs/1705.04304">Paulus et al (2017)</a> added an extra weighting term based on the reward difference between two output sequences, \(y\) by sampling the next token according to the predicted probability and \(\hat{y}\) by greedily taking the most likely token. This RL loss maximizes the conditional likelihood of the sampled sequence \(y\) if it obtains a higher reward than the greedy baseline \(\hat{y}\):</p>
\[\mathcal{L}_\text{RL} = \sum_{(x, y^*)\sim\mathcal{D}} (R(\hat{y}, y^*) - R(y, y^*)) \sum_{t=1}^{n'} \log p(y_t \vert y_{<t}, x)\]
<h3 id="rl-fine-tuning-with-human-preferences">RL Fine-tuning with Human Preferences</h3>
<p>Reward learning is critical for defining human preferences. Quantitative measurement like BLEU or ROUGE computes the overlap of words and n-gram phrases between sequences and does not always correlate with better quality by human judges. Reward learning from human feedback (<a href="https://arxiv.org/abs/1706.03741">Christiano et al., 2017</a>) is a better way to align what we measure with what we actually care about. Human feedback has been applied to learn a reward function for applications like story generation (<a href="https://arxiv.org/abs/1904.13015">Yi et al., 2019</a>) and summarization (<a href="https://arxiv.org/abs/1909.01214">Böhm et al., 2019</a>, <a href="https://arxiv.org/abs/1909.08593">Ziegler et al., 2019</a>, <a href="https://arxiv.org/abs/2009.01325">Stiennon et al., 2020</a>).</p>
<p>In order to generate more coherent conversation, <a href="https://arxiv.org/abs/1904.13015">Yi et al (2019)</a> collected 4 types of binary human feedback given a conversation pair (user utterance, system response), whether the system response is (1) comprehensive, (2) on topic, (3) interesting and (4) leading to continuation of the conversation.
An evaluator is trained to predict human feedback and then is used to rerank the beam search samples, to finetune the model or to do both. (Actually they didn’t use RL fine-tuning but rather use the evaluator to provide a discriminator loss in supervised fine-tuning.)</p>
<p>Let’s define a learned reward function \(R_\psi(x, y)\) parameterized by \(\psi\) as a measurement for the quality of output \(y\) given the input \(x\).</p>
<p>To learn the ground truth reward \(R^*\) defined by human judgements, <a href="https://arxiv.org/abs/1909.01214">Böhm et al (2019)</a> compared two loss functions:</p>
<p>(1) Regression loss: simply minimizing the mean squared error.</p>
\[\mathcal{L}^\text{MSE}_\text{rm} = [R^*(x, y) - R_\psi(x, y)]^2\]
<p>(2) Preference loss: learning to agree with the ground truth reward,</p>
\[\begin{aligned}
\mathcal{L}^\text{pref}_\text{rm} =& - \sum_{i,j} \big(\mathbb{1}[R^*(x, y_i) > R^*(x, y_j)] \log P(y_i \succ y_j) + \\
&\mathbb{1}[R^*(x, y_j) > R^*(x, y_i)] \log P(y_j \succ y_i) \big)\\
\text{where }P(y_i \succ y_j) =& \frac{\exp(R_\psi(x, y_i))}{\exp(R_\psi(x, y_i)) + \exp(R_\psi(x, y_j))}
\end{aligned}\]
<p>Their experiments showed that the <em>preference loss</em> achieves the best performance, where the reward model is a thin MLP layer on top of BERT sentence embedding.</p>
<p><a href="https://arxiv.org/abs/1909.08593">Ziegler et al (2019)</a> collected human labels by asking humans to select the best candidate \(y_b\) out of a few options \(\{y_i\}\) given the input \(x \sim \mathcal{D}\). The candidates are sampled by \(y_0, y_1 \sim p(.\vert x), y_2, y_3 \sim \pi(.\vert x)\). We should be aware that human labeling might have very high disagreement when the ground truth is fuzzy.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/finetune-human-feedback.png" alt="Human feedback fine-tuning" /></p>
<p class="image-caption"><em>Fig. 18. The overview of the training framework for fine-tuning a language model policy with reward learned from human feedback. (Image source: <a href="https://arxiv.org/abs/1909.08593">Ziegler et al., 2019</a>)</em></p>
<p>The reward model is implemented by a pretrained language model with an extra random linear layer of the final embedding output. It it trained to minimize the loss:</p>
\[\mathcal{L}_\text{rm} = -\mathbb{E}_{(x, \{y_i\}, b) \sim \mathcal{D}} \Big[ \log \frac{\exp(R_\psi(x, y_b))}{\sum_i \exp(R_\psi(x, y_i))} \Big]\]
<p>To keep the scale consistent during training, the reward model is normalized to have mean 0 and variance 1.</p>
<p><a name="kl-penalty"></a>During RL fine-tuning, the policy \(\pi\), initialized by a pretrained language model \(p\), is optimized via <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#ppo">PPO</a> with the above learned reward model. To avoid the policy’s deviating from its original behavior too much, a <strong>KL penalty</strong> is added:</p>
\[R(x, y) = R_\psi(x, y) - \beta\log\frac{\pi(y \vert x)}{p(y \vert x)}\]
<p>If running online data collection, human label collection process is continued during RL fine-tuning and thus the human labelers can review results generated by the latest policy. The number of human labels are evenly spread out during the training process. Meanwhile the reward model is also retrained periodically. Online data collection turns out to be important for the summarization task but not for the text continuation task. In their experiments, jointly training the reward model and the policy with shared parameters did not work well and can lead to overfitting due to the big imbalance between dataset sizes.</p>
<p>In the following work (<a href="https://arxiv.org/abs/2009.01325">Stiennon et al., 2020</a>), the human label collection was further simplified to select the best option between a pair of summaries, \(y_b \in\{y_0, y_1\}\) The reward model loss was updated to optimize the log odds of the selected summary:</p>
\[\mathcal{L}_\text{rm} = \mathbb{E}_{(x, y_0, y_1, b)\sim\mathcal{D}} [\log(\sigma(r_\theta(x, y_b) − r_\theta(x, y_{1−b})))]\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/summarize-human-feedback.png" alt="Human feedback fine-tuning 2" /></p>
<p class="image-caption"><em>Fig. 19. The overview of fine-tuning the language model policy from human feedback for summarization, including (1) human feedback collection, (2) reward model training, and (3) policy training. (Image source: <a href="https://arxiv.org/abs/2009.01325">Stiennon et al., 2020</a>)</em></p>
<h3 id="guided-fine-tuning-with-steerable-layer">Guided Fine-tuning with Steerable Layer</h3>
<p>Instead of fine-tuning the entire model, only fine-tuning a small extra set of parameters while the base model stays fixed is computationally cheaper.</p>
<p><a name="pplm"></a>In computer vision, plug-and-play generative networks (PPGN; <a href="https://arxiv.org/abs/1612.00005">Nguyen et al., 2017</a>) generate images with different attributes by plugging a discriminator \(p(a \vert x)\) into a base generative model \(p(x)\). Then the sample with a desired attribute \(a\) can be sampled from \(p(x \vert a) \propto p(a \vert x)p(x)\). Inspired by PPGN, the <strong>plug-and-play language model</strong> (<strong>PPLM</strong>; <a href="https://arxiv.org/abs/1912.02164">Dathathri et al., 2019</a>) combines one or multiple simple attribute models with a pretrained language model for controllable text generation.</p>
<p>Given an attribute \(a\) and the generated sample \(x\), let an attribute model be \(p(a\vert x)\). To control content generation, the current latent representation at time \(t\), \(H_t\) (containing a list of key-value pairs per layer), can be shifted by \(\Delta H_t\) in the direction of the sum of two gradients:</p>
<ul>
<li>One toward higher log-likelihood of the attribute \(a\) under \(p(a \vert x)\) — so that the output content acquires a desired attribute.</li>
<li>The other toward higher log-likelihood of the unmodified language model \(p(x)\) — so that the generated text is still in fluent and smooth natural language.</li>
</ul>
<p>To shift the output, at decoding time, PPLM runs one forward → one backward → one forward, three passes in total:</p>
<ol>
<li>First a forward pass is performed to compute the likelihood of attribute \(a\) by \(p(a\vert x)\);</li>
<li>Let \(\Delta H_t\) be a stepwise update to the hidden state \(H_t\) such that \((H_t + \Delta H_t)\) shifts the distribution of generated text closer to having the attribute \(a\). \(\Delta H_t\) is initialized at zero.
Then a backward pass updates the LM hidden states using normalized gradients from the attribute model \(\nabla_{\Delta H_t} \log p(a \vert H_t + \Delta H_t)\) as
\(\Delta H_t \leftarrow \Delta H_t + \alpha \frac{\nabla_{\Delta H_t} \log p(a|H_t + \Delta H_t)}{\| \nabla_{\Delta H_t} \log p(a|H_t + \Delta H_t) \|^\gamma}\)
where \(\gamma\) is a normalization scaling coefficient, set per layer. \(\alpha\) is step size. This update can be repeated \(m \in [3, 10]\) times</li>
<li>The final forward pass recomputes a new distribution over the vocabulary, generated from the updated latents \(\tilde{H}_t = H_t + \Delta H_t\). The next token is sampled from the updated distribution.</li>
</ol>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/PPLM.png" alt="PPLM" /></p>
<p class="image-caption"><em>Fig. 20. The overview of how PPLM runs three passes to update the model output to increase the likelihood of a desired attribute. (Image source: <a href="https://arxiv.org/abs/1912.02164">Dathathri et al., 2019</a>)</em></p>
<p>Multiple attribute models can be mix-and-matched during generation with customized weights, acting as a set of “control knobs”. The PPLM paper explored two types of attribute models:</p>
<ol>
<li>The simplest attribution model is based on a predefined <em>bag of words</em> (BoW), \(\{w_1, \dots, w_k\}\), that specifies a topic of interest.<br />
\(\log p(a \vert x) = \log\big( \sum_{i=1}^k p_{t+1} [w_i] \big)\)
<br />To encourage the model to output the desired words at least once but not at every step, they normalize the gradient by the maximum gradient norm.
<br />Interestingly, they found that increasing the probability of generating words in the bag also increases the probability of generating <em>related</em> but not identical words about the same topic.</li>
<li>The discriminator attribute models are based on learned classifiers which define preferences by a distribution instead of hard samples.</li>
</ol>
<p>To ensure the fluency in language, PPLM applied two additional designs:</p>
<ol>
<li>Minimizing the KL diverge between modified and unmodified LM, commonly seen in other RL fine-tuning approaches (see <a href="#kl-penalty">above</a>).</li>
<li>It performs <a href="https://arxiv.org/abs/1809.00125">post-norm fusion</a> to constantly tie the generated text to the unconditional LM \(p(x)\), \(x_{t+1} \sim \frac{1}{\beta}(\tilde{p}_{t+1}^{\gamma_\text{gm}} p_{t+1}^{1-\gamma_\text{gm}})\), where \(p_{t+1}\) and \(\tilde{p}_{t+1}\) are the unmodified and modified output distributions, respectively. \(\beta\) is a normalizing factor. \(\gamma_\text{gm} \in [0.8, 0.95]\) balances between prediction from before and after models.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/PPLM-examples.png" alt="PPLM examples" /></p>
<p class="image-caption"><em>Fig. 21. Examples of controllable text generation by PPLM. (Image source: <a href="https://arxiv.org/abs/1912.02164">Dathathri et al., 2019</a>)</em></p>
<p>Interestingly, they found a large variance in the extent of controllability across topics. Some topics (religion, science, politics) are easier to control for compared to others (computers, space).</p>
<p>One obvious drawback of PPLM is that due to multiple passes at every decoding step, the test time computation becomes much more expensive.</p>
<p>Similar to PPLM, <strong>DELOREAN</strong> (DEcoding for nonmonotonic LOgical REAsoNing; <a href="https://arxiv.org/abs/2010.05906">Qin et al., 2020</a>) incorporates the future context by back-propagation. Given input text \(\mathbf{x}\), DELOREAN aims to generate continuation completion \(\mathbf{y} = [y_1, \dots, y_N]\) such that \(y\) satisfies certain constraints defined by a context \(z\). To keep the generation differentiable, a soft representation of \(y\) is tracked, \(\tilde{\mathbf{y}}=(\tilde{y}_1, \dots, \tilde{y}_N)\) where \(\tilde{y}_i \in \mathbb{R}^V\) are logits over the vocabulary. \(\tilde{\mathbf{y}}^{(t)}\) is the soft representation at iteration \(t\).</p>
<p>Given the representation \(\tilde{y}^{(t-1)}\) at iteration \(t\), it runs the following procedures:</p>
<ol>
<li><strong>Backward</strong>: The constraint is represented as a loss function \(\mathcal{L}(\mathbf{x}, \tilde{\mathbf{y}}^{(t-1)}, z))\). The logits are updated via gradient descent: \(\tilde{y}^{(t), b}_n = \tilde{y}_n^{(t-1)} - \lambda \nabla_{\tilde{y}_n} \mathcal{L}(\mathbf{x}, \tilde{\mathbf{y}}^{(t-1)}, z)\).</li>
<li><strong>Forward</strong>: Run forward pass to ensure the generated text is fluent. \(\tilde{y}^{(t),f}_n = \text{LM}(\mathbf{x}, \tilde{\mathbf{y}}^{(t)}_{1:n-1})\).</li>
<li>Then linearly combine two logits together to create a new representation \(\tilde{y}^{(t)}_n = \gamma \tilde{y}^{(t), f}_n + (1-\gamma) \tilde{y}^{(t), b}_n\). Note that each \(\tilde{y}^{(t)}_n\) is needed to sample the next \(\tilde{y}^{(t),f}_{n+1}\).</li>
</ol>
<p><strong>Side-tuning</strong> (<a href="https://arxiv.org/abs/1912.13503">Zhang et al., 2019</a>) trains a light-weighted side network that learns a residual on top of the original model outputs without modifying the pre-trained model weights. Unlike PPLM, no gradient update is applied on the hidden states. It is a simple yet effective approach for incremental learning. The base model is treated as a black-box model and does not necessarily have to be a neural network. Side-tuning setup assumes the base and side models are fed exactly the same input and the side model is independently learned.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/side-tuning.png" alt="Side-tuning" /></p>
<p class="image-caption"><em>Fig. 22. Comparison of fixed weights, fine-tuning and side-tuning. (Image source: <a href="https://arxiv.org/abs/1912.13503">Zhang et al., 2019</a>)</em></p>
<p>The paper explored different strategies of fusing predictions from the base and side models: <code class="language-plaintext highlighter-rouge">product</code> is the worst while <code class="language-plaintext highlighter-rouge">sum</code> (\(\alpha\)-blending), MLP, and <a href="https://arxiv.org/abs/1709.07871">FiLM</a> are comparable. Side-tuning is able to achieve better performance, when it is trained with intermediate amounts of data and when the base network is large.</p>
<p><strong>Auxiliary tuning</strong> (<a href="https://arxiv.org/abs/2006.16823">Zeldes et al., 2020</a>) supplements the original pre-trained model with an <em>auxiliary</em> model that shifts the output distribution according to the target task. The base and auxiliary model outputs are merged on the logits level. The combined model is trained to maximize the likelihood \(p(x_t\vert x_{<t}, z)\) of target output.</p>
<p>The conditional probability of \(p(x_t\vert x_{<t}, z)\) can be decomposed into two parts:</p>
<ol>
<li>\(p(x_t\vert x_{<t})\) assigns high probabilities to fluent sequences of tokens;</li>
<li>a shift on \(p(x_t\vert x_{<t})\) towards \(p(x_t\vert x_{<t}, z)\).</li>
</ol>
\[p(x_t\vert x_{<t}, z) = \text{softmax}(\text{logits}_\text{LM}(x_t \vert x_{<t}) + \text{logits}_\text{aux}(x_t \vert x_{<t}, z))\]
<p>By Bayesian rule, we have</p>
\[p(x_t\vert x_{<t}, z)
= \frac{p(z \vert x_{\leq t})}{p(z)} p(x_t \vert x_{<t})
\propto p(z \vert x_{\leq t}) p(x_t \vert x_{<t})\]
<p>And therefore the auxiliary model \(\text{logits}_\text{aux}(x_t \vert x_{<t}, z))\) effectively should learn to predict \(p(z \vert x_{\leq t})\). In the experiments of <a href="https://arxiv.org/abs/2006.16823">Zeldes et al., 2020</a>, the auxiliary model can re-use the intermediate layers of the pre-trained LM for feature extraction.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/side-auxiliary.png" alt="Side auxiliary" /></p>
<p class="image-caption"><em>Fig. 23. The auxiliary model is trained by reusing features extracted from multiple layers of the base model. (Image source: <a href="https://arxiv.org/abs/2006.16823">Zeldes et al., 2020</a>)</em></p>
<p><strong>GeDi</strong> (<a href="https://arxiv.org/abs/2009.06367">Kruse et al., 2020</a>) guides the text generation by <em>Generative Discriminator</em>. The discriminator is implemented as a class conditional language model (CC-LM), \(p_\theta(x_{1:t} \vert z)\). The discriminator guides generation at each decoding step by computing classification probabilities for all possible next tokens via Bayes rule by normalizing over <em>two</em> contrastive class-conditional distributions:</p>
<ol>
<li>One conditioned on the control code \(z\) for desired attribute.</li>
<li>The other conditioned on the anti-control code \(\bar{z}\) for undesired attributes.</li>
</ol>
<p>GeDi relies on the contract between \(p_\theta(x_{1:t} \vert z)\) and \(p_\theta(x_{1:t} \vert \bar{z})\) to compute the probability of the sequence belonging to the desired class. The discriminator loss is to maximize the probability of desired attribute \(z\):</p>
\[\begin{aligned}
p_\theta(z \vert x_{1:t}) &= \frac{p(z) p_\theta(x_{1:\tau} \vert z)^{\alpha/\tau}}{\sum_{z' \in \{z, \bar{z}\}} p(z') p_\theta(x_{1:\tau} \vert z')^{\alpha/\tau} } \\
\mathcal{L}_\text{desc}
&= -\frac{1}{N} \sum_{i=1}^N \log p_\theta(z^{(i)} \vert x^{(i)}_{1:\tau_i}) \\
&= -\frac{1}{N} \sum_{i=1}^N \log \frac{p(z) p_\theta(x^{(i)}_{1:\tau_i} \vert z^{(i)})^{\alpha/t_i}}{\sum_{z' \in \{z, \bar{z}\} } p(z')p_\theta(x^{(i)}_{1:\tau_i} \vert z')^{\alpha/\tau_i}}
\end{aligned}\]
<p>where \(p(z) = \exp(b_z) / \sum_{z'} \exp(b_{z'})\) and \(b_z\) is a learned class prior. The probabilities are normalized by the current sequence length \(\tau\) to robustify generation sequences of variable lengths. \(\tau_i\) is the sequence length of the \(i\)-th input \(x^{(i)}\) in the dataset.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/GeDi.png" alt="GeDi" /></p>
<p class="image-caption"><em>Fig. 24. An illustration of how GeDi works via Bayesian rule. (Image source: <a href="https://arxiv.org/abs/2009.06367">Kruse et al., 2020</a>)</em></p>
<p>They finetuned a GPT2-medium model with control code similar to how <a href="#ctrl">CTRL</a> is trained to form a CC-LM using a linear combination of discriminative loss and generative loss. This discriminator model is then used as GiDe to guide generation by a larger language model like GPT2-XL.</p>
<p>One way of decoding from GeDi is to sample from a weighted posterior \(p^w(x_{t+1}\vert x_{1:t}, z) \propto p(z \vert x_{1:t+1})^w p(x_{t+1} \vert x_{1:t})\) where \(w>1\) applies additional bias toward the desired class \(z\). In the sampling process, only tokens with the class or next-token probability larger than a certain threshold are selected.</p>
<p>GeDi guided generation in their experiments showed strong controllability and ran 30x faster than <a href="#pplm">PPLM</a>.</p>
<h3 id="distributional-approach">Distributional Approach</h3>
<p><strong>Generation with Distributional Control</strong> (GDC; <a href="https://arxiv.org/abs/2012.11635">Khalifa, et al. 2020</a>) frames controlled text generation as the optimization of a probability distribution with a constraint. It involves two major steps.</p>
<p><strong>Step 1: Learn a EBM of the target model</strong></p>
<p>Let’s label a pretrained LM as \(a\) and a target LM with desired features as \(p\). The desired features can be defined by a set of pre-defined real-valued feature functions \(\phi_i(x), i=1,\dots,k\) over \(x \in X\), denoted as a vector \(\boldsymbol{\phi}\). When sequences \(x \in X\) are sampled according to the desired model \(p\), the expectations of features \(\mathbb{E}_{x\sim p}\boldsymbol{\phi}(x)\) should be close to \(\bar{\boldsymbol{\mu}}\) , named “<em>moment constraints</em>”. The feature function \(\phi_i\) can have distinct values (e.g. identity function for binary classifier) or continuous probabilities. In the meantime, the fine-tuned model \(p\) should not diverge from \(a\) too much by maintaining a small KL divergence measure.</p>
<p>In summary, given a pretrained model \(a\), we would like to find a target model \(p\) such that:</p>
\[\begin{aligned}
\bar{\boldsymbol{\mu}} &= \mathbb{E}_{x\sim p}\boldsymbol{\phi}(x) \\
p &= \arg\min_{c \in \mathcal{C}} D_\text{KL}(c, a)
\end{aligned}\]
<p>where \(\mathcal{C}\) is the set of all distributions over \(X\) that satisfy the moment constraints.</p>
<p>According to theorems in Information Geometry, \(p\) can be approximated by an EBM (energy-based model; an unnormalized probability distribution) \(P\) in the form of exponential function, such that \(p(x) \propto P(x)\) and \(p(x)=\frac{1}{Z}P(x)\) where \(Z=\sum_x P(x)\). The energy-based model can be approximated by:
\(P(x)=a(x)\exp\big(\sum_i \lambda_i \phi_i(x)\big)=a(x)\exp(\boldsymbol{\lambda}\cdot\boldsymbol{\phi}(x))\)
Let’s define <em>importance weight</em> \(w(x, \boldsymbol{\lambda}) = \frac{P(x)}{a(x)} = \exp\langle\boldsymbol{\lambda}\cdot\boldsymbol{\phi}(x)\rangle\). Given a large number of sequences sampled from the pretrained model \(x_1, \dots, x_N \sim a(x)\),</p>
\[\begin{aligned}
\mu(\boldsymbol{\lambda})
&= \mathbb{E}_{x\sim p}\boldsymbol{\phi}(x)
= \mathbb{E}_{x\sim a} \frac{p(x)}{a(x)}\boldsymbol{\phi}(x)
= \frac{1}{Z}\mathbb{E}_{x\sim a} w(x, \boldsymbol{\lambda}) \boldsymbol{\phi}(x) \\
&= \frac{\mathbb{E}_{x\sim a} w(x, \boldsymbol{\lambda}) \boldsymbol{\phi}(x)}{\sum_{x\in X} P(x)}
= \frac{\mathbb{E}_{x\sim a} w(x, \boldsymbol{\lambda}) \boldsymbol{\phi}(x)}{\sum_{x\in X} w(x, \boldsymbol{\lambda})a(x)}
= \frac{\mathbb{E}_{x\sim a} w(x, \boldsymbol{\lambda}) \boldsymbol{\phi}(x)}{\mathbb{E}_{x\sim a} w(x, \boldsymbol{\lambda})} \\
&\simeq \frac{\sum_{i=1}^N w(x_i,\boldsymbol{\lambda}) \boldsymbol{\phi}(x_i)}{\sum_{i=1}^N w(x_i, \boldsymbol{\lambda})}
= \frac{\sum_{i=1}^N \exp\langle\boldsymbol{\lambda}\cdot\boldsymbol{\phi}(x)\rangle \boldsymbol{\phi}(x_i)}{\sum_{i=1}^N \exp\langle\boldsymbol{\lambda}\cdot\boldsymbol{\phi}(x)\rangle}
\end{aligned}\]
<p>Using SGD over the objective \(\|\boldsymbol{\mu}(\boldsymbol{\lambda}) - \bar{\boldsymbol{\mu}}\|^2_2\), we can obtain an estimated value for \(\boldsymbol{\lambda}\) and a representation of \(P(x)=a(x)\exp\langle\boldsymbol{\lambda}\cdot\boldsymbol{\phi}(x)\rangle\). \(P(x)\) is a sequential EBM because \(a\) is an autoregressive model.</p>
<p><strong>Step 2: Learn the target probability distribution</strong></p>
<p>The EBM \(P(x)\) can compute ratios of probabilities of two sequences, but cannot sample from \(p(x)\) with knowing \(Z\). In order to sample from a sequential EBM, the paper proposed to use <a href="https://arxiv.org/abs/1912.08517">Distributional Policy Gradient</a> (DPG; but not this <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#dpg">DPG</a>) with the objective to obtain an autoregressive policy \(\pi_\theta\) to approximate a target distribution \(p\) by minimizing the cross entropy \(H(p, \pi_\theta)\). DPG runs through a sequence of iterations. Within each iteration, the proposed distribution \(q\) is used for sampling and we can correct the cross entropy loss with importance weights too:</p>
\[\begin{aligned}
\nabla_\theta H(p, \pi_\theta)
&= - \nabla_\theta \mathbb{E}_{x\sim p} \log \pi_\theta(x)
= - \mathbb{E}_{x\sim p} \nabla_\theta \log \pi_\theta(x) \\
&= - \mathbb{E}_{x\sim q} \frac{p(x)}{q(x)} \nabla_\theta \log \pi_\theta(x)
= - \frac{1}{Z}\mathbb{E}_{x\sim q} \frac{P(x)}{q(x)} \nabla_\theta \log \pi_\theta(x)
\end{aligned}\]
<p>To learn such a \(\pi_\theta\), the paper adopts a KL-adaptive version of DPG: It only updates \(q\) when the estimated policy \(\pi_\theta\) gets closer to \(p\). This adaptive step is important for fast convergence.</p>
<p style="width: 45%;" class="center"><img src="/lil-log/assets/images/GDC-KL-adaptive-DPG.png" alt="KL-adaptive DPG" /></p>
<p class="image-caption"><em>Fig. 25. The algorithm of distributional policy gradient to make it possible to sample from a EBM \(P(x)\), where \(q\) is initialized to be \(a\). (Image source: <a href="https://arxiv.org/abs/2012.11635">Khalifa, et al. 2020</a>)</em></p>
<p>This approach can be used to model various constraints in controllable text generation:</p>
<ol>
<li>Pointwise constraints: \(\phi_i\) is a binary feature; such as constraining the presence or absence of words, or classifier-based constraints.</li>
<li>Distributional constraints: \(\phi_i\) represents a probability distribution; such as constraining the probability of gender, topic, etc. Their experiments showed great progress in debiasing a GPT-2 model that was trained on Wikipedia Biographies corpus. The percentage of generated biographies on females increased from 7.4% to 35.6%.</li>
<li>Hybrid constraints: combine multiple constraints by simply summing them up.</li>
</ol>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/GDC-debiasing.png" alt="GDC debiasing" /></p>
<p class="image-caption"><em>Fig. 26. Debiasing experiments using GDC with various constraints. (Image source: <a href="https://arxiv.org/abs/2012.11635">Khalifa, et al. 2020</a>)</em></p>
<p>Compared to other baselines, GDC using pointwise constraints diverges less from the base model \(a\) and produces smoother curves.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/GDC-ablation.png" alt="GDC debiasing" /></p>
<p class="image-caption"><em>Fig. 27. Compare pointwise constrained GDC with several baselines. Low Self-BLEU-5 and high Dist-1 indicate high diversity. (Image source: <a href="https://arxiv.org/abs/2012.11635">Khalifa, et al. 2020</a>)</em></p>
<ul>
<li>REINFORCE that optimizes the reward \(\phi\) directly (\(\text{REINFORCE}\) in Fig. X.) without constraints converges fast but has a high deviation from the original model.</li>
<li>REINFORCE that optimizes \(P(x)\) (\(\text{REINFORCE}_{P(x)}\) in Fig. X.) has low sample diversity.</li>
<li>Compared to <a href="https://arxiv.org/abs/1909.08593">Ziegler et al., 2019</a> GDC has smoother learning curves and produces a richer vocabulary.</li>
</ul>
<h3 id="unlikelihood-training">Unlikelihood Training</h3>
<p>The standard way of maximizing the log-likelihood loss in language model training leads to <a href="#beam-search-surprise">incorrect token distribution</a>, which cannot be fixed with only smart decoding methods. Such models tend to output high-frequency words too often and low-frequency words too rarely, especially when using deterministic decoding (e.g. greedy, beam search). In other words, they are overconfident in their predictions.</p>
<p>Unlikelihood training (<a href="https://arxiv.org/abs/1908.04319">Welleck & Kulikov et al. 2019</a>] tries to combat this and incorporates preference to <em>unwanted</em> content into the training objective directly. It combines two updates:</p>
<ul>
<li>A routine maximized likelihood update to assign true tokens with high probability;</li>
<li>A new type of unlikelihood update to avoid unwanted tokens with high probability.</li>
</ul>
<p>Given a sequence of tokens \((x_1, \dots, x_T)\) and a set of negative candidate tokens \(\mathcal{C}^t = \{c_1, \dots , c_m\}\) at step \(t\), where each token \(x_i, c_j \in \mathcal{V}\), the combined loss for step \(t\) is defined as:</p>
\[\mathcal{L}^t_\text{UL}(p_\theta (. \vert x_{<t}), \mathcal{C}^t)
= - \alpha \cdot \underbrace{\sum_{c \in \mathcal{C}^t} \log(1 - p_\theta(c \vert x_{<t}))}_\text{unlikelihood} - \underbrace{\log p_\theta (x_t \vert x_{<t})}_\text{likelihood}\]
<p>One approach for constructing \(\mathcal{C}^t\) is to randomly select candidates from model-generated sequences.</p>
<p>The unlikelihood training can be extended to be on the <em>sequence</em>-level, where the negative continuation is defined by a sequence of per-step negative candidate sets. They should be designed to penalize properties that we don’t like. For example, we can penalize repeating n-grams as follows:</p>
\[\mathcal{C}^t_\text{repeat-n} = \{x_t\} \text{ if }(x_{t-i}, \dots, x_{t+j}) \in x_{<t-i} \text{ for any } (j-i)=n, i\leq n \leq j.\]
<p>Their experiments used unlikelihood training to avoid repetitions in language model outputs and indeed showed better results on less repetition and more unique tokens compared to standard MLE training.</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2021conditional,
title = "Controllable Neural Text Generation.",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2021",
url = "https://lilianweng.github.io/lil-log/2021/01/02/controllable-neural-text-generation.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Patrick von Platen. <a href="https://huggingface.co/blog/how-to-generate">“How to generate text: using different decoding methods for language generation with Transformers”</a> Hugging face blog, March 18, 2020.</p>
<p>[2] Angela Fan, et al. <a href="https://arxiv.org/abs/1805.04833">“Hierarchical Neural Story Generation/”</a> arXiv preprint arXiv:1805.04833 (2018).</p>
<p>[3] Ari Holtzman et al. <a href="https://arxiv.org/abs/1904.09751">“The Curious Case of Neural Text Degeneration.”</a> ICLR 2020.</p>
<p>[4] Marjan Ghazvininejad et al. <a href="https://www.aclweb.org/anthology/P17-4008">“Hafez: an interactive poetry generation system.”</a> ACL 2017.</p>
<p>[5] Ari Holtzman et al. <a href="https://arxiv.org/abs/1805.06087">“Learning to write with cooperative discriminators.”</a> ACL 2018.</p>
<p>[6] Ashutosh Baheti et al. <a href="https://arxiv.org/abs/1809.01215">“Generating More Interesting Responses in Neural Conversation Models with Distributional Constraints.”</a> EMNLP 2018.</p>
<p>[7] Jiatao Gu et al. <a href="https://arxiv.org/abs/1702.02429">“Trainable greedy decoding for neural machine translation.”</a> EMNLP 2017.</p>
<p>[8] Kyunghyun Cho. <a href="https://arxiv.org/abs/1605.03835">“Noisy Parallel Approximate Decoding for Conditional Recurrent Language Model.”</a> arXiv preprint arXiv:1605.03835. (2016).</p>
<p>[9] Marco Tulio Ribeiro et al. <a href="https://www.aclweb.org/anthology/P18-1079/">“Semantically equivalent adversarial rules for debugging NLP models.”</a> ACL 2018.</p>
<p>[10] Eric Wallace et al. <a href="https://arxiv.org/abs/1908.07125">“Universal Adversarial Triggers for Attacking and Analyzing NLP.”</a> EMNLP 2019. [<a href="https://github.com/Eric-Wallace/universal-triggers">code</a>]</p>
<p>[11] Taylor Shin et al. <a href="https://arxiv.org/abs/2010.15980">“AutoPrompt: Eliciting Knowledge from Language Models with Automatically Generated Prompts.”</a> EMNLP 2020. [<a href="http://ucinlp.github.io/autoprompt">code</a>]</p>
<p>[12] Zhengbao Jiang et al. <a href="https://arxiv.org/abs/1911.12543">“How Can We Know What Language Models Know?”</a> TACL 2020.</p>
<p>[13] Nanyun Peng et al. <a href="https://www.aclweb.org/anthology/W18-1505/">“Towards Controllable Story Generation.”</a> NAACL 2018.</p>
<p>[14] Nitish Shirish Keskar, et al. <a href="https://arxiv.org/abs/1909.05858">“CTRL: A Conditional Transformer Language Model for Controllable Generation”</a> arXiv preprint arXiv:1909.05858 (2019).[<a href="https://github.com/salesforce/ctrl">code</a>]</p>
<p>[15] Marc’Aurelio Ranzato et al. <a href="https://arxiv.org/abs/1511.06732">“Sequence Level Training with Recurrent Neural Networks.”</a> ICLR 2016.</p>
<p>[16] Yonghui Wu et al. <a href="https://arxiv.org/abs/1609.08144">“Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation.”</a> CoRR 2016.</p>
<p>[17] Romain Paulus et al. <a href="https://arxiv.org/abs/1705.04304">“A Deep Reinforced Model for Abstractive Summarization.”</a> ICLR 2018.</p>
<p>[18] Paul Christiano et al. <a href="https://arxiv.org/abs/1706.03741">“Deep Reinforcement Learning from Human Preferences.”</a> NIPS 2017.</p>
<p>[19] Sanghyun Yi et al. <a href="https://arxiv.org/abs/1904.13015">“Towards coherent and engaging spoken dialog response generation using automatic conversation evaluators.”</a> INLG 2019.</p>
<p>[20] Florian Böhm et al. <a href="https://arxiv.org/abs/1909.01214">“Better rewards yield better summaries: Learning to summarise without references.”</a> EMNLP 2019. [<a href="https://github.com/yg211/summary-reward-no-reference">code</a>]</p>
<p>[21] Daniel M Ziegler et al. <a href="https://arxiv.org/abs/1909.08593">“Fine-tuning language models from human preferences.”</a> arXiv preprint arXiv:1909.08593 (2019). [<a href="https://github.com/openai/lm-human-preferences">code</a>]</p>
<p>[22] Nisan Stiennon, et al. <a href="https://arxiv.org/abs/2009.01325">“Learning to summarize from human feedback.”</a> arXiv preprint arXiv:2009.01325 (2020).</p>
<p>[23] Sumanth Dathathri et al. <a href="https://arxiv.org/abs/1912.02164">“Plug and play language models: a simple approach to controlled text generation.”</a> ICLR 2020. [<a href="https://github.com/uber-research/PPLM">code</a>]</p>
<p>[24] Jeffrey O Zhang et al. <a href="https://arxiv.org/abs/1912.13503">“Side-tuning: Network adaptation via additive side networks”</a> ECCV 2020.</p>
<p>[25] Ben Kruse et al. <a href="https://arxiv.org/abs/2009.06367">“GeDi: Generative Discriminator Guided Sequence Generation.”</a> arXiv preprint arXiv:2009.06367.</p>
<p>[26] Yoel Zeldes et al. <a href="https://arxiv.org/abs/2006.16823">“Technical Report: Auxiliary Tuning and its Application to Conditional Text Generatio.”</a> arXiv preprint arXiv:2006.16823.</p>
<p>[27] Thomas Scialom, et al. <a href="https://arxiv.org/abs/2002.10375">“Discriminative Adversarial Search for Abstractive Summarization”</a> ICML 2020.</p>
<p>[28] Clara Meister, et al. <a href="https://arxiv.org/abs/2010.02650">“If beam search is the answer, what was the question?”</a> EMNLP 2020.</p>
<p>[29] Xiang Lisa Li and Percy Liang. <a href="https://arxiv.org/abs/2101.00190">“Prefix-Tuning: Optimizing Continuous Prompts for Generation.”</a> arXiv preprint arXiv:2101.00190 (2021).</p>
<p>[30] Lianhui Qin, et al. <a href="https://arxiv.org/abs/2010.05906">“Back to the Future: Unsupervised Backprop-based Decoding for Counterfactual and Abductive Commonsense Reasoning.”</a> arXiv preprint arXiv:2010.05906 (2020).</p>
<p>[31] Muhammad Khalifa, et al. <a href="https://arxiv.org/abs/2012.11635">“A Distributional Approach to Controlled Text Generation”</a> Accepted by ICLR 2021.</p>
<p>[32] Aditya Grover, et al. <a href="https://arxiv.org/abs/1906.09531">“Bias correction of learned generative models using likelihood-free importance weighting.”</a> NeuriPS 2019.</p>
<p>[33] Yuntian Deng et al. <a href="https://arxiv.org/abs/2004.11714">“Residual Energy-Based Models for Text Generation.”</a> ICLR 2020.</p>
<p>[34] Brian Lester et al. <a href="https://arxiv.org/abs/2104.08691">“The Power of Scale for Parameter-Efficient Prompt Tuning.”</a> arXiv preprint arXiv:2104.08691 (2021).</p>
<p>[35] Xiao Liu et al. <a href="https://arxiv.org/abs/2103.10385">“GPT Understands, Too.”</a> arXiv preprint arXiv:2103.10385 (2021).</p>
<p>[36] Welleck & Kulikov et al. <a href="https://arxiv.org/abs/1908.04319">“Neural Text Generation with Unlikelihood Training”</a> arXiv:1908.04319 (2019).</p>Lilian WengThe modern language model with SOTA results on many NLP tasks is trained on large scale free text on the Internet. It is challenging to steer such a model to generate content with desired attributes. Although still not perfect, there are several approaches for controllable text generation, such as guided or learned decoding strategy, smart prompt design, or fine-tuning the model with various methods.How to Build an Open-Domain Question Answering System?2020-10-29T12:00:00+00:002020-10-29T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/10/29/open-domain-question-answering<blockquote>
<p>A model that is capable of answering any question with regard to factual knowledge can enable many useful applications. This post delves into how we can build an Open-Domain Question Answering (ODQA) system, assuming we have access to a powerful pretrained language model. Both closed-book and open-book approachs are discussed.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2020-11-12: add <a href="#openai-api-example">an example</a> on closed-book factual QA using OpenAI API (beta).</span></p>
<p>A model that can answer any question with regard to factual knowledge can lead to many useful and practical applications, such as working as a chatbot or an AI assistant🤖. In this post, we will review several common approaches for building such an open-domain question answering system.</p>
<p>Disclaimers given so many papers in the wild:</p>
<ul>
<li>Assume we have access to a powerful pretrained <a href="/lil-log/2019/01/31/generalized-language-models.html">language model</a>.</li>
<li>We do not cover how to use structured knowledge base (e.g. Freebase, WikiData) here.</li>
<li>We only focus on a single-turn QA instead of a multi-turn conversation style QA.</li>
<li>We mostly focus on QA models that contain neural networks, specially Transformer-based language models.</li>
<li>I admit that I missed a lot of papers with architectures designed specifically for QA tasks between 2017-2019😔</li>
</ul>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#what-is-open-domain-question-answering" id="markdown-toc-what-is-open-domain-question-answering">What is Open-Domain Question Answering?</a> <ul>
<li><a href="#notation" id="markdown-toc-notation">Notation</a></li>
<li><a href="#concerns-of-qa-data-fine-tuning" id="markdown-toc-concerns-of-qa-data-fine-tuning">Concerns of QA data fine-tuning</a></li>
</ul>
</li>
<li><a href="#open-book-qa-retriever-reader" id="markdown-toc-open-book-qa-retriever-reader">Open-book QA: Retriever-Reader</a> <ul>
<li><a href="#retriever-model" id="markdown-toc-retriever-model">Retriever Model</a></li>
<li><a href="#reader-model" id="markdown-toc-reader-model">Reader Model</a></li>
<li><a href="#end-to-end-joint-training" id="markdown-toc-end-to-end-joint-training">End-to-end Joint Training</a></li>
</ul>
</li>
<li><a href="#open-book-qa-retriever-generator" id="markdown-toc-open-book-qa-retriever-generator">Open-book QA: Retriever-Generator</a></li>
<li><a href="#closed-book-qa-generative-language-model" id="markdown-toc-closed-book-qa-generative-language-model">Closed-book QA: Generative Language Model</a></li>
<li><a href="#related-techniques" id="markdown-toc-related-techniques">Related Techniques</a> <ul>
<li><a href="#fast-maximum-inner-product-search-mips" id="markdown-toc-fast-maximum-inner-product-search-mips">Fast Maximum Inner Product Search (MIPS)</a></li>
<li><a href="#language-model-pre-training" id="markdown-toc-language-model-pre-training">Language Model Pre-training</a></li>
</ul>
</li>
<li><a href="#summary" id="markdown-toc-summary">Summary</a></li>
<li><a href="#appendix-qa-datasets" id="markdown-toc-appendix-qa-datasets">Appendix: QA Datasets</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="what-is-open-domain-question-answering">What is Open-Domain Question Answering?</h2>
<p><strong>Open-domain Question Answering (ODQA)</strong> is a type of language tasks, asking a model to produce answers to factoid questions in natural language. The true answer is objective, so it is simple to evaluate model performance.</p>
<p>For example,</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Question: What did Albert Einstein win the Nobel Prize for?
Answer: The law of the photoelectric effect.
</code></pre></div></div>
<p>The “open-domain” part refers to the lack of the relevant context for any arbitrarily asked factual question. In the above case, the model only takes as the input the question but no article about “why Einstein didn’t win a Nobel Prize for the theory of relativity” is provided, where the term “the law of the photoelectric effect” is likely mentioned. In the case when both the question and the context are provided, the task is known as <strong>Reading comprehension (RC)</strong>.</p>
<p>An ODQA model may work with or without <em>access to an external source of knowledge</em> (e.g. Wikipedia) and these two conditions are referred to as <em>open-book</em> or <em>closed-book</em> question answering, respectively.</p>
<p>When considering different types of open-domain questions, I like the classification by <a href="https://arxiv.org/abs/2008.02637">Lewis, et al., 2020</a>, in increasing order of difficulty:</p>
<ol>
<li>A model is able to correctly memorize and respond with the answer to a question that has been seen at training time.</li>
<li>A model is able to answer novel questions at test time and choose an answer from the set of answers it has seen during training.</li>
<li>A model is able to answer novel questions which have answers not contained in the training dataset.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/QA-summary.png" alt="QA-summary" /></p>
<p class="image-caption"><em>Fig. 1. Overview of three frameworks discussed in this post.</em></p>
<h3 id="notation">Notation</h3>
<p>Given a question \(x\) and a ground truth answer span \(y\), the context passage containing the true answer is labelled as \(z \in \mathcal{Z}\), where \(\mathcal{Z}\) is an external knowledge corpus. Wikipedia is a common choice for such an external knowledge source.</p>
<h3 id="concerns-of-qa-data-fine-tuning">Concerns of QA data fine-tuning</h3>
<p>Before we dive into the details of many models below. I would like to point out one concern of fine-tuning a model with common QA datasets, which appears as one fine-tuning step in several ODQA models. It could be concerning, because there is a significant overlap between questions in the train and test sets in several public QA datasets.</p>
<p><a href="https://arxiv.org/abs/2008.02637">Lewis, et al., (2020)</a> (<a href="https://github.com/facebookresearch/QA-Overlap">code</a>) found that 58-71% of test-time answers are also present somewhere in the training sets and 28-34% of test-set questions have a near-duplicate paraphrase in their corresponding training sets. In their experiments, several models performed notably worse when duplicated or paraphrased questions were removed from the training set.</p>
<h2 id="open-book-qa-retriever-reader">Open-book QA: Retriever-Reader</h2>
<p>Given a factoid question, if a language model has no context or is not big enough to memorize the context which exists in the training dataset, it is unlikely to guess the correct answer. In an open-book exam, students are allowed to refer to external resources like notes and books while answering test questions. Similarly, a ODQA system can be paired with a rich knowledge base to identify relevant documents as evidence of answers.</p>
<p>We can decompose the process of finding answers to given questions into two stages,</p>
<ol>
<li>Find the related context in an external repository of knowledge;</li>
<li>Process the retrieved context to <em>extract</em> an answer.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/QA-retriever-reader.png" alt="retriever + reader QA system" /></p>
<p class="image-caption"><em>Fig. 2. The retriever-reader QA framework combines information retrieval with machine reading comprehension.</em></p>
<p>Such a retriever + reader framework was first proposed in <strong>DrQA</strong> (“Document retriever Question-Answering” by <a href="https://arxiv.org/abs/1704.00051">Chen et al., 2017</a>; <a href="https://github.com/facebookresearch/DrQA">code</a>). The retriever and the reader components can be set up and trained independently, or jointly trained <a href="#end-to-end-joint-training">end-to-end</a>.</p>
<h3 id="retriever-model">Retriever Model</h3>
<p>Two popular approaches for implementing the retriever is to use the information retrieval (IR) system that depends on (1) the classic non-learning-based <a href="https://en.wikipedia.org/wiki/Tf%E2%80%93idf">TF-IDF</a> features (“classic IR”) or (2) dense embedding vectors of text produced by neural networks (“neural IR”).</p>
<h4 id="classic-ir">Classic IR</h4>
<p><strong>DrQA</strong> (<a href="https://arxiv.org/abs/1704.00051">Chen et al., 2017</a>) adopts an efficient non-learning-based search engine based on the <a href="https://en.wikipedia.org/wiki/Vector_space_model">vector space model</a>. Every query and document is modelled as a bag-of-word vector, where each term is weighted by TF-IDF (term frequency \(\times\) inverse document frequency).</p>
\[\begin{aligned}
\text{tf-idf}(t, d, \mathcal{D}) &= \text{tf}(t, d) \times \text{idf}(t, \mathcal{D}) \\
\text{tf}(t, d) &= \log(1 + \text{freq}(t, d)) \\
\text{idf}(t, \mathcal{D}) &= \log \Big( \frac{\vert\mathcal{D}\vert}{\vert d\in\mathcal{D}: t\in d\vert} \Big)
\end{aligned}\]
<p>where \(t\) is a unigram or bigram term in a document \(d\) from a collection of documents \(\mathcal{D}\) . \(\text{freq}(t, d)\) measures how many times a term \(t\) appears in \(d\). Note that the term-frequency here includes bigram counts too, which is found to be very helpful because the local word order is taken into consideration via bigrams. As part of the implementation, DrQA maps the bigrams of \(2^{24}\) bins using unsigned murmur3 hash.</p>
<p>Precisely, DrQA implemented Wikipedia as its knowledge source and this choice has became a default setting for many ODQA studies since then. The non-ML document retriever returns the top \(k=5\) most relevant Wikipedia articles given a question.</p>
<p><strong>BERTserini</strong> (<a href="https://arxiv.org/abs/1902.01718">Yang et al., 2019</a>) pairs the open-source <a href="https://github.com/castorini/anserini"><em>Anserini</em></a> IR toolkit as the retriever with a fine-tuned pre-trained BERT model as the reader. The top \(k\) documents (\(k=10\)) are retrieved via the <code class="language-plaintext highlighter-rouge">post-v3.0</code> branch of Anserini with the query treated as a bag of words. The retrieved text segments are ranked by <a href="https://en.wikipedia.org/wiki/Okapi_BM25">BM25</a>, a classic TF-IDF-based retrieval scoring function. In terms of the effect of text granularity on performance, they found that paragraph retrieval > sentence retrieval > article retrieval.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/BERTserini-arch.png" alt="BERTserini" /></p>
<p class="image-caption"><em>Fig. 3. An illustration of BERTserini architecture. (Image source: <a href="https://arxiv.org/abs/1902.01718">Yang et al., 2019</a>)</em></p>
<p><em>ElasticSearch + BM25</em> is used by the <strong>Multi-passage BERT</strong> QA model (<a href="https://arxiv.org/abs/1908.08167">Wang et al., 2019</a>). They found that splitting articles into passages with the length of 100 words by <em>sliding window</em> brings 4% improvements, since splitting documents into passages without overlap may cause some near-boundary evidence to lose useful contexts.</p>
<h4 id="neural-ir">Neural IR</h4>
<p>There is a long history in learning a low-dimensional representation of text, denser than raw term-based vectors (<a href="http://lsa.colorado.edu/papers/JASIS.lsi.90.pdf">Deerwester et al., 1990</a>; <a href="https://www.aclweb.org/anthology/W11-0329/">Yih, et al., 2011</a>). Dense representations can be learned through matrix decomposition or some neural network architectures (e.g. MLP, LSTM, bidirectional LSTM, etc). When involving neural networks, such approaches are referred to as “Neural IR”, Neural IR is a new category of methods for retrieval problems, but it is not necessary to perform better/superior than classic IR (<a href="https://sigir.org/wp-content/uploads/2019/01/p040.pdf">Lim, 2018</a>).</p>
<p>After the success of many large-scale <a href="/lil-log/2019/01/31/generalized-language-models.html">general language models</a>, many QA models embrace the following approach:</p>
\[h_x = E_x(x)\quad
h_z = E_z(z)\quad
\text{score}(x, z) = h_x^\top h_z\]
<ol>
<li>Extract the dense representations of a question \(x\) and a context passage \(z\) by feeding them into a language model;</li>
<li>Use the dot-product of these two representations as the retrieval score to rank and select most relevant passages.</li>
</ol>
<p>ORQA, REALM and DPR all use such a scoring function for context retrieval, which will be described in detail in a <a href="#end-to-end-joint-training">later section</a> on the end-to-end QA model.</p>
<p>An extreme approach, investigated by <strong>DenSPI</strong> (“Dense-Sparse Phrase Index”; <a href="https://arxiv.org/abs/1906.05807">Seo et al., 2019</a>), is to encode all the text in the knowledge corpus at the <em>phrase</em> level and then only rely on the retriever to identify the most relevant phrase as the predicted answer. In this way, the retriever+reader pipeline is reduced to only retriever. Of course, the index would be much larger and the retrieval problem is more challenging.</p>
<p>DenSPI introduces a <em>query-agnostic</em> indexable representation of document phrases. Precisely it encodes query-agnostic representations of text spans in Wikipedia offline and looks for the answer at inference time by performing nearest neighbor search. It can drastically speed up the inference time, because there is no need to re-encode documents for every new query, which is often required by a reader model.</p>
<p>Given a question \(x\) and a fixed set of (Wikipedia) documents, \(z_1, \dots, z_K\) and each document \(z_k\) contains \(N_k\) words, \(z_k = \langle z_k^{(1)}, \dots, z_k^{(N_k)}\rangle\). An ODQA model is a scoring function \(F\) for each candidate phrase span \(z_k^{(i:j)}, 1 \leq i \leq j \leq N_k\), such that the truth answer is the phrase with maximum score: \(y = {\arg\max}_{k,i,j} F(x, z_k^{(i:j)})\).</p>
<p>The phrase representation \(z_k^{(i:j)}\) combines both dense and sparse vectors, \(z_k^{(i:j)} = [d_k^{(i:j)}, s_k^{(i:j)}] \in \mathbb{R}^{d^d + d^s}\) (note that \(d^d \ll d^s\)):</p>
<ul>
<li>The dense vector \(d_k^{(i:j)}\) is effective for encoding local <em>syntactic</em> and <em>semantic</em> cues, as what can be learned by a pretrained language model.</li>
<li>The sparse vector \(s_k^{(i:j)}\) is superior at encoding precise <em>lexical</em> information. The sparse vector is term-frequency-based encoding. DenSPI uses 2-gram term-frequency same as DrQA, resulting a highly sparse representation (\(d^s \approx 16\)M)</li>
</ul>
<p>The dense vector \(d^{(i:j)}\) is further decomposed into three parts, \(d^{(i:j)} = [a_i, b_j, c_{ij}] \in \mathbb{R}^{2d^b + 1}\) where \(2d^b + 1 = d^d\). All three components are learned based on different columns of the fine-tuned BERT representations.</p>
<ul>
<li>A vector \(a_i\) encodes the <em>start</em> position for the \(i\)-th word of the document;</li>
<li>A vector \(b_j\) encodes the <em>end</em> position for the \(j\)-th word of the document;</li>
<li>A scalar \(c_{ij}\) measures the <em>coherency</em> between the start and the end vectors, helping avoid non-constituent phrases during inference.</li>
</ul>
<p>For all possible \((i,j,k)\) tuples where \(j-i < J\), the text span embeddings are precomputed and stored as a <em>phrase index</em>. The maximum span length \(J\) is a predefined scalar constant.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/DenSPI-arch.png" alt="DenseSPI" /></p>
<p class="image-caption"><em>Fig. 4. An illustration of Dense-Sparse Phrase Index (DenSPI) architecture. (Image source: <a href="https://arxiv.org/abs/1906.05807">Seo et al., 2019</a>)</em></p>
<p>At the inference time, the question is mapped into the same vector space \(x=[d', s'] \in \mathbb{R}^{d^d + d^s}\), where the dense vector \(d'\) is extracted from the BERT embedding of the special <code class="language-plaintext highlighter-rouge">[CLS]</code> symbol. The same BERT model is shared for encoding both questions and phrases. The final answer is predicted by \(k^*, i^*, j^* = \arg\max x^\top z_k^{(i:j)}\).</p>
<h3 id="reader-model">Reader Model</h3>
<p>The reader model learns to solve the reading comprehension task — extract an answer for a given question from a given context document. Here we only discuss approaches for machine comprehension using neural networks.</p>
<h4 id="bi-directional-lstm">Bi-directional LSTM</h4>
<p>The reader model for answer detection of <strong>DrQA</strong> (<a href="https://arxiv.org/abs/1704.00051">Chen et al., 2017</a>) is a 3-layer bidirectional LSTM with hidden size 128. Every relevant paragraph of retrieved Wikipedia articles is encoded by a sequence of feature vector, \(\{\tilde{\mathbf{z}}_1, \dots, \tilde{\mathbf{z}}_m \}\). Each feature vector \(\hat{\mathbf{z}}_i \in \mathbb{R}^{d_z}\) is expected to capture useful contextual information around one token \(z_i\). The feature consists of several categories of features:</p>
<ol>
<li>Word embeddings: A 300d <a href="/lil-log/2017/10/15/learning-word-embedding.html#glove-global-vectors">Glove</a> word embedding trained from 800B Web crawl data, \(f_\text{embed} = E_g(z_i)\).</li>
<li>Exact match: Whether a word \(z_i\) appears in the question \(x\), \(f_\text{match} = \mathbb{I}(z_i \in x)\).</li>
<li>Token features: This includes POS (part-of-speech) tagging, NER (named entity recognition), and TF (term-frequency), \(f_\text{token}(z_i) = (\text{POS}(z_i), \text{NER}(z_i), \text{TF}(z_i))\).</li>
<li>Aligned question embedding: The attention score \(y_{ij}\) is designed to capture inter-sentence matching and similarity between the paragraph token \(z_i\) and the question word \(x_j\). This feature adds soft alignments between similar but non-identical words.</li>
</ol>
\[\begin{aligned}
f_\text{align}(z_i) &= \sum_j y_{i,j} E_g(x_j) \\
y_{i,j} &= \frac{\exp(\alpha(E_g(z_i))^\top \alpha(E_g(x_j)) )}{\sum_{j'} \exp(\alpha(E_g(z_i))^\top \alpha(E_g(x_{j'})) ) }
\end{aligned}\]
<p>where \(\alpha\) is a single dense layer with ReLU and \(E_g(.)\) is the glove word embedding.</p>
<p>The feature vector of a paragraph of \(m\) tokens is fed into LSTM to obtain the final paragraph vectors:</p>
\[\begin{aligned}
\mathbf{z} = \{\mathbf{z}_1, \dots, \mathbf{z}_m\} &= \text{LSTM}(\{\tilde{\mathbf{z}}_1, \dots, \tilde{\mathbf{z}}_m\}) \\
\text{where } \tilde{\mathbf{z}}_i &= \{f_\text{embed}, f_\text{match}, f_\text{token}, f_\text{align}\}
\end{aligned}\]
<p>The question is encoded as a weighted sum of the embeddings of every word in the question:</p>
\[\mathbf{x} = \sum_j b_j E(x_j) \quad b_j = \text{softmax}(\mathbf{w}^\top E(x_j))\]
<p>where \(\mathbf{w}\) is a weight vector to learn.</p>
<p>Once the feature vectors are constructed for the question and all the related paragraphs, the reader needs to predict the probabilities of each position in a paragraph to be the start and the end of an answer span, \(p_\text{start}(i_s)\) and \(p_\text{end}(i_s)\), respectively. Across all the paragraphs, the optimal span is returned as the final answer with maximum \(p_\text{start}(i_s) \times p_\text{end}(i_e)\).</p>
\[\begin{aligned}
p_\text{start}(i_s) \propto \exp(\mathbf{z}_{i_s} \mathbf{W}_s \mathbf{x}) \\
p_\text{end}(i_e) \propto \exp(\mathbf{z}_{i_e} \mathbf{W}_e \mathbf{x}) \\
\text{ s.t. } i_s \leq i_e \leq i_s + 15
\end{aligned}\]
<p>where \(\mathbf{W}_s\) and \(\mathbf{W}_e\) are learned parameters.</p>
<h4 id="bert-universe">BERT-universe</h4>
<p>Following the success of <a href="/lil-log /2019/01/31/generalized-language-models.html#bert">BERT</a> (<a href="https://arxiv.org/abs/1810.04805">Devlin et al., 2018</a>), many QA models develop the machine comprehension component based on BERT. Let’s define the BERT model as a function that can take one or multiple strings (concatenated by <code class="language-plaintext highlighter-rouge">[SEP]</code>) as input and outputs a set of BERT encoding vectors for the special <code class="language-plaintext highlighter-rouge">[CLS]</code> token and every input token:</p>
\[\text{BERT}(s_1, s_2, \dots) = [\mathbf{h}^\texttt{[CLS]}, \mathbf{h}^{(1)}, \mathbf{h}^{(2)}, \dots]\]
<p>where \(\mathbf{h}^\texttt{[CLS]}\) is the embedding vector for the special <code class="language-plaintext highlighter-rouge">[CLS]</code> token and \(\mathbf{h}^{(i)}\) is the embedding vector for the \(i\)-th token.</p>
<p>To use BERT for reading comprehension, it learns two additional weights, \(\mathbf{W}_s\) and \(\mathbf{W}_e\), and \(\text{softmax}(\mathbf{h}^{(i)}\mathbf{W}_s)\) and \(\text{softmax}(\mathbf{h}^{(i)}\mathbf{W}_e)\) define two probability distributions of start and end position of the predicted span per token.</p>
<p><strong>BERTserini</strong> (<a href="https://arxiv.org/abs/1902.01718">Yang et al., 2019</a>) utilizes a pre-trained BERT model to work as the reader. Their experiments showed that <em>fine-tuning</em> pretrained BERT with SQuAD is sufficient to achieve high accuracy in identifying answer spans.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/BERT-RC.png" alt="BERT for reading comprehension" /></p>
<p class="image-caption"><em>Fig. 5. How BERT is used to solve question-answering tasks. (Image source: <a href="https://arxiv.org/abs/1810.04805">Devlin et al., 2018</a>)</em></p>
<p>The key difference of the BERTserini reader from the original BERT is: to allow comparison and aggregation of results from different segments, the final softmax layer over different answer spans is removed. The pre-trained BERT model is fine-tuned on the training set of SQuAD, where all inputs to the reader are padded to 384 tokens with the learning rate 3e-5.</p>
<p>When ranking all the extracted answer spans, the retriever score (BM25) and the reader score (probability of token being the start position \(\times\) probability of the same token being the end position ) are combined via linear interpolation.</p>
<p>The original BERT normalizes the probability distributions of start and end position per token for every passage independently. Differently, the <strong>Multi-passage BERT</strong> (<a href="https://arxiv.org/abs/1908.08167">Wang et al., 2019</a>) normalizes answer scores across all the retrieved passages of one question <a href="https://arxiv.org/abs/1710.10723">globally</a>. Precisely, multi-passage BERT removes the final normalization layer per passage in BERT for QA (same as in BERTserini) and then adds a global <code class="language-plaintext highlighter-rouge">softmax</code> over all the word positions of all the passages. Global normalization makes the reader model more stable while pin-pointing answers from a large number of passages.</p>
<p>In addition, multi-passage BERT implemented an independent <em>passage ranker</em> model via another BERT model and the rank score for \((x, z)\) is generated by a <code class="language-plaintext highlighter-rouge">softmax</code> over the representation vectors of the first <code class="language-plaintext highlighter-rouge">[CLS]</code> token. The passage ranker brings in extra 2% improvements. Similar idea of re-ranking passages with BERT was discussed in <a href="https://arxiv.org/abs/1901.04085">Nogueira & Cho, 2019</a>, too.</p>
<p>Interestingly, <a href="https://arxiv.org/abs/1908.08167">Wang et al., 2019</a> found that <em>explicit inter-sentence matching</em> does not seem to be critical for RC tasks with BERT; check the original paper for how the experiments were designed. One possible reason is that the multi-head self-attention layers in BERT has already embedded the inter-sentence matching.</p>
<h3 id="end-to-end-joint-training">End-to-end Joint Training</h3>
<p>The retriever and reader components can be jointly trained. This section covers R^3, ORQA, REALM and DPR. There are a lot of common designs, such as BERT-based dense vectors for retrieval and the loss function on maximizing the marginal likelihood of obtaining true answers.</p>
<p>The retriever and reader models in the <strong>R^3</strong> (“Reinforced Ranker-Reader”; <a href="https://arxiv.org/abs/1709.00023">Wang, et al., 2017</a>) QA system are jointly trained via <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">reinforcement learning</a>. (Note that to keep the term consistent between papers in this section, the “ranker” model in the original R^3 paper is referred to as the “retriever” model here.) Both components are variants of <a href="https://arxiv.org/abs/1512.08849">Match-LSTM</a>, which relies on an attention mechanism to compute word similarities between the passage and question sequences.</p>
<p><strong>How does the Match-LSTM module work?</strong> Given a question \(\mathbf{X}\) of \(d_x\) words and a passage \(\mathbf{Z}\) of \(d_z\) words, both representations use fixed <a href="/lil-log/2017/10/15/learning-word-embedding.html#glove-global-vectors">Glove</a> word embeddings,</p>
\[\begin{aligned}
\mathbf{H}^x &= \text{BiLSTM}(\mathbf{X}) \in \mathbb{R}^{l \times d_x} \\
\mathbf{H}^z &= \text{BiLSTM}(\mathbf{Z}) \in \mathbb{R}^{l \times d_z} \\
\mathbf{G} &= \text{softmax}((\mathbf{W}^g \mathbf{H}^x + \mathbf{b}^g \otimes \mathbf{e}_{d_x})^\top \mathbf{H}^z) \in \mathbb{R}^{d_x \times d_z} & \text{; an attention matrix}\\
\bar{\mathbf{H}}^x &= \mathbf{H}^x \mathbf{G} \in \mathbb{R}^{l \times d_z} \\
\mathbf{M} &= \text{ReLU} \Big( \mathbf{W}^m \begin{bmatrix}
\mathbf{H}^z \\
\bar{\mathbf{H}}^x \\
\mathbf{H}^z \odot \bar{\mathbf{H}}^x \\
\mathbf{H}^z - \bar{\mathbf{H}}^x
\end{bmatrix} \Big) \in \mathbb{R}^{2l \times d_z} \\
\mathbf{H}^m &= \text{BiLSTM}(M) \in \mathbb{R}^{l \times d_z}
\end{aligned}\]
<p>where \(l\) is the hidden dimension of the bidirectional LSTM module. \(\mathbf{W}^g \in \mathbb{R}^{l\times l}\), \(\mathbf{b}^g \in \mathbb{R}^l\), and \(\mathbf{W}^m \in \mathbb{R}^{2l \times 4l}\) are parameters to learn. The operator \(\otimes \mathbf{e}_{d_x}\) is the outer product to repeat the column vector \(\mathbf{b}^g\) \(d_x\) times.</p>
<p>The ranker and reader components share the same Match-LSTM module with two separate prediction heads in the last layer, resulting in \(\mathbf{H}^\text{rank}\) and \(\mathbf{H}^\text{reader}\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/R%5E3-arch.png" alt="R^3 QA" /></p>
<p class="image-caption"><em>Fig. 6. The overview of R^3 (reinforced ranker-reader) architecture. Both components share the same Match-LSTM module. (Image source: <a href="https://arxiv.org/abs/1709.00023">Wang, et al., 2017</a>)</em></p>
<p>The retriever runs a max-pooling operation per passage and then aggregates to output a probability of each passage entailing the answer.</p>
\[\begin{aligned}
\mathbf{u}_i &= \text{max-pooling}(\mathbf{H}^\text{rank}_i) \in \mathbb{R}^l \\
\mathbf{C} &= \text{tanh}(\mathbf{W}^c[\mathbf{u}_1;\dots;\mathbf{u}_N] + \mathbf{b}^c \otimes \mathbf{e}_N) \in \mathbb{R}^{l \times n} \\
\gamma &= \text{softmax}(\mathbf{w}^c \mathbf{C}) \in \mathbb{R}^n
\end{aligned}\]
<p>Finally, the retriever is viewed as a <em>policy</em> to output action to sample a passage according to predicted \(\gamma\),</p>
\[\pi(z \vert x; \theta^\gamma) = \gamma_z\]
<p>The reader predicts the start position \(\beta^s\) and the end position \(\beta^e\) of the answer span. Two positions are computed in the same way, with independent parameters to learn. There are \(V\) words in all the passages involved.</p>
\[\begin{aligned}
\mathbf{H}^\text{read} &= [\mathbf{H}^\text{read}_\tau; \mathbf{H}^\text{read}_{\text{neg}_1}; \dots; \mathbf{H}^\text{read}_{\text{neg}_n}] \\
\mathbf{F}^s &= \text{tanh}(\mathbf{W}^s \mathbf{H}^\text{read} + \mathbf{b}^s \otimes \mathbf{e}_V) \quad
\beta^s = \text{softmax}(\mathbf{w}^s \mathbf{F}^s) \in \mathbb{R}^V \\
\mathbf{F}^e &= \text{tanh}(\mathbf{W}^e \mathbf{H}^\text{read} + \mathbf{b}^e \otimes \mathbf{e}_V) \quad
\beta^e = \text{softmax}(\mathbf{w}^e \mathbf{F}^e) \in \mathbb{R}^V \\
L(y \vert z, x) &= -\log(\beta^s_{y_z^s})-\log(\beta^e_{y_z^e})
\end{aligned}\]
<p>where \(y\) is the ground-truth answer and the passage \(z\) is sampled by the retriever. \(\beta^s_{y_z^s}\) and \(\beta^s_{y_z^e}\) represent the probabilities of the start and end positions of \(y\) in passage \(z\).</p>
<p>The training objective for the end-to-end R^3 QA system is to minimize the negative log-likelihood of obtaining the correct answer \(y\) given a question \(x\),</p>
\[\begin{aligned}
\mathcal{J}(\theta) &= -\mathbb{E}_{z\sim\pi(.\vert x)} [L(y \vert z, x)] \\
\nabla \mathcal{J}(\theta)
&= - \nabla_\theta \sum_z \pi(z \vert x) L(y \vert z, x) \\
&= - \sum_z \big( L(y \vert z, x) \nabla_\theta\pi(z \vert x) + \pi(z \vert x) \nabla_\theta L(y \vert z, x) \big) \\
&= - \mathbb{E}_{z\sim\pi(.\vert x)} \big( \color{red}{L(y \vert z, x)\nabla_\theta\log\pi(z \vert x)} + \nabla_\theta L(y \vert z, x) \big) \\
&\approx - \mathbb{E}_{z\sim\pi(.\vert x)} \big( \underbrace{\color{red}{R(y \vert z, x)\nabla_\theta\log\pi(z \vert x)}}_\text{REINFORCE} + \nabla_\theta L(y \vert z, x) \big)
\end{aligned}\]
<p>Essentially in training, given a passage \(z\) sampled by the retriever, the reader is trained by gradient descent while the retriever is trained by <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce">REINFORCE</a> using \(L(y \vert z, x)\) as the reward function. However, \(L(y \vert z, x)\) is not bounded and may introduce a lot of variance. The paper replaces the reward with a customized scoring function by comparing the ground truth \(y\) and the answer extracted by the reader \(\hat{y}\):</p>
\[R(y, \hat{y} \vert z) = \begin{cases}
2 & \text{if } y = \hat{y}\\
f1(y, \hat{y}) & \text{if } y \cap \hat{y} = \varnothing \\
-1 & \text{otherwise}
\end{cases}\]
<p style="width: 30%;" class="center"><img src="/lil-log/assets/images/R%5E3-reward-flow.png" alt="R^3 reward flow" /></p>
<p class="image-caption"><em>Fig. 7. The workflow of R^3 training process. (Image source: <a href="https://github.com/danqi/acl2020-openqa-tutorial/blob/master/slides/part4-retriever-reader.pdf">acl2020-openqa-tutorial/slides/part4</a>)</em></p>
<p><a name="ORQA"></a><strong>ORQA</strong> (“Open-Retrieval Question-Answering”; <a href="https://arxiv.org/abs/1906.00300">Lee et al., 2019</a>) jointly learns a retriever + reader QA model to optimize marginal log-likelihood of obtaining correct answers in a supervised manner. No explicit “black-box” IR system is involved. Instead, it is capable of retrieving any text in an open corpus. During training, ORQA does not need ground-truth context passages (i.e. reading comprehension datasets) but only needs (question, answer) string pairs. Both retriever and reader components are based on BERT, but not shared.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/ORQA-retriever.png" alt="ORQA-retriever" /></p>
<p class="image-caption"><em>Fig. 8. An illustration of the retriever component in ORQA. (Image source: replotted based on one slide in <a href="https://github.com/danqi/acl2020-openqa-tutorial/blob/master/slides/part5-dense-retriever-e2e-training.pdf">acl2020-openqa-tutorial/slides/part5</a>)</em></p>
<p>All the evidence blocks are ranked by a retrieval score, defined as the inner product of BERT embedding vectors of the <code class="language-plaintext highlighter-rouge">[CLS]</code> token of the question \(x\) and the evidence block \(z\). Note that the encoders for questions and context are independent.</p>
\[\begin{aligned}
h_x &= \mathbf{W}_x \text{BERT}_x(x)^{\mathtt{[CLS]}} \\
h_z &= \mathbf{W}_z \text{BERT}_z(z)^{\mathtt{[CLS]}} \\
S_\text{retr}(z, x) &= h_x^\top h_z
\end{aligned}\]
<p><a name="ICT-loss"></a>The retriever module is pretrained with <em>Inverse Cloze Task (ICT)</em>, which is to predict the context given a sentence, opposite to the standard <a href="https://en.wikipedia.org/wiki/Cloze_test">Cloze Task</a>. The ICT objective is to maximize the retrieval score of the correct context \(z\) given a random sentence \(x\):</p>
\[L_\text{ICT} = p_\text{early}(z \vert x) = \frac{\exp(S_\text{retr}(z, x))}{\sum_{z'\in\text{BATCH}(\mathcal{Z})} \exp(S_\text{retr}(z', x))}\]
<p>where \(\text{BATCH}(\mathcal{Z})\) is the set of evidence blocks in the same batch used as sampled negatives.</p>
<p>After such pretraining, the BERT retriever is expected to have representations good enough for evidence retrieval. Only the question encoder needs to be fine-tuned for answer extraction. In other words, the evidence block encoder (i.e., \(\mathbf{W}_z\) and \(\text{BERT}_z\)) is fixed and thus all the evidence block encodings can be pre-computed with support for <a href="#fast-maximum-inner-product-search-mips">fast Maximum Inner Product Search (MIPS)</a>.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/ORQA-reader.png" alt="ORQA-reader" /></p>
<p class="image-caption"><em>Fig. 9. An illustration of the reader component in ORQA. (Image source: <a href="https://github.com/danqi/acl2020-openqa-tutorial/blob/master/slides/part5-dense-retriever-e2e-training.pdf">acl2020-openqa-tutorial/slides/part5</a>)</em></p>
<p>The reader follows the same design as in the original <a href="/lil-log/2019/01/31/generalized-language-models.html#use-bert-in-downstream-tasks">BERT RC</a> experiments. It learns in a supervised manner, while the parameters of the evidence block encoder are fixed and all other parameters are fine-tuned. Given a question \(x\) and a gold answer string \(y\), the reader loss contains two parts:</p>
\[\mathcal{L}(x, y) = \mathcal{L}_\text{early}(x, y) + \mathcal{L}_\text{full}(x, y)\]
<p>(1) Find all correct text spans within top \(k\) evidence blocks and optimize for the marginal likelihood of a text span \(s\) that matches the true answer \(y\):</p>
\[\begin{aligned}
h_s &= \text{BERT}_R(x, y)^{(\text{START}(s))} \\
h_e &= \text{BERT}_R(x, y)^{(\text{END}(s))} \\
S_\text{read}(z, s, x) &= \text{MLP}([h_s; h_e]) \\
p(z, s \vert x) &= \frac{\exp(S_\text{read}(z, s, x))}{\sum_{z'\in\text{TOP}(k)} \sum_{s'\in z'} \exp(S_\text{read}(z', s', x))} \\
L_\text{full}(x, y) &= - \log \sum_{\substack{z \in \text{TOP}(k)\\ s \in z}} \sum_{y=\text{TEXT}(s)} p(z, s \vert x)
\end{aligned}\]
<p>where \(y=\text{TEXT}(s)\) indicates whether the answer \(y\) matches the text span \(s\). \(\text{TOP}(k)\) is the top \(k\) retrieved blocks according to \(S_\text{retr}(z, x)\). The paper sets \(k=5\).</p>
<p>(2) At the early stage of learning, when the retriever is not strong enough, it is possible none of the top \(k\) blocks contains the answer. To avoid such sparse learning signals, ORQA considers a larger set of \(c\) evidence blocks for more aggressive learning. The paper has \(c=5000\).</p>
\[L_\text{early}(x, y)
= -\log \sum_{\substack{z\in \text{TOP}(c)\\y\in\text{TEXT}(z)}} p_\text{early}(z\vert x)
= -\log \sum_{\substack{z\in \text{TOP}(c)\\y\in\text{TEXT}(z)}} \frac{\exp(S_\text{retr}(z, x)}{\sum_{z'\in\text{TOP}(c)} \exp(S_\text{retr}(z', x)}\]
<p>Some issues in SQuAD dataset were discussed in the ORQA paper:</p>
<blockquote>
<p>” The notable drop between development and test accuracy for SQuAD is a reflection of an artifact in the dataset—its 100k questions are derived from only 536 documents. Therefore, good retrieval targets are highly correlated between training examples, violating the IID assumption, and making it unsuitable for learned retrieval. We strongly suggest that those who are interested in end-to-end open-domain QA models no longer train and evaluate with SQuAD for this reason.”</p>
</blockquote>
<p><a name="REALM"></a><strong>REALM</strong> (“Retrieval-Augmented Language Model pre-training”; <a href="https://arxiv.org/abs/2002.08909">Guu et al., 2020</a>) also jointly trains retriever + reader by optimizing the marginal likelihood of obtaining the true answer:</p>
\[p(y \vert x)
= \sum_{z \in \mathcal{Z}} \underbrace{p(y \vert x, z)}_\text{reader} \underbrace{p(z \vert x)}_\text{retriever}
\approx \sum_{z \in \text{TOP}_k(\mathcal{Z})} p(y \vert x, z) p(z \vert x)\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/REALM-train.png" alt="REALM" /></p>
<p class="image-caption"><em>Fig. 10. REALM is first unsupervised pre-trained with salient spans masking and then fine-tuned with QA data. (Image source: <a href="https://arxiv.org/abs/2002.08909">Guu et al., 2020</a>).</em></p>
<p>REALM computes two probabilities, \(p(z \vert x)\) and \(p(y \vert x, z)\), same as ORQA. However, different from ICT in ORQA, REALM upgrades the unsupervised pre-training step with several new design decisions, leading towards better retrievals. REALM pre-trains the model with Wikipedia or CC-News corpus.</p>
<ol>
<li><a name="ssm"></a>Use <em>salient span masking</em>. Named entities and dates are identified. Then one of these “salient spans” is selected and masked. Salient span masking is a special case of MLM and works out well for QA tasks.</li>
<li>Add an <em>empty null document</em>. Because not every question demands a context document.</li>
<li>No trivial retrieval. The context document should not be same as the selected sentence with a masked span.</li>
<li>Apply the same ICT loss as in ORQA to encourage learning when the retrieval quality is still poor at the early stage of training.</li>
</ol>
<blockquote>
<p>“Among all systems, the most direct comparison with REALM is ORQA (Lee et al., 2019), where the fine-tuning setup, hyperparameters and training data are identical. The improvement of REALM over ORQA is purely due to better pre-training methods.” — from REALM paper.</p>
</blockquote>
<p>Both unsupervised pre-training and supervised fine-tuning optimize the same log-likelihood \(\log p(y \vert x)\). Because the parameters of the retriever encoder for evidence documents are also updated in the process, the index for MIPS is changing. REALM asynchronously refreshes the index with the updated encoder parameters every several hundred training steps.</p>
<p><a href="https://arxiv.org/abs/2104.08710">Balachandran, et al. (2021)</a> found that REALM is significantly undertrained and REALM++ achieves great EM accuracy improvement (3-5%) by scaling up the model training with larger batch size and more retrieved documents for the reader to process.</p>
<p><a name="DPR"></a><strong>DPR</strong> (“Dense Passage Retriever”; <a href="https://arxiv.org/abs/2004.04906">Karpukhin et al., 2020</a>, <a href="https://github.com/facebookresearch/DPR">code</a>) argues that ICT pre-training could be too computationally expensive and the ORQA’s context encoder might be sub-optimal because it is not fine-tuned with question-answer pairs. DPR aims to resolve these two issues by only training a dense dual-encoder architecture for retrieval only from a small number of Q/A pairs, without any pre-training.</p>
<p>Same as previous work, DPR uses the dot-product (L2 distance or cosine similarity also works) of BERT representations as retrieval score. The loss function for training the dual-encoder is the NLL of the positive passage, which essentially takes the same formulation as <a href="#ICT-loss">ICT loss</a> of ORQA. Note that both of them consider other passages in the same batch as the negative samples, named <em>in-batch negative sampling</em>. The main difference is that DPR relies on supervised QA data, while ORQA trains with ICT on unsupervised corpus. At the inference time, DPR uses <a href="https://github.com/facebookresearch/faiss">FAISS</a> to run fast MIPS.</p>
<p>DPR did a set of comparison experiments involving several different types of negatives:</p>
<ol>
<li>Random: any random passage from the corpus;</li>
<li>BM25: top passages returned by BM25 which don’t contain the answer but match most question tokens;</li>
<li>In-batch negative sampling (“gold”): positive passages paired with other questions which appear in the training set.</li>
</ol>
<p>DPR found that using gold passages from the same mini-batch and one negative passage with high BM25 score works the best. To further improve the retrieval results, DPR also explored a setting where a BM25 score and a dense embedding retrieval score are linearly combined to serve as a new ranking function.</p>
<h2 id="open-book-qa-retriever-generator">Open-book QA: Retriever-Generator</h2>
<p>Compared to the retriever-reader approach, the retriever-generator also has 2 stages but the second stage is to generate free text directly to answer the question rather than to extract start/end position in a retrieved passage. Some paper also refer to this as <em>Generative question answering</em>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/QA-retiever-generator.png" alt="retriever + text generator" /></p>
<p class="image-caption"><em>Fig. 11. The retriever + generator QA framework combines a document retrieval system with a general language model.</em></p>
<p>A pretrained LM has a great capacity of memorizing knowledge in its parameters, as shown above. However, they cannot easily modify or expand their memory, cannot straightforwardly provide insights into their predictions, and may produce non-existent illusion.</p>
<p><a href="https://arxiv.org/abs/2005.04611">Petroni et al. (2020)</a> studied how the retrieved relevant context can help a generative language model produce better answers. They found:</p>
<ol>
<li>Augmenting queries with relevant contexts dramatically improves the pretrained LM on unsupervised machine reading capabilities.</li>
<li>An off-the-shelf IR system is sufficient for BERT to match the performance of a supervised ODQA baseline;</li>
<li>BERT’s <a href="/lil-log/2019/01/31/generalized-language-models.html#pre-training-tasks">NSP</a> pre-training strategy is a highly effective unsupervised mechanism in dealing with noisy and irrelevant contexts.</li>
</ol>
<p>They pair the BERT model with different types of context, including adversarial (unrelated context), retrieved (by BM25), and generative (by an autoregressive language model of 1.4N parameters, trained on CC-NEWS). The model is found to be robust to adversarial context, but only when the question and the context are provided as two segments (e.g. separated by <code class="language-plaintext highlighter-rouge">[SEP]</code>). One hypothesis is related to NSP task: “BERT might learn to not condition across segments for masked token prediction if the NSP score is low, thereby implicitly detecting irrelevant and noisy contexts.”</p>
<p><strong>RAG</strong> (“Retrieval-Augmented Generation”; <a href="https://arxiv.org/abs/2005.11401">Lewis et al., 2020</a>) combines pre-trained parametric (language model) and non-parametric memory (external knowledge index) together for language generation. RAG can be fine-tuned on any seq2seq task, whereby both the retriever and the sequence generator are jointly learned. They found that unconstrained generation outperforms previous extractive approaches.</p>
<p>RAG consists of a retriever model \(p_\eta(z \vert x)\) and a generator model \(p_\theta(y_i \vert x, z, y_{1:i-1})\):</p>
<ul>
<li>The retriever uses the input sequence \(x\) to retrieve text passages \(z\), implemented as a <a href="#DPR">DPR</a> retriever. \(\log p_\eta(z \vert x) \propto E_z(z)^\top E_x(x)\).</li>
<li>The generator uses \(z\) as additional context when generating the target sequence \(y\), where the context and the question are simply concatenated.</li>
</ul>
<p>Depending on whether using the same or different retrieved documents for each token generation, there are two versions of RAG:</p>
\[\begin{aligned}
p_\text{RAG-seq}(y \vert x) &= \sum_{z \in \text{TOP}_k(p_\eta(.\vert x))} p_\eta(z \vert x) \prod_i^N p_\theta(y_i \vert x, z, y_{1:i-1}) \\
p_\text{RAG-token}(y \vert x) &= \prod_i^N \sum_{z \in \text{TOP}_k(p_\eta(.\vert x))} p_\eta(z_i\vert x) p_\theta(y_i \vert x, z_i, y_{1:i-1})
\end{aligned}\]
<p>The retriever + generator in RAG is jointly trained to minimize the NLL loss, \(\mathcal{L}_\text{RAG} = \sum_j -\log p(y_j \vert x_j)\). Updating the passage encoder \(E_z(.)\) is expensive as it requires the model to re-index the documents for fast MIPS. RAG does not find fine-tuning \(E_z(.)\) necessary (like in <a href="#ORQA">ORQA</a>) and only updates the query encoder + generator.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RAG.png" alt="RAG" /></p>
<p class="image-caption"><em>Fig. 12. An illustration of retrieval-augmented generation (RAG) architecture. (Image source: <a href="https://arxiv.org/abs/2005.11401">Lewis et al., 2020</a>)</em></p>
<p>At decoding/test time, RAG-token can be evaluated via a <a href="https://d2l.ai/chapter_recurrent-modern/beam-search.html#id1">beam search</a>. RAG-seq cannot be broken down into a set of per-token likelihood, so it runs beam search for each candidate document \(z\) and picks the one with optimal \(p_\theta(y_i \vert x, z, y_{1:i-1})\).</p>
<p>The <em>Fusion-in-Decoder</em> approach, proposed by <a href="https://arxiv.org/abs/2007.01282">Izacard & Grave (2020)</a> is also based on a pre-trained T5. It works similar to RAG but differently for how the context is integrated into the decoder.</p>
<ol>
<li>Retrieve top \(k\) related passage of 100 words each, using BM25 or DPR.</li>
<li>Each retrieved passage and its title are concatenated with the question using special tokens like <code class="language-plaintext highlighter-rouge">question:</code>, <code class="language-plaintext highlighter-rouge">title:</code> and <code class="language-plaintext highlighter-rouge">context:</code> to indicate the content differences.</li>
<li>Each retrieved passage is processed independently and later combined in the decoder. Processing passages independently in the encoder allows us to parallelize the computation. OTOH, processing them jointly encourages better aggregation of multiple pieces of evidence. The aggregation part is missing in extractive approaches.</li>
</ol>
<p>Note that they did fine-tune the pretrained LM independently for each dataset.</p>
<h2 id="closed-book-qa-generative-language-model">Closed-book QA: Generative Language Model</h2>
<p>Big language models have been pre-trained on a large collection of unsupervised textual corpus. Given enough parameters, these models are able to memorize some factual knowledge within parameter weights. Therefore, we can use these models to do question-answering without explicit context, just like in a closed-book exam. The pre-trained language models produce <em>free text</em> to respond to questions, no explicit reading comprehension.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/LM-compute.png" alt="LM compute" /></p>
<p class="image-caption"><em>Fig. 13. The amount of computation used for training big language models of different sizes is getting big. (Image source: <a href="https://arxiv.org/abs/2005.14165">Brown et al., 2020</a>).</em></p>
<p><a href="https://arxiv.org/abs/2002.08910">Roberts et al. (2020)</a> measured the practical utility of a language model by fine-tuning a pre-trained model to answer questions without access to any external context or knowledge. They fine-tuned the <a href="https://arxiv.org/abs/1910.10683">T5</a> language model (same architecture as the original Transformer) to answer questions without inputting any additional information or context. Such setup enforces the language model to answer questions based on “knowledge” that it internalized during pre-training.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/T5_SSM.png" alt="T5+SSM" /></p>
<p class="image-caption"><em>Fig. 14. T5 is first pre-trained with salient span masking and then fine-tuned for each QA dataset to produce answers in free text. (Image source: <a href="https://arxiv.org/abs/2002.08910">Roberts et al. 2020</a>)</em></p>
<p>The original T5 models were pre-trained on a multi-task mixture including an unsupervised <a href="/lil-log/2019/01/31/generalized-language-models.html#use-bert-in-downstream-tasks">“masked language modeling”</a> (MLM) tasks on the C4 (“Colossal Clean Crawled Corpus”) dataset as well as fine-tuned altogether with supervised translation, summarization, classification, and reading comprehension tasks. <a href="https://arxiv.org/abs/2002.08910">Roberts, et al. (2020)</a> took a pre-trained T5 model and continued pre-training with <a href="#ssm">salient span masking</a> over Wikipedia corpus, which has been found to substantially boost the performance for ODQA. Then they fine-tuned the model for each QA datasets independently.</p>
<p>With a pre-trained T5 language model + continue pre-training with salient spans masking + fine-tuning for each QA dataset,</p>
<ul>
<li>It can attain competitive results in open-domain question answering without access to external knowledge.</li>
<li>A larger model can obtain better performance. For example, a T5 with 11B parameters is able to match the performance with <a href="#DPR">DPR</a> with 3 BERT-base models, each with 330M parameters.</li>
</ul>
<p>Interestingly, fine-tuning is not strictly necessary. GPT3 (<a href="https://arxiv.org/abs/2005.14165">Brown et al., 2020</a>) has been evaluated on the closed book question answering task <em>without any gradient updates or fine-tuning</em>. During evaluation, the few-shot, one-shot and zero-shot settings here only refer to how many demonstrations are provided as context in the text input:</p>
<ol>
<li>“few-shot learning”: GPT3 is allowed to take as many demonstrations as what can fit into the model’s context window (typically 10 to 100).</li>
<li>“one-shot learning”: only one demonstration is provided.</li>
<li>“zero-shot learning”: no demonstrations are allowed and only an instruction in natural language is given to the model.</li>
</ol>
<p>The performance grows with the model size. On the TriviaQA dataset, GPT3 evaluation with demonstrations can match or exceed the performance of SOTA baseline with fine-tuning.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/GPT3-triviaqa.png" alt="GPT3 on TriviaQA" /></p>
<p class="image-caption"><em>Fig. 15. GPT3’s performance on TriviaQA grows smoothly with the model size. More demonstrations lead to better performance. (Image source: <a href="https://arxiv.org/abs/2005.14165">Brown et al., 2020</a>).</em></p>
<p><a name="openai-api-example"></a>Check out this cool example in OpenAI API <a href="https://beta.openai.com/playground/p/HMoho4552EHXrPLbmOIxpX4X">playground viewer</a>. The model is able to answer factal questions in short answer and not to make up things when the model does not know the answer. I added the last two questions and asked the model to respond with <code class="language-plaintext highlighter-rouge">A:</code>. The API is still in beta version, so you might need to <a href="https://beta.openai.com/">apply</a> to get on the wait list.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Q: Who is Batman?
A: Batman is a fictional comic book character.
###
Q: What is torsalplexity?
A: ?
###
Q: What is Devz9?
A: ?
###
Q: Who is George Lucas?
A: George Lucas is American film director and producer famous for creating Star Wars.
###
Q: What is the capital of California?
A: Sacramento.
###
Q: What orbits the Earth?
A: The Moon.
###
Q: Who is Fred Rickerson?
A: ?
###
Q: What is an atom?
A: An atom is a tiny particle that makes up everything.
###
Q: Who is Alvan Muntz?
A: ?
###
Q: What is Kozar-09?
A: ?
###
Q: How many moons does Mars have?
A: Two, Phobos and Deimos.
###
Q: What is COVID-19?
A: ?
###
Q: What is H1N1?
A: H1N1 is a strain of influenza.
</code></pre></div></div>
<h2 id="related-techniques">Related Techniques</h2>
<h3 id="fast-maximum-inner-product-search-mips">Fast Maximum Inner Product Search (MIPS)</h3>
<p>MIPS (maximum inner product search) is a crucial component in many open-domain question answering models. In retriever + reader/generator framework, a large number of passages from the knowledge source are encoded and stored in a memory. A retrieval model is able to query the memory to identify the top relevant passages which have the maximum inner product with the question’s embedding.</p>
<p>We need fast MIPS because the number of precomputed passage representations can be gigantic. There are several ways to achieve fast MIPS at run time, such as <a href="https://papers.nips.cc/paper/5329-asymmetric-lsh-alsh-for-sublinear-time-maximum-inner-product-search-mips.pdf">asymmetric LSH</a>, <a href="https://arxiv.org/abs/1501.01062">data-dependent hashing</a>, and <a href="https://github.com/facebookresearch/faiss">FAISS</a>.</p>
<h3 id="language-model-pre-training">Language Model Pre-training</h3>
<p>Two pre-training tasks are especially helpful for QA tasks, as we have discussed above.</p>
<ul>
<li>
<p><strong>Inverse Cloze Task</strong> (proposed by <a href="#ORQA">ORQA</a>): The goal of <a href="https://en.wikipedia.org/wiki/Cloze_test">Cloze Task</a> is to predict masked-out text based on its context. The prediction of Inverse Cloze Task (ICT) is in the reverse direction, aiming to predict the context given a sentence. In the context of QA tasks, a random sentence can be treated as a pseudo-question, and its context can be treated as pseudo-evidence.</p>
</li>
<li>
<p><strong>Salient Spans Masking</strong> (proposed by <a href="#REALM">REALM</a>): Salient span masking is a special case for MLM task in language model training. First, we find <em>salient spans</em> by using a tagger to identify named entities and a regular expression to identify dates. Then one of the detected salient spans is selected and masked. The task is to predict this masked salient span.</p>
</li>
</ul>
<h2 id="summary">Summary</h2>
<table class="info">
<thead>
<tr>
<th>Model</th>
<th>Retriever</th>
<th>Reader / Generator</th>
<th>Pre-training / Fine-tuning</th>
<th>End2end</th>
</tr>
</thead>
<tbody>
<tr>
<td>DrQA</td>
<td>TF-IDF</td>
<td>Bi-directional LSTM</td>
<td>–</td>
<td>No</td>
</tr>
<tr>
<td>BERTserini</td>
<td>Aserini + BM25</td>
<td>BERT without softmax layer</td>
<td>Fine-tune with SQuAD</td>
<td>No</td>
</tr>
<tr>
<td>Multi-passage BERT</td>
<td>ElasticSearch + BM25</td>
<td>Multi-passage BERT + Passage ranker</td>
<td> </td>
<td>No</td>
</tr>
<tr>
<td>R^3</td>
<td>Classic IR + Match-LSTM</td>
<td>Match-LSTM</td>
<td> </td>
<td>Yes</td>
</tr>
<tr>
<td>ORQA</td>
<td>Dot product of BERT embeddings</td>
<td>BERT-RC</td>
<td>Inverse cloze task</td>
<td>Yes</td>
</tr>
<tr>
<td>REALM</td>
<td>Dot product of BERT embeddings</td>
<td>BERT-RC</td>
<td>Salient span masking</td>
<td>Yes</td>
</tr>
<tr>
<td>DPR</td>
<td>Dot product of BERT embeddings</td>
<td>BERT-RC</td>
<td>supervised training with QA pairs</td>
<td>Yes</td>
</tr>
<tr>
<td>DenSPI</td>
<td>Classic + Neural IR</td>
<td>–</td>
<td> </td>
<td>Yes</td>
</tr>
<tr>
<td>T5 + SSM</td>
<td>–</td>
<td>T5</td>
<td>SSM on <a href="https://commoncrawl.org/the-data/get-started/">CommonCrawl</a> data + Fine-tuning on QA data</td>
<td>Yes</td>
</tr>
<tr>
<td>GPT3</td>
<td>–</td>
<td>GPT3</td>
<td>NSP on <a href="https://commoncrawl.org/the-data/get-started/">CommonCrawl</a> data</td>
<td>Yes</td>
</tr>
<tr>
<td>RAG</td>
<td>DPR retriever</td>
<td><a href="https://arxiv.org/abs/1910.13461">BART</a></td>
<td> </td>
<td>Yes</td>
</tr>
<tr>
<td>Fusion-in-Decoder</td>
<td>BM25 / DPR retriever</td>
<td>Tranformer</td>
<td> </td>
<td>No</td>
</tr>
</tbody>
</table>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/QA-results.png" alt="SOTA-comparison" /></p>
<p class="image-caption"><em>Fig. 16. A comparison of performance of several QA models on common QA datasets. On TriviaQA, two columns of results are reported, on the open domain test set (left) and on the hidden test set (right). (Image source: <a href="https://arxiv.org/abs/2007.01282">Izacard & Grave, 2020</a>).</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020odqa,
title = "How to Build an Open-Domain Question Answering System?",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/10/29/open-domain-question-answering.html"
}
</code></pre></div></div>
<h2 id="appendix-qa-datasets">Appendix: QA Datasets</h2>
<ul>
<li><a href="https://rajpurkar.github.io/SQuAD-explorer/">SQuAD 2.0</a>: the Stanford QA dataset.</li>
<li><a href="http://www.qizhexie.com/data/RACE_leaderboard">RACE</a>: a reading comprehension dataset collected from English Examinations that are created for middle school and high school students.</li>
<li><a href="https://trec.nist.gov/data/qa.html">TREC QA</a>: the TREC QA collections.</li>
<li><a href="https://microsoft.github.io/msmarco/">MS MARCO</a>: a QA dataset featuring 100,000 real Bing questions and a human generated answer.</li>
<li><a href="https://github.com/brmson/dataset-factoid-curated">CuratedTREC</a>: based on the benchmarks from the TREC QA tasks that have been curated by <a href="https://link.springer.com/chapter/10.1007%2F978-3-319-24027-5_20">Baudis & Sedivy (2015)</a>.</li>
<li><a href="https://ai.google.com/research/NaturalQuestions/dataset">Google Natural Questions</a>: contains real user questions issued to Google search, and answers found from Wikipedia by annotators.</li>
<li><a href="https://github.com/brmson/dataset-factoid-webquestions">WebQuestions</a>: designed for knowledge-base QA with answers restricted to Freebase entities.</li>
<li><a href="https://www.microsoft.com/en-us/research/publication/wikiqa-a-challenge-dataset-for-open-domain-question-answering/">WikiQA</a>: Bing query logs were used as the source of questions. Each question is then linked to a Wikipedia page that potentially contains the answer.</li>
<li><a href="https://research.fb.com/downloads/babi/">WikiMovies</a>: contains movie-related questions from the OMDb and MovieLens databases and where the questions can be answered using Wikipedia pages.</li>
<li><a href="https://github.com/google-research-datasets/wiki-reading">WikiReading</a>: to predict textual values from the structured knowledge base Wikidata by reading the text of the corresponding Wikipedia articles.</li>
<li><a href="https://nlp.cs.washington.edu/triviaqa/">TriviaQA</a>: a reading comprehension dataset containing 95K question-answer pairs authored by trivia enthusiasts and independently gathered multiple evidence documents per question.</li>
<li><a href="https://www.kaggle.com/tunguz/200000-jeopardy-questions"> Jeopardy! Questions</a>: contains 200,000+ <a href="https://en.wikipedia.org/wiki/Jeopardy!">Jeopardy!</a> questions.</li>
<li><a href="https://cs.nyu.edu/~kcho/DMQA/">DeepMind Q&A Dataset</a>: question/answer pairs from CNN and Daily Mail articles.</li>
<li><a href="https://research.fb.com/downloads/babi/">bAbi</a>: a rich collection of datasets for text understanding by Facebook.</li>
<li><a href="https://fever.ai/data.html">FEVER</a>: for fact extraction and verification.</li>
<li><a href="https://github.com/nyu-dl/dl4ir-searchQA">SearchQA</a>: question-answer pairs were crawled from from <a href="https://j-archive.com/"> J! Archive</a>, and then augmented with text snippets from Google.</li>
<li><a href="https://github.com/bdhingra/quasar">Quasar-T</a>: a collection of open-domain trivia questions and their answers obtained from various internet sources.</li>
<li><a href="https://people.cs.umass.edu/~miyyer/qblearn/index.html">Quiz bowl</a>: contains data from a trivia competition called quiz bowl.</li>
<li><a href="https://nlp.cs.washington.edu/ambigqa/">AmbigNQ</a>: ambiguous questions selected from NQ-OPEN dataset.</li>
<li><a href="https://github.com/facebookresearch/QA-Overlap">QA-Overlap</a>: a collections of overlapped answers/questions between train and test set for Natural Questions, TriviaQA, and WebQuestions.</li>
</ul>
<h2 id="references">References</h2>
<p>[1] Danqi Chen & Scott Yih. <a href="https://github.com/danqi/acl2020-openqa-tutorial">“ACL2020 Tutorial: Open-Domain Question Answering”</a> July 2020.</p>
<table>
<tbody>
<tr>
<td>[2] Danqi Chen, et al. <a href="https://arxiv.org/abs/1704.00051">“Reading Wikipedia to Answer Open-Domain Questions”</a> ACL 2017.</td>
<td><a href="https://github.com/facebookresearch/DrQA">code</a></td>
</tr>
</tbody>
</table>
<p>[3] Shuohang Wang, et al. <a href="https://arxiv.org/abs/1709.00023">“R^3: Reinforced Ranker-Reader for Open-Domain Question Answering”</a> AAAI 2018.</p>
<p>[4] Jimmy Lin. <a href="https://sigir.org/wp-content/uploads/2019/01/p040.pdf">“The neural hype and comparisons against weak baselines.”</a> ACM SIGIR Forum. Vol. 52. No. 2. 2019.</p>
<p>[5] Wei Yang, et al. <a href="https://arxiv.org/abs/1902.01718">“End-to-End Open-Domain Question Answering with BERTserini”</a> NAACL 2019.</p>
<p>[6] Christopher Clark & Matt Gardner. <a href="https://arxiv.org/abs/1710.10723">“Simple and Effective Multi-Paragraph Reading Comprehension.”</a> arXiv:1710.10723 (2017).</p>
<p>[7] Rodrigo Nogueira & Kyunghyun Cho. <a href="https://arxiv.org/abs/1901.04085">“Passage Re-ranking with BERT.”</a> arXiv preprint arXiv:1901.04085 (2019). | <a href="https://github.com/nyu-dl/dl4marco-bert">code</a></p>
<p>[8] Zhiguo Wang, et al. <a href="https://arxiv.org/abs/1908.08167">“Multi-passage BERT: A globally normalized BERT model for open-domain question answering.”</a> EMNLP 2019.</p>
<p>[9] Minjoon Seo et al. <a href="https://arxiv.org/abs/1906.05807">“Real-time open-domain question answering with dense-sparse phrase index.”</a> ACL 2019.</p>
<p>[10] Kenton Lee, et al. <a href="https://arxiv.org/abs/1906.00300">“Latent Retrieval for Weakly Supervised Open Domain Question Answering”</a> ACL 2019.</p>
<p>[11] Kelvin Guu, et al. <a href="https://arxiv.org/abs/2002.08909">“REALM: Retrieval-Augmented Language Model Pre-Training”</a> arXiv:2002.08909 (2020).</p>
<p>[12] Vladimir Karpukhin et al. <a href="https://arxiv.org/abs/2004.04906">“Dense passage retrieval for open-domain question answering.”</a>. EMNLP 2020. | <a href="https://github.com/facebookresearch/DPR">code</a></p>
<p>[13] Patrick Lewis et al. <a href="https://arxiv.org/abs/2005.11401">“Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks”</a> arXiv:2005.11401 (2020).</p>
<p>[14] Adam Roberts, et al. <a href="https://arxiv.org/abs/2002.08910">“How Much Knowledge Can You Pack Into the Parameters of a Language Model?”</a> EMNLP 2020.</p>
<p>[15] Tom Brown, et al. <a href="https://arxiv.org/abs/2005.14165">“Language models are few-shot learners.”</a> arXiv:2005.14165 (2020).</p>
<p>[16] Fabio Petroni, et al. <a href="https://arxiv.org/abs/2005.04611">“How Context Affects Language Models’ Factual Predictions”</a> AKBC 2020.</p>
<p>[17] Gautier Izacard & Edouard Grave. <a href="https://arxiv.org/abs/2007.01282">“Leveraging passage retrieval with generative models for open domain question answering.”</a> arXiv:2007.01282 (2020).</p>
<p>[18] <a href="https://d2l.ai/chapter_recurrent-modern/beam-search.html">“Dive into deep learning: Beam search”</a></p>
<p>[19] Patrick Lewis, et al. <a href="https://arxiv.org/abs/2008.02637">“Question and Answer Test-Train Overlap in Open-Domain Question Answering Datasets”</a> arXiv:2008.02637 (2020). | <a href="https://github.com/facebookresearch/QA-Overlap">data</a></p>
<p>[20] Hervé Jegou, et al. <a href="https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/">“Faiss: A library for efficient similarity search”</a> Mar 2017.</p>
<p>[21] Vidhisha Balachandran, et al. <a href="https://arxiv.org/abs/2104.08710">“Simple and Efficient ways to Improve REALM.”</a> arXiv:2104.08710 (2021).</p>Lilian WengA model that is capable of answering any question with regard to factual knowledge can enable many useful applications. This post delves into how we can build an Open-Domain Question Answering (ODQA) system, assuming we have access to a powerful pretrained language model. Both closed-book and open-book approachs are discussed.Neural Architecture Search2020-08-06T12:00:00+00:002020-08-06T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/08/06/neural-architecture-search<blockquote>
<p>Neural Architecture Search (NAS) automates network architecture engineering. It aims to learn a network topology that can achieve best performance on a certain task. By dissecting the methods for NAS into three components: search space, search algorithm and child model evolution strategy, this post reviews many interesting ideas for better, faster and more cost-efficient automatic neural architecture search.</p>
</blockquote>
<!--more-->
<p>Although most popular and successful model architectures are designed by human experts, it doesn’t mean we have explored the entire network architecture space and settled down with the best option. We would have a better chance to find the optimal solution if we adopt a systematic and automatic way of learning high-performance model architectures.</p>
<p>Automatically learning and evolving network topologies is not a new idea (<a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">Stanley & Miikkulainen, 2002</a>). In recent years, the pioneering work by <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a> and <a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a> has attracted a lot of attention into the field of Neural Architecture Search (NAS), leading to many interesting ideas for better, faster and more cost-efficient NAS methods.</p>
<p>As I started looking into NAS, I found this nice survey very helpful by <a href="https://arxiv.org/abs/1808.05377">Elsken, et al 2019</a>. They characterize NAS as a system with three major components, which is clean & concise, and also commonly adopted in other NAS papers.</p>
<ol>
<li><strong>Search space</strong>: The NAS search space defines a set of operations (e.g. convolution, fully-connected, pooling) and how operations can be connected to form valid network architectures. The design of search space usually involves human expertise, as well as unavoidably human biases.</li>
<li><strong>Search algorithm</strong>: A NAS search algorithm samples a population of network architecture candidates. It receives the child model performance metrics as rewards (e.g. high accuracy, low latency) and optimizes to generate high-performance architecture candidates.</li>
<li><strong>Evaluation strategy</strong>: We need to measure, estimate, or predict the performance of a large number of proposed child models in order to obtain feedback for the search algorithm to learn. The process of candidate evaluation could be very expensive and many new methods have been proposed to save time or computation resources.</li>
</ol>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/NAS-high-level.png" alt="High-level categorization of NAS" /></p>
<p class="image-caption"><em>Fig. 1. Three main components of Neural Architecture Search (NAS) models. (Image source: <a href="https://arxiv.org/abs/1808.05377">Elsken, et al. 2019</a> with customized annotation in red)</em></p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#search-space" id="markdown-toc-search-space">Search Space</a> <ul>
<li><a href="#sequential-layer-wise-operations" id="markdown-toc-sequential-layer-wise-operations">Sequential Layer-wise Operations</a></li>
<li><a href="#cell-based-representation" id="markdown-toc-cell-based-representation">Cell-based Representation</a></li>
<li><a href="#hierarchical-structure" id="markdown-toc-hierarchical-structure">Hierarchical Structure</a></li>
<li><a href="#memory-bank-representation" id="markdown-toc-memory-bank-representation">Memory-bank Representation</a></li>
</ul>
</li>
<li><a href="#search-algorithms" id="markdown-toc-search-algorithms">Search Algorithms</a> <ul>
<li><a href="#random-search" id="markdown-toc-random-search">Random Search</a></li>
<li><a href="#reinforcement-learning" id="markdown-toc-reinforcement-learning">Reinforcement Learning</a></li>
<li><a href="#evolutionary-algorithms" id="markdown-toc-evolutionary-algorithms">Evolutionary Algorithms</a></li>
<li><a href="#progressive-decision-process" id="markdown-toc-progressive-decision-process">Progressive Decision Process</a></li>
<li><a href="#gradient-descent" id="markdown-toc-gradient-descent">Gradient descent</a></li>
</ul>
</li>
<li><a href="#evaluation-strategy" id="markdown-toc-evaluation-strategy">Evaluation Strategy</a> <ul>
<li><a href="#training-from-scratch" id="markdown-toc-training-from-scratch">Training from Scratch</a></li>
<li><a href="#proxy-task-performance" id="markdown-toc-proxy-task-performance">Proxy Task Performance</a></li>
<li><a href="#parameter-sharing" id="markdown-toc-parameter-sharing">Parameter Sharing</a></li>
<li><a href="#prediction-based" id="markdown-toc-prediction-based">Prediction-Based</a></li>
</ul>
</li>
<li><a href="#one-shot-approach-search--evaluation" id="markdown-toc-one-shot-approach-search--evaluation">One-Shot Approach: Search + Evaluation</a></li>
<li><a href="#whats-the-future" id="markdown-toc-whats-the-future">What’s the Future?</a></li>
<li><a href="#appendix-summary-of-nas-papers" id="markdown-toc-appendix-summary-of-nas-papers">Appendix: Summary of NAS Papers</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="search-space">Search Space</h2>
<p>The NAS search space defines a set of basic network operations and how operations can be connected to construct valid network architectures.</p>
<h3 id="sequential-layer-wise-operations">Sequential Layer-wise Operations</h3>
<p>The most naive way to design the search space for neural network architectures is to depict network topologies, either CNN or RNN, with a list of <em>sequential layer-wise operations</em>, as seen in the early work of <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a> & <a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a>. The serialization of network representation requires a decent amount of expert knowledge, since each operation is associated with different layer-specific parameters and such associations need to be hardcoded. For example, after predicting a <code class="language-plaintext highlighter-rouge">conv</code> op, the model should output kernel size, stride size, etc; or after predicting an <code class="language-plaintext highlighter-rouge">FC</code> op, we need to see the number of units as the next prediction.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NAS-search-space.png" alt="The sequential layer-wise operation search space" /></p>
<p class="image-caption"><em>Fig. 2. (Top) A sequential representation of CNN. (Bottom) A sequential representation of the tree structure of a recurrent cell. (Image source: <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>)</em></p>
<p>To make sure the generated architecture is valid, additional rules might be needed (<a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>):</p>
<ul>
<li>If a layer is not connected to any input layer then it is used as the input layer;</li>
<li>At the final layer, take all layer outputs that have not been connected and concatenate them;</li>
<li>If one layer has many input layers, then all input layers are concatenated in the depth dimension;</li>
<li>If input layers to be concatenated have different sizes, we pad the small layers with zeros so that the concatenated layers have the same sizes.</li>
</ul>
<p>The skip connection can be predicted as well, using an <a href="/lil-log /2018/06/24/attention-attention.html">attention</a>-style mechanism. At layer \(i\) , an anchor point is added with \(i−1\) content-based sigmoids to indicate which of the previous layers to be connected. Each sigmoid takes as input the hidden states of the current node \(h_i\) and \(i-1\) previous nodes \(h_j, j=1, \dots, i-1\) .</p>
\[P(\text{Layer j is an input to layer i}) = \text{sigmoid}(v^\top \tanh(\mathbf{W}_\text{prev} h_j + \mathbf{W}_\text{curr} h_i))\]
<p>The sequential search space has a lot of representation power, but it is very large and consumes a ton of computation resources to exhaustively cover the search space. In the experiments by <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>, they were running 800 GPUs in parallel for 28 days and <a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a> restricted the search space to contain at most 2 <code class="language-plaintext highlighter-rouge">FC</code> layers.</p>
<h3 id="cell-based-representation">Cell-based Representation</h3>
<p>Inspired by the design of using repeated modules in successful vision model architectures (e.g. Inception, ResNet), the <em>NASNet search space</em> (<a href="https://arxiv.org/abs/1707.07012">Zoph et al. 2018</a>) defines the architecture of a conv net as the same cell getting repeated multiple times and each cell contains several operations predicted by the NAS algorithm. A well-designed cell module enables transferability between datasets. It is also easy to scale down or up the model size by adjusting the number of cell repeats.</p>
<p>Precisely, the NASNet search space learns two types of cells for network construction:</p>
<ol>
<li><em>Normal Cell</em>: The input and output feature maps have the same dimension.</li>
<li><em>Reduction Cell</em>: The output feature map has its width and height reduced by half.</li>
</ol>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/NASNet-search-space.png" alt="NASNet search space" /></p>
<p class="image-caption"><em>Fig. 3. The NASNet search space constrains the architecture as a repeated stack of cells. The cell architecture is optimized via NAS algorithms. (Image source: <a href="https://arxiv.org/abs/1707.07012">Zoph et al. 2018</a>)</em></p>
<p>The predictions for each cell are grouped into \(B\) blocks (\(B=5\) in the NASNet paper), where each block has 5 prediction steps made by 5 distinct softmax classifiers corresponding to discrete choices of the elements of a block. Note that the NASNet search space does not have residual connections between cells and the model only learns skip connections on their own within blocks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/cell-prediction-steps.png" alt="5 prediction steps in one block" /></p>
<p class="image-caption"><em>Fig. 4. (a) Each cell consists of \(B\) blocks and each block is predicted by 5 discrete decisions. (b) An concrete example of what operations can be chosen in each decision step.</em></p>
<p><a name="ScheduledDropPath"></a>During the experiments, they discovered that a modified version of <a href="https://arxiv.org/abs/1605.07648"><em>DropPath</em></a>, named <em>ScheduledDropPath</em>, significantly improves the final performance of NASNet experiments. DropPath stochastically drops out paths (i.e. edges with operations attached in NASNet) with a fixed probability. ScheduledDropPath is DropPath with a linearly increasing probability of path dropping during training time.</p>
<p><a href="https://arxiv.org/abs/1808.05377">Elsken, et al (2019)</a> point out three major advantages of the NASNet search space:</p>
<ol>
<li>The search space size is reduced drastically;</li>
<li>The <a href="https://en.wikipedia.org/wiki/Network_motif">motif</a>-based architecture can be more easily transferred to different datasets.</li>
<li>It demonstrates a strong proof of a useful design pattern of repeatedly stacking modules in architecture engineering. For example, we can build strong models by stacking residual blocks in CNN or stacking multi-headed attention blocks in Transformer.</li>
</ol>
<h3 id="hierarchical-structure">Hierarchical Structure</h3>
<p>To take advantage of already discovered well-designed network <a href="https://en.wikipedia.org/wiki/Network_motif">motifs</a>, the NAS search space can be constrained as a hierarchical structure, as in <em>Hierarchical NAS</em> (<strong>HNAS</strong>; (<a href="https://arxiv.org/abs/1711.00436">Liu et al 2017</a>)). It starts with a small set of primitives, including individual operations like convolution operation, pooling, identity, etc. Then small sub-graphs (or “motifs”) that consist of primitive operations are recursively used to form higher-level computation graphs.</p>
<p>A computation motif at level \(\ell=1, \dots, L\) can be represented by \((G^{(\ell)}, \mathcal{O}^{(\ell)})\), where:</p>
<ul>
<li>\(\mathcal{O}^{(\ell)}\) is a set of operations, \(\mathcal{O}^{(\ell)} = \{ o^{(\ell)}_1, o^{(\ell)}_2, \dots \}\)</li>
<li>\(G^{(\ell)}\) is an adjacency matrix, where the entry \(G_{ij}=k\) indicates that operation \(o^{(\ell)}_k\) is placed between node \(i\) and \(j\). The node indices follow <a href="https://en.wikipedia.org/wiki/Topological_sorting">topological ordering</a> in DAG, where the index \(1\) is the source and the maximal index is the sink node.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/hierarchical-NAS-search-space.png" alt="Hierarchical search space" /></p>
<p class="image-caption"><em>Fig. 5. (Top) Three level-1 primitive operations are composed into a level-2 motif. (Bottom) Three level-2 motifs are plugged into a base network structure and assembled into a level-3 motif. (Image source: <a href="https://arxiv.org/abs/1711.00436">Liu et al 2017</a>)</em></p>
<p>To build a network according to the hierarchical structure, we start from the lowest level \(\ell=1\) and recursively define the \(m\)-th motif operation at level \(\ell\) as</p>
\[o^{(\ell)}_m = \text{assemble}\Big( G_m^{(\ell)}, \mathcal{O}^{(\ell-1)} \Big)\]
<p>A hierarchical representation becomes \(\Big( \big\{ \{ G_m^{(\ell)} \}_{m=1}^{M_\ell} \big\}_{\ell=2}^L, \mathcal{O}^{(1)} \Big), \forall \ell=2, \dots, L\), where \(\mathcal{O}^{(1)}\) contains a set of primitive operations.</p>
<p>The \(\text{assemble}()\) process is equivalent to sequentially compute the feature map of node \(i\) by aggregating all the feature maps of its predecessor node \(j\) following the topological ordering:</p>
\[x_i = \text{merge} \big[ \{ o^{(\ell)}_{G^{(\ell)}_{ij}}(x_j) \}_{j < i} \big], i = 2, \dots, \vert G^{(\ell)} \vert\]
<p>where \(\text{merge}[]\) is implemented as depth-wise concatenation in the <a href="https://arxiv.org/abs/1711.00436">paper</a>.</p>
<p>Same as NASNet, experiments in <a href="https://arxiv.org/abs/1711.00436">Liu et al (2017)</a> focused on discovering good cell architecture within a predefined “macro” structure with repeated modules. They showed that the power of simple search methods (e.g. random search or evolutionary algorithms) can be substantially enhanced using well-designed search spaces.</p>
<p><a href="https://arxiv.org/abs/1806.02639">Cai et al (2018b)</a> propose a tree-structure search space using path-level network transformation. Each node in a tree structure defines an <em>allocation</em> scheme for splitting inputs for child nodes and a <em>merge</em> scheme for combining results from child nodes. The path-level network transformation allows replacing a single layer with a multi-branch motif if its corresponding merge scheme is add or concat.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/path-level-network-transformations.png" alt="Path-level network transformation" /></p>
<p class="image-caption"><em>Fig. 6. An illustration of transforming a single layer to a tree-structured motif via path-level transformation operations. (Image source: <a href="https://arxiv.org/abs/1806.02639">Cai et al. 2018b</a>)</em></p>
<h3 id="memory-bank-representation">Memory-bank Representation</h3>
<p>A memory-bank representation of feed-forward networks is proposed by <a href="https://arxiv.org/abs/1708.05344">Brock et al. (2017)</a> in <a href="#prediction-based">SMASH</a>. Instead of a graph of operations, they view a neural network as a system with multiple memory blocks which can read and write. Each layer operation is designed to: (1) read from a subset of memory blocks; (2) computes results; finally (3) write the results into another subset of blocks. For example, in a sequential model, a single memory block would get read and overwritten consistently.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/NAS-memory-bank-view-representation.png" alt="Memory-bank representation" /></p>
<p class="image-caption"><em>Fig. 7. Memory-bank representation of several popular network architecture blocks. (Image source: <a href="https://arxiv.org/abs/1708.05344">Brock et al. 2017</a>)</em></p>
<h2 id="search-algorithms">Search Algorithms</h2>
<p>NAS search algorithms sample a population of child networks. It receives the child models’ performance metrics as rewards and learns to generate high-performance architecture candidates. You may a lot in common with the field of hyperparameter search.</p>
<h3 id="random-search">Random Search</h3>
<p>Random search is the most naive baseline. It samples a valid architecture candidate from the search space <em>at random</em> and no learning model is involved. Random search has proved to be quite useful in hyperparameter search (<a href="http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf">Bergstra & Bengio 2012</a>). With a well-designed search space, random search could be a very challenging baseline to beat.</p>
<h3 id="reinforcement-learning">Reinforcement Learning</h3>
<p>The initial design of <strong>NAS</strong> (<a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>) involves a RL-based controller for proposing child model architectures for evaluation. The controller is implemented as a RNN, outputting a variable-length sequence of tokens used for configuring a network architecture.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/NAS.png" alt="NAS" /></p>
<p class="image-caption"><em>Fig. 8. A high level overview of NAS, containing a RNN controller and a pipeline for evaluating child models. (Image source: <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>)</em></p>
<p>The controller is trained as a <em>RL task</em> using <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce">REINFORCE</a>.</p>
<ul>
<li><strong>Action space</strong>: The action space is a list of tokens for defining a child network predicted by the controller (See more in the above <a href="#sequential-layer-wise-operations">section</a>). The controller outputs <em>action</em>, \(a_{1:T}\), where \(T\) is the total number of tokens.</li>
<li><strong>Reward</strong>: The accuracy of a child network that can be achieved at convergence is the reward for training the controller, \(R\).</li>
<li><strong>Loss</strong>: NAS optimizes the controller parameters \(\theta\) with a REINFORCE loss. We want to maximize the expected reward (high accuracy) with the gradient as follows. The nice thing here with policy gradient is that it works even when the reward is non-differentiable.</li>
</ul>
\[\nabla_{\theta} J(\theta) = \sum_{t=1}^T \mathbb{E}[\nabla_{\theta} \log P(a_t \vert a_{1:(t-1)}; \theta) R ]\]
<p><strong>MetaQNN</strong> (<a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a>) trains an agent to sequentially choose CNN layers using <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#q-learning-off-policy-td-control"><em>Q-learning</em></a> with an <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#%CE%B5-greedy-algorithm">\(\epsilon\)-greedy</a> exploration strategy and experience replay. The reward is the validation accuracy as well.</p>
\[Q^{(t+1)}(s_t, a_t) = (1 - \alpha)Q^{(t)}(s_t, a_t) + \alpha (R_t + \gamma \max_{a \in \mathcal{A}} Q^{(t)}(s_{t+1}, a'))\]
<p>where a state \(s_t\) is a tuple of layer operation and related parameters. An action $$a$ determines the connectivity between operations. The Q-value is proportional to how confident we are in two connected operations leading to high accuracy.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/MetaQNN.png" alt="MetaQNN" /></p>
<p class="image-caption"><em>Fig. 9. Overview of MetaQNN - designing CNN models with Q-Learning. (Image source: <a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a>)</em></p>
<h3 id="evolutionary-algorithms">Evolutionary Algorithms</h3>
<p><strong>NEAT</strong> (short for <em>NeuroEvolution of Augmenting Topologies</em>) is an approach for evolving neural network topologies with <a href="https://en.wikipedia.org/wiki/Genetic_algorithm">genetic algorithm (GA)</a>, proposed by <a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">Stanley & Miikkulainen</a> in 2002. NEAT evolves both connection weights and network topology together. Each gene encodes the full information for configuring a network, including node weights and edges. The population grows by applying mutation of both weights and connections, as well as crossover between two parent genes. For more in neuroevolution, please refer to the in-depth <a href="https://www.nature.com/articles/s42256-018-0006-z">survey</a> by Stanley et al. (2019).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NEAT-mutations.png" alt="Mutation operations in NEAT" /></p>
<p class="image-caption"><em>Fig. 10. Mutations in the NEAT algorithm. (Image source: Fig 3 & 4 in <a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">Stanley & Miikkulainen, 2002</a>)</em></p>
<p><a href="https://arxiv.org/abs/1802.01548">Real et al. (2018)</a> adopt the evolutionary algorithms (EA) as a way to search for high-performance network architectures, named <strong>AmoebaNet</strong>. They apply the <a href="https://en.wikipedia.org/wiki/Tournament_selection">tournament selection</a> method, which at each iteration picks a best candidate out of a random set of samples and places its mutated offspring back into the population. When the tournament size is \(1\), it is equivalent to random selection.</p>
<p><a href="aging-evolutionary-algorithms"></a>AmoebaNet modified the tournament selection to favor <em>younger</em> genotypes and always discard the oldest models within each cycle. Such an approach, named <em>aging evolution</em>, allows AmoebaNet to cover and explore more search space, rather than to narrow down on good performance models too early.</p>
<p>Precisely, in every cycle of the tournament selection with aging regularization (See Figure 11):</p>
<ol>
<li>Sample \(S\) models from the population and the one with highest accuracy is chosen as <em>parent</em>.</li>
<li>A <em>child</em> model is produced by mutating <em>parent</em>.</li>
<li>Then the child model is trained, evaluated and added back into the population.</li>
<li>The oldest model is removed from the population.</li>
</ol>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/aging-evolution-algorithm.png" alt="Aging evolution algorithm" /></p>
<p class="image-caption"><em>Fig. 11. The algorithm of aging evolution. (Image source: <a href="https://arxiv.org/abs/1802.01548">Real et al. 2018</a>)</em></p>
<p>Two types of mutations are applied:</p>
<ol>
<li><em>Hidden state mutation</em>: randomly chooses a pairwise combination and rewires a random end such that there is no loop in the graph.</li>
<li><em>Operation mutation</em>: randomly replaces an existing operation with a random one.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/AmoebaNet-mutations.png" alt="Mutations in AmoebaNet" /></p>
<p class="image-caption"><em>Fig. 12. Two types of mutations in AmoebaNet. (Image source: <a href="https://arxiv.org/abs/1802.01548">Real et al. 2018</a>)</em></p>
<p>In their experiments, EA and RL work equally well in terms of the final validation accuracy, but EA has better anytime performance and is able to find smaller models. Here using EA in NAS is still expensive in terms of computation, as each experiment took 7 days with 450 GPUs.</p>
<p><strong>HNAS</strong> (<a href="https://arxiv.org/abs/1711.00436">Liu et al 2017</a>) also employs the evolutionary algorithms (the original tournament selection) as their search strategy. In the <a href="#hierarchical-structure">hierarchical structure</a> search space, each edge is an operation. Thus genotype mutation in their experiments is applied by replacing a random edge with a different operation. The replacement set includes an <code class="language-plaintext highlighter-rouge">none</code> op, so it can alter, remove and add an edge. The initial set of genotypes is created by applying a large number of random mutations on “trivial” motifs (all identity mappings).</p>
<h3 id="progressive-decision-process">Progressive Decision Process</h3>
<p>Constructing a model architecture is a sequential process. Every additional operator or layer brings extra complexity. If we guide the search model to start the investigation from simple models and gradually evolve to more complex architectures, it is like to introduce <a href="/lil-log/2020/01/29/curriculum-for-reinforcement-learning.html">“curriculum”</a> into the search model’s learning process.</p>
<p><em>Progressive NAS</em> (<strong>PNAS</strong>; <a href="https://arxiv.org/abs/1712.00559">Liu, et al 2018</a>) frames the problem of NAS as a progressive procedure for searching models of increasing complexity. Instead of RL or EA, PNAS adopts a Sequential Model-based Bayesian Optimization (SMBO) as the search strategy. PNAS works similar to A* search, as it searches for models from simple to hard while simultaneously learning a surrogate function to guide the search.</p>
<blockquote>
<p><a href="https://en.wikipedia.org/wiki/A*_search_algorithm">A* search algorithm</a> (“best-first search”) is a popular algorithm for path finding. The problem is framed as finding a path of smallest cost from a specific starting node to a given target node in a weighted graph. At each iteration, A* finds a path to extend by minimizing: \(f(n)=g(n)+h(n)\), where \(n\) is the next node, \(g(n)\) is the cost from start to \(n\), and \(h(n)\) is the heuristic function that estimates the minimum cost of going from node \(n\) to the goal.</p>
</blockquote>
<p>PNAS uses the <a href="#cell-based-representation">NASNet</a> search space. Each block is specified as a 5-element tuple and PNAS only considers the element-wise addition as the step 5 combination operator, no concatenation. Differently, instead of setting the number of blocks \(B\) at a fixed number, PNAS starts with \(B=1\), a model with only one block in a cell, and gradually increases \(B\).</p>
<p>The performance on a validation set is used as feedback to train a <em>surrogate</em> model for <em>predicting</em> the performance of novel architectures. With this predictor, we can thus decide which models should be prioritized to be evaluated next. Since the performance predictor should be able to handle various-sized inputs, accuracy, and sample-efficient, they ended up using an RNN model.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/progressive-NAS-algorithm.png" alt="Progressive NAS" /></p>
<p class="image-caption"><em>Fig. 13. The algorithm of Progressive NAS. (Image source: <a href="https://arxiv.org/abs/1712.00559">Liu, et al 2018</a>)</em></p>
<h3 id="gradient-descent">Gradient descent</h3>
<p>Using gradient descent to update the architecture search model requires an effort to make the process of choosing discrete operations differentiable. These approaches usually combine the learning of both architecture parameters and network weights together into one model. See more in the <a href="#one-shot-approach-search--evaluation">section</a> on the <em>“one-shot”</em> approach.</p>
<h2 id="evaluation-strategy">Evaluation Strategy</h2>
<p>We need to measure, estimate or predict the performance of every child model in order to obtain feedback for optimizing the search algorithm. The process of candidate evaluation could be very expensive and many new evaluation methods have been proposed to save time or computation. When evaluating a child model, we mostly care about its performance measured as accuracy on a validation set. Recent work has started looking into other factors of a model, such as model size and latency, as certain devices may have limitations on memory or demand fast response time.</p>
<h3 id="training-from-scratch">Training from Scratch</h3>
<p>The most naive approach is to train every child network independently from scratch until <em>convergence</em> and then measure its accuracy on a validation set (<a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>). It provides solid performance numbers, but one complete train-converge-evaluate loop only generates a single data sample for training the RL controller (let alone RL is known to be sample-inefficient in general). Thus it is very expensive in terms of computation consumption.</p>
<h3 id="proxy-task-performance">Proxy Task Performance</h3>
<p>There are several approaches for using a proxy task performance as the performance estimator of a child network, which is generally cheaper and faster to calculate:</p>
<ul>
<li>Train on a smaller dataset.</li>
<li>Train for fewer epochs.</li>
<li>Train and evaluate a down-scaled model in the search stage. For example, once a cell structure is learned, we can play with the number of cell repeats or scale up the number of filters (<a href="https://arxiv.org/abs/1707.07012">Zoph et al. 2018</a>).</li>
<li>Predict the learning curve. <a href="https://arxiv.org/abs/1705.10823">Baker et al (2018)</a> model the prediction of validation accuracies as a time-series regression problem. The features for the regression model (\(\nu\)-support vector machine regressions; \(\nu\)-SVR) include the early sequences of accuracy per epoch, architecture parameters, and hyperparameters.</li>
</ul>
<h3 id="parameter-sharing">Parameter Sharing</h3>
<p>Instead of training every child model independently from scratch. You may ask, ok, what if we fabricate dependency between them and find a way to reuse weights? Some researchers succeeded to make such approaches work.</p>
<p>Inspired by <a href="https://arxiv.org/abs/1511.05641">Net2net</a> transformation, <a href="https://arxiv.org/abs/1707.04873">Cai et al (2017)</a> proposed <em>Efficient Architecture Search</em> (<strong>EAS</strong>). EAS sets up an RL agent, known as a meta-controller, to predict function-preserving network transformation so as to grow the network depth or layer width. Because the network is growing incrementally, the weights of previously validated networks can be <em>reused</em> for further exploration. With inherited weights, newly constructed networks only need some light-weighted training.</p>
<p>A meta-controller learns to generate <em>network transformation actions</em> given the current network architecture, which is specified with a variable-length string. In order to handle architecture configuration of a variable length, the meta-controller is implemented as a bi-directional recurrent network. Multiple actor networks output different transformation decisions:</p>
<ol>
<li><em>Net2WiderNet</em> operation allows to replace a layer with a wider layer, meaning more units for fully-connected layers, or more filters for convolutional layers, while preserving the functionality.</li>
<li><em>Net2DeeperNet</em> operation allows to insert a new layer that is initialized as adding an identity mapping between two layers so as to preserve the functionality.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/EAS-meta-controller.png" alt="EAS meta-controller" /></p>
<p class="image-caption"><em>Fig. 14. Overview of the RL based meta-controller in Efficient Architecture Search (NAS). After encoding the architecture configuration, it outputs net2net transformation actions through two separate actor networks. (Image source: <a href="https://arxiv.org/abs/1707.04873">Cai et al 2017</a>)</em></p>
<p><a name="ENAS"></a>With similar motivation, <em>Efficient NAS</em> (<strong>ENAS</strong>; <a href="https://arxiv.org/abs/1802.03268">Pham et al. 2018</a>) speeds up NAS (i.e. 1000x less) by aggressively sharing parameters among child models. The core motivation behind ENAS is the observation that all of the sampled architecture graphs can be viewed as <em>sub-graphs</em> of a larger <em>supergraph</em>. All the child networks are sharing weights of this supergraph.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ENAS-example.png" alt="ENAS example" /></p>
<p class="image-caption"><em>Fig. 15. (Left) The graph represents the entire search space for a 4-node recurrent cell, but only connections in red are active. (Middle) An example of how the left active sub-graph can be translated into a child model architecture. (Right) The network parameters produced by an RNN controller for the architecture in the middle. (Image source: <a href="https://arxiv.org/abs/1802.03268">Pham et al. 2018</a>)</em></p>
<p>ENAS alternates between training the shared model weights \(\omega\) and training the controller \(\theta\):</p>
<ol>
<li>The parameters of the controller LSTM \(\theta\) are trained with <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce">REINFORCE</a>, where the reward \(R(\mathbf{m}, \omega)\) is computed on the validation set.</li>
<li>The shared parameters of the child models \(\omega\) are trained with standard supervised learning loss. Note that different operators associated with the same node in the supergraph would have their own distinct parameters.</li>
</ol>
<h3 id="prediction-based">Prediction-Based</h3>
<p>A routine child model evaluation loop is to update model weights via standard gradient descent. SMASH (<a href="https://arxiv.org/abs/1708.05344">Brock et al. 2017</a>) proposes a different and interesting idea: <em>Can we predict the model weights directly based on the network architecture parameters?</em></p>
<p>They employ a <a href="https://blog.otoro.net/2016/09/28/hyper-networks/">HyperNet</a> (<a href="https://arxiv.org/abs/1609.09106">Ha et al 2016</a>) to directly generate the weights of a model conditioned on an encoding of its architecture configuration. Then the model with HyperNet-generated weights is validated directly. Note that we don’t need extra training for every child model but we do need to train the HyperNet.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SMASH-algorithm.png" alt="SMASH algorithm" /></p>
<p class="image-caption"><em>Fig. 16. The algorithm of SMASH. (Image source: <a href="https://arxiv.org/abs/1708.05344">Brock et al. 2017</a>)</em></p>
<p>The correlation between model performance with SMASH-generated weights and true validation errors suggests that predicted weights can be used for model comparison, to some extent. We do need a HyperNet of large enough capacity, as the correlation would be corrupted if the HyperNet model is too small compared to the child model size.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SMASH-error-correlation.png" alt="SMASH error correlation" /></p>
<p class="image-caption"><em>Fig. 17. The algorithm of SMASH. (Image source: <a href="https://arxiv.org/abs/1708.05344">Brock et al. 2017</a>)</em></p>
<p>SMASH can be viewed as another way to implement the idea of <a href="#parameter-sharing">parameter sharing</a>. One problem of SMASH as pointed out by <a href="https://arxiv.org/abs/1802.03268">Pham et al. (2018)</a> is: The usage of HyperNet restricts the weights of SMASH child models to a <em>low-rank space</em>, because weights are generated via tensor products. In comparison, <a href="#ENAS">ENAS</a> has no such restrictions.</p>
<h2 id="one-shot-approach-search--evaluation">One-Shot Approach: Search + Evaluation</h2>
<p>Running search & evaluation independently for a large population of child models is expensive. We have seen promising approaches like <a href="https://arxiv.org/abs/1708.05344">Brock et al. (2017)</a> or <a href="https://arxiv.org/abs/1802.03268">Pham et al. (2018)</a>, where training a single model is enough for emulating any child model in the search space.</p>
<p>The <strong>one-shot</strong> architecture search extends the idea of weight sharing and further combines the learning of architecture generation together with weight parameters. The following approaches all treat child architectures as different sub-graphs of a supergraph with shared weights between common edges in the supergraph.</p>
<p><a href="http://proceedings.mlr.press/v80/bender18a/bender18a.pdf">Bender et al (2018)</a> construct a single large over-parameterized network, known as the <strong>One-Shot model</strong>, such that it contains every possible operation in the search space. With <a href="#ScheduledDropPath">ScheduledDropPath</a> (the dropout rate is increased over time, which is \(r^{1/k}\) at the end of training, where \(0 < r < 1\) is a hyperparam and \(k\) is the number of incoming paths) and some carefully designed tricks (e.g. ghost batch normalization, L2 regularization only on the active architecture), the training of such a giant model can be stabilized enough and used for evaluating any child model sampled from the supergraph.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/one-shot-model-architecture.png" alt="One-Shot model architecture" /></p>
<p class="image-caption"><em>Fig. 18. The architecture of the One-Shot model in <a href="http://proceedings.mlr.press/v80/bender18a/bender18a.pdf">Bender et al 2018</a>. Each cell has \(N\) choice blocks and each choice block can select up to 2 operations. Solid edges are used in every architecture, where dash lines are optional. (Image source: <a href="http://proceedings.mlr.press/v80/bender18a/bender18a.pdf">Bender et al 2018</a>)</em></p>
<p>Once the one-shot model is trained, it is used for evaluating the performance of many different architectures sampled at random by zeroing out or removing some operations. This sampling process can be replaced by RL or evolution.</p>
<p>They observed that the difference between the accuracy measured with the one-shot model and the accuracy of the same architecture after a small fine-tuning could be very large. Their hypothesis is that the one-shot model automatically learns to focus on the <em>most useful</em> operations in the network and comes to <em>rely on</em> these operations when they are available. Thus zeroing out useful operations lead to big reduction in model accuracy, while removing less important components only causes a small impact — Therefore, we see a larger variance in scores when using the one-shot model for evaluation.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/one-shot-model-accuracy-correlation.png" alt="One-shot accuracy" /></p>
<p class="image-caption"><em>Fig. 19. A stratified sample of models with different one-shot model accuracy versus their true validation accuracy as stand-alone models. (Image source: <a href="http://proceedings.mlr.press/v80/bender18a/bender18a.pdf">Bender et al 2018</a>)</em></p>
<p>Clearly designing such a search graph is not a trivial task, but it demonstrates a strong potential with the one-shot approach. It works well with only gradient descent and no additional algorithm like RL or EA is a must.</p>
<p>Some believe that one main cause for inefficiency in NAS is to treat the architecture search as a <em>black-box optimization</em> and thus we fall into methods like RL, evolution, SMBO, etc. If we shift to rely on standard gradient descent, we could potentially make the search process more effectively. As a result, <a href="https://arxiv.org/abs/1806.09055">Liu et al (2019)</a> propose <em>Differentiable Architecture Search</em> (<strong>DARTS</strong>). DARTS introduces a continuous relaxation on each path in the search supergraph, making it possible to jointly train architecture parameters and weights via gradient descent.</p>
<p>Let’s use the directed acyclic graph (DAG) representation here. A cell is a DAG consisting of a topologically ordered sequence of \(N\) nodes. Each node has a latent representation \(x_i\) to be learned. Each edge \((i, j)\) is tied to some operation \(o^{(i,j)} \in \mathcal{O}\) that transforms \(x_j\) to compose \(x_i\):</p>
\[x_i = \sum_{j < i} o^{(i,j)}(x_j)\]
<p>To make the search space continuous, DARTS relaxes the categorical choice of a particular operation as a softmax over all the operations and the task of architecture search is reduced to learn a set of mixing probabilities \(\alpha = \{ \alpha^{(i,j)} \}\).</p>
\[\bar{o}^{(i,j)}(x) = \sum_{o\in\mathcal{O}} \frac{\exp(\alpha_{ij}^o)}{\sum_{o'\in\mathcal{O}} \exp(\alpha^{o'}_{ij})} o(x)\]
<p>where \(\alpha_{ij}\) is a vector of dimension \(\vert \mathcal{O} \vert\), containing weights between nodes \(i\) and \(j\) over different operations.</p>
<p>The bilevel optimization exists as we want to optimize both the network weights \(w\) and the architecture representation \(\alpha\):</p>
\[\begin{aligned}
\min_\alpha & \mathcal{L}_\text{validate} (w^*(\alpha), \alpha) \\
\text{s.t.} & w^*(\alpha) = \arg\min_w \mathcal{L}_\text{train} (w, \alpha)
\end{aligned}\]
<p>At step \(k\), given the current architecture parameters \(\alpha_{k−1}\), we first optimize weights \(w_k\) by moving \(w_{k−1}\) in the direction of minimizing the training loss \(\mathcal{L}_\text{train}(w_{k−1}, \alpha_{k−1})\) with a learning rate \(\xi\). Next, while keeping the newly updated weights \(w_k\) fixed, we update the mixing probabilities so as to minimize the validation loss <em>after a single step of gradient descent w.r.t. the weights</em>:</p>
\[J_\alpha = \mathcal{L}_\text{val}(w_k - \xi \nabla_w \mathcal{L}_\text{train}(w_k, \alpha_{k-1}), \alpha_{k-1})\]
<p>The motivation here is that we want to find an architecture with a low validation loss when its weights are optimized by gradient descent and the one-step unrolled weights serve as the <em>surrogate</em> for \(w^∗(\alpha)\).</p>
<blockquote>
<p>Side note: Earlier we have seen similar formulation in <a href="/lil-log/2018/11/30/meta-learning.html#maml">MAML</a> where the two-step optimization happens between task losses and the meta-learner update, as well as framing <a href="/lil-log/2019/05/05/domain-randomization.html#dr-as-optimization">Domain Randomization</a> as a bilevel optimization for better transfer in the real environment.</p>
</blockquote>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DARTS-illustration.png" alt="DARTS" /></p>
<p class="image-caption"><em>Fig. 20. An illustration of how DARTS applies continuous relaxation on edges in DAG supergraph and identifies the final model. (Image source: <a href="https://arxiv.org/abs/1806.09055">Liu et al 2019</a>)</em></p>
\[\begin{aligned}
\text{Let }w'_k &= w_k - \xi \nabla_w \mathcal{L}_\text{train}(w_k, \alpha_{k-1}) & \\
J_\alpha &= \mathcal{L}_\text{val}(w_k - \xi \nabla_w \mathcal{L}_\text{train}(w_k, \alpha_{k-1}), \alpha_{k-1}) = \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) & \\
\nabla_\alpha J_\alpha
&= \nabla_{\alpha_{k-1}} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) \nabla_\alpha \alpha_{k-1} + \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1})\nabla_\alpha w'_k & \\& \text{; multivariable chain rule}\\
&= \nabla_{\alpha_{k-1}} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) + \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) \big( - \xi \color{red}{\nabla^2_{\alpha, w} \mathcal{L}_\text{train}(w_k, \alpha_{k-1})} \big) & \\
&\approx \nabla_{\alpha_{k-1}} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) - \xi \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) \color{red}{\frac{\nabla_\alpha \mathcal{L}_\text{train}(w_k^+, \alpha_{k-1}) - \nabla_\alpha \mathcal{L}_\text{train}(w_k^-, \alpha_{k-1}) }{2\epsilon}} & \\
& \text{; apply numerical differentiation approximation}
\end{aligned}\]
<p>where the red part is using numerical differentiation approximation where \(w_k^+ = w_k + \epsilon \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1})\) and \(w_k^- = w_k - \epsilon \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1})\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DARTS-algorithm.png" alt="DARTS algorithm" /></p>
<p class="image-caption"><em>Fig. 21. The algorithm overview of DARTS. (Image source: <a href="https://arxiv.org/abs/1806.09055">Liu et al 2019</a>)</em></p>
<p>As another idea similar to DARTS, Stochastic NAS (<a href="https://arxiv.org/abs/1812.09926">Xie et al., 2019</a>) applies a continuous relaxation by employing the concrete distribution (CONCRETE = CONtinuous relaxations of disCRETE random variables; <a href="https://arxiv.org/abs/1611.00712">Maddison et al 2017</a>) and reparametrization tricks. The goal is same as DARTS, to make the discrete distribution differentiable and thus enable optimization by gradient descent.
<!--- TBA: maybe add more details on SNAS --></p>
<p>DARTS is able to greatly reduce the cost of GPU hours. Their experiments for searching for CNN cells have \(N=7\) and only took 1.5 days with a single GPU. However, it suffers from the high GPU memory consumption issue due to its continuous representation of network architecture. In order to fit the model into the memory of a single GPU, they picked a small \(N\).</p>
<p>To constrain the GPU memory consumption, <strong>ProxylessNAS</strong> (<a href="https://arxiv.org/abs/1812.00332">Cai et al., 2019</a>) views NAS as a path-level pruning process in DAG and binarizes the architecture parameters to force only one path to be active between two nodes at a time. The probabilities for an edge being either masked out or not are then learned by sampling a few binarized architectures and using <em>BinaryConnect</em> (<a href="https://arxiv.org/abs/1511.00363">Courbariaux et al., 2015</a>) to update the corresponding probabilities. ProxylessNAS demonstrates a strong connection between NAS and model compression. By using path-level compression, it is able to save memory consumption by one order of magnitude.</p>
<p>Let’s continue with the graph representation. In a DAG adjacency matrix \(G\) where \(G_{ij}\) represents an edge between node \(i\) and \(j\) and its value can be chosen from the set of \(\vert \mathcal{O} \vert\) candidate primitive operations, \(\mathcal{O} = \{ o_1, \dots \}\). The One-Shot model, DARTS and ProxylessNAS all consider each edge as a mixture of operations, \(m_\mathcal{O}\), but with different tweaks.</p>
<p>In One-Shot, \(m_\mathcal{O}(x)\) is the sum of all the operations. In DARTS, it is a weighted sum where weights are softmax over a real-valued architecture weighting vector \(\alpha\) of length \(\vert \mathcal{O} \vert\). ProxylessNAS transforms the softmax probabilities of \(\alpha\) into a binary gate and uses the binary gate to keep only one operation active at a time.</p>
\[\begin{aligned}
m^\text{one-shot}_\mathcal{O}(x) &= \sum_{i=1}^{\vert \mathcal{O} \vert} o_i(x) \\
m^\text{DARTS}_\mathcal{O}(x) &= \sum_{i=1}^{\vert \mathcal{O} \vert} p_i o_i(x) = \sum_{i=1}^{\vert \mathcal{O} \vert} \frac{\exp(\alpha_i)}{\sum_j \exp(\alpha_j)} o_i(x) \\
m^\text{binary}_\mathcal{O}(x) &= \sum_{i=1}^{\vert \mathcal{O} \vert} g_i o_i(x) = \begin{cases}
o_1(x) & \text{with probability }p_1, \\
\dots &\\
o_{\vert \mathcal{O} \vert}(x) & \text{with probability }p_{\vert \mathcal{O} \vert}
\end{cases} \\
\text{ where } g &= \text{binarize}(p_1, \dots, p_N) = \begin{cases}
[1, 0, \dots, 0] & \text{with probability }p_1, \\
\dots & \\
[0, 0, \dots, 1] & \text{with probability }p_N. \\
\end{cases}
\end{aligned}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/proxylessNAS-training.png" alt="Training steps of ProxylessNAS" /></p>
<p class="image-caption"><em>Fig. 22. ProxylessNAS has two training steps running alternatively. (Image source: <a href="https://arxiv.org/abs/1812.00332">Cai et al., 2019</a>)</em></p>
<p>ProxylessNAS runs two training steps alternatively:</p>
<ol>
<li>When training weight parameters \(w\), it freezes the architecture parameters \(\alpha\) and stochastically samples binary gates \(g\) according to the above \(m^\text{binary}_\mathcal{O}(x)\). The weight parameters can be updated with standard gradient descent.</li>
<li>When training architecture parameters \(\alpha\), it freezes \(w\), resets the binary gates and then updates \(\alpha\) on the validation set. Following the idea of <em>BinaryConnect</em>, the gradient w.r.t. architecture parameters can be approximately estimated using \(\partial \mathcal{L} / \partial g_i\) in replacement for \(\partial \mathcal{L} / \partial p_i\):</li>
</ol>
\[\begin{aligned}
\frac{\partial \mathcal{L}}{\partial \alpha_i}
&= \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial p_j} \frac{\partial p_j}{\partial \alpha_i}
\approx \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial g_j} \frac{\partial p_j}{\partial \alpha_i}
= \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial g_j} \frac{\partial \frac{e^{\alpha_j}}{\sum_k e^{\alpha_k}}}{\partial \alpha_i} \\
&= \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial g_j} \frac{\sum_k e^{\alpha_k} (\mathbf{1}_{i=j} e^{\alpha_j}) - e^{\alpha_j} e^{\alpha_i} }{(\sum_k e^{\alpha_k})^2}
= \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial g_j} p_j (\mathbf{1}_{i=j} -p_i)
\end{aligned}\]
<p>Instead of BinaryConnect, REINFORCE can also be used for parameter updates with the goal for maximizing the reward, while no RNN meta-controller is involved.</p>
<p>Computing \(\partial \mathcal{L} / \partial g_i\) needs to calculate and store \(o_i(x)\), which requires \(\vert \mathcal{O} \vert\) times GPU memory. To resolve this issue, they factorize the task of choosing one path out of \(N\) into multiple binary selection tasks (Intuition: “if a path is the best choice, it should be better than any other path”). At every update step, only two paths are sampled while others are masked. These two selected paths are updated according to the above equation and then scaled properly so that other path weights are unchanged. After this process, one of the sampled paths is enhanced (path weight increases) and the other is attenuated (path weight decreases), while all other paths stay unaltered.</p>
<p>Besides accuracy, ProxylessNAS also considers <em>latency</em> as an important metric to optimize, as different devices might have very different requirements on inference time latency (e.g. GPU, CPU, mobile). To make latency differentiable, they model latency as a continuous function of the network dimensions. The expected latency of a mixed operation can be written as \(\mathbb{E}[\text{latency}] = \sum_j p_j F(o_j)\), where \(F(.)\) is a latency prediction model:</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/proxylessNAS-latency.png" alt="proxylessNAS latency" /></p>
<p class="image-caption"><em>Fig. 23. Add a differentiable latency loss into the training of ProxylessNAS. (Image source: <a href="https://arxiv.org/abs/1812.00332">Cai et al., 2019</a>)</em></p>
<h2 id="whats-the-future">What’s the Future?</h2>
<p>So far we have seen many interesting new ideas on automating the network architecture engineering through neural architecture search and many have achieved very impressive performance. However, it is a bit hard to do inference on <em>why</em> some architecture work well and how we can develop modules generalizable across tasks rather than being very dataset-specific.</p>
<p>As also noted in <a href="https://arxiv.org/abs/1808.05377">Elsken, et al (2019)</a>:</p>
<blockquote>
<p>“…, so far it provides little insights into why specific architectures work well and how similar the architectures derived in independent runs would be. Identifying common motifs, providing an understanding why those motifs are important for high performance, and investigating if these motifs generalize over different problems would be desirable.”</p>
</blockquote>
<p>In the meantime, purely focusing on improvement over validation accuracy might not be enough (<a href="https://arxiv.org/abs/1812.00332">Cai et al., 2019</a>). Devices like mobile phones for daily usage in general have limited memory and computation power. While AI applications are on the way to affect our daily life, it is unavoidable to be more <em>device-specific</em>.</p>
<p>Another interesting investigation is to consider <em>unlabelled dataset</em> and <a href="/lil-log/2019/11/10/self-supervised-learning.html">self-supervised learning</a> for NAS. The size of labelled dataset is always limited and it is not easy to tell whether such a dataset has biases or big deviation from the real world data distribution.</p>
<p><a href="https://arxiv.org/abs/2003.12056">Liu et al (2020)</a> delve into the question <em>“Can we find high-quality neural architecture without human-annotated labels?”</em> and proposed a new setup called <em>Unsupervised Neural Architecture Search</em> (<strong>UnNAS</strong>). The quality of the architecture needs to be estimated in an unsupervised fashion during the search phase. The paper experimented with three unsupervised <a href="/lil-log/2019/11/10/self-supervised-learning.html#images-based">pretext tasks</a>: image rotation prediction, colorization, and solving the jigsaw puzzle.</p>
<p>They observed in a set of UnNAS experiments that:</p>
<ol>
<li>High rank correlation between supervised accuracy and pretext accuracy <em>on the same dataset</em>. Typically the rank correlation is higher than 0.8, regardless of the dataset, the search space, and the pretext task.</li>
<li>High rank correlation between supervised accuracy and pretext accuracy <em>across datasets</em>.</li>
<li>Better pretext accuracy translates to better supervised accuracy.</li>
<li>Performance of UnNAS architecture is comparable to supervised counterparts, though not better yet.</li>
</ol>
<p>One hypothesis is that the architecture quality is correlated with image statistics. Because CIFAR-10 and ImageNet are all on the natural images, they are comparable and the results are transferable. UnNAS could potentially enable a much larger amount of unlabelled data into the search phase which captures image statistics better.</p>
<p>Hyperparameter search is a long-standing topic in the ML community. And NAS automates architecture engineering. Gradually we are trying to automate processes in ML which usually demand a lot of human efforts. Taking even one more step further, is it possible to automatically discover ML algorithms? <strong>AutoML-Zero</strong> (<a href="https://arxiv.org/abs/2003.03384">Real et al 2020</a>) investigates this idea. Using <a href="#aging-evolutionary-algorithms">aging evolutionary algorithms</a>, AutoML-Zero automatically searches for whole ML algorithms using little restriction on the form with only simple mathematical operations as building blocks.</p>
<p>It learns three component functions. Each function only adopts very basic operations.</p>
<ul>
<li><code class="language-plaintext highlighter-rouge">Setup</code>: initialize memory variables (weights).</li>
<li><code class="language-plaintext highlighter-rouge">Learn</code>: modify memory variables</li>
<li><code class="language-plaintext highlighter-rouge">Predict</code>: make a prediction from an input \(x\).</li>
</ul>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/AutoML-zero-evaluation.png" alt="AutoML-zero evaluation" /></p>
<p class="image-caption"><em>Fig. 24. Algorithm evaluation on one task (Image source: <a href="https://arxiv.org/abs/2003.03384">Real et al 2020</a>)</em></p>
<p>Three types of operations are considered when mutating a parent genotype:</p>
<ol>
<li>Insert a random instruction or remove an instruction at a random location in a component function;</li>
<li>Randomize all the instructions in a component function;</li>
<li>Modify one of the arguments of an instruction by replacing it with a random choice (e.g. “swap the output address” or “change the value of a constant”)</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/AutoML-zero-progress.png" alt="Progress of AutoML-zero experiment" /></p>
<p class="image-caption"><em>Fig. 25. An illustration of evolutionary progress on projected binary CIFAR-10 with example code. (Image source: <a href="https://arxiv.org/abs/2003.03384">Real et al 2020</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020nas,
title = "Neural Architecture Search",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/08/06/neural-architecture-search.html"
}
</code></pre></div></div>
<h2 id="appendix-summary-of-nas-papers">Appendix: Summary of NAS Papers</h2>
<table class="info">
<thead>
<tr>
<th>Model name</th>
<th>Search space</th>
<th>Search algorithms</th>
<th>Child model evaluation</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">NEAT (2002)</a></td>
<td>-</td>
<td>Evolution (Genetic algorithm)</td>
<td>-</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1611.01578">NAS (2017)</a></td>
<td>Sequential layer-wise ops</td>
<td>RL (REINFORCE)</td>
<td>Train from scratch until convergence</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1611.02167">MetaQNN (2017)</a></td>
<td>Sequential layer-wise ops</td>
<td>RL (Q-learning with $\epsilon$-greedy)</td>
<td>Train for 20 epochs</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1711.00436">HNAS (2017)</a></td>
<td>Hierarchical structure</td>
<td>Evolution (Tournament selection)</td>
<td>Train for a fixed number of iterations</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1707.07012">NASNet (2018)</a></td>
<td>Cell-based</td>
<td>RL (PPO)</td>
<td>Train for 20 epochs</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1802.01548">AmoebaNet (2018)</a></td>
<td>NASNet search space</td>
<td>Evolution (Tournament selection with aging regularization)</td>
<td>Train for 25 epochs</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1707.04873">EAS (2018a)</a></td>
<td>Network transformation</td>
<td>RL (REINFORCE)</td>
<td>2-stage training</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1712.00559">PNAS (2018)</a></td>
<td>Reduced version of NASNet search space</td>
<td>SMBO; Progressive search for architectures of increasing complexity</td>
<td>Train for 20 epochs</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1802.03268">ENAS (2018)</a></td>
<td>Both sequential and cell-based search space</td>
<td>RL (REINFORCE)</td>
<td>Train one model with shared weights</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1708.05344">SMASH (2017)</a></td>
<td>Memory-bank representation</td>
<td>Random search</td>
<td>HyperNet predicts weights of evaluated architectures.</td>
</tr>
<tr>
<td><a href="http://proceedings.mlr.press/v80/bender18a.html">One-Shot (2018)</a></td>
<td>An over-parameterized one-shot model</td>
<td>Random search (zero out some paths at random)</td>
<td>Train the one-shot model</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1806.09055">DARTS (2019)</a></td>
<td>NASNet search space</td>
<td colspan="2">Gradient descent (Softmax weights over operations)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1812.00332">ProxylessNAS (2019)</a></td>
<td>Tree structure architecture</td>
<td colspan="2">Gradient descent (BinaryConnect) or REINFORCE</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1812.09926">SNAS (2019)</a></td>
<td>NASNet search space</td>
<td colspan="2">Gradient descent (concrete distribution)</td>
</tr>
</tbody>
</table>
<h2 id="reference">Reference</h2>
<p>[1] Thomas Elsken, Jan Hendrik Metzen, Frank Hutter. <a href="https://arxiv.org/abs/1808.05377">“Neural Architecture Search: A Survey”</a> JMLR 20 (2019) 1-21.</p>
<p>[2] Kenneth O. Stanley, et al. <a href="https://www.nature.com/articles/s42256-018-0006-z">“Designing neural networks through neuroevolution”</a> Nature Machine Intelligence volume 1, pages 24–35 (2019).</p>
<p>[3] Kenneth O. Stanley & Risto Miikkulainen. <a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">“Evolving Neural Networks through Augmenting Topologies”</a> Evolutionary Computation 10(2): 99-127 (2002).</p>
<p>[4] Barret Zoph, Quoc V. Le. <a href="https://arxiv.org/abs/1611.01578">“Neural architecture search with reinforcement learning”</a> ICLR 2017.</p>
<p>[5] Bowen Baker, et al. <a href="https://arxiv.org/abs/1611.02167">“Designing Neural Network Architectures using Reinforcement Learning”</a> ICLR 2017.</p>
<p>[6] Bowen Baker, et al. <a href="https://arxiv.org/abs/1705.10823">“Accelerating neural architecture search using performance prediction”</a> ICLR Workshop 2018.</p>
<p>[7] Barret Zoph, et al. <a href="https://arxiv.org/abs/1707.07012">“Learning transferable architectures for scalable image recognition”</a> CVPR 2018.</p>
<p>[8] Hanxiao Liu, et al. <a href="https://arxiv.org/abs/1711.00436">“Hierarchical representations for efficient architecture search.”</a> ICLR 2018.</p>
<p>[9] Esteban Real, et al. <a href="https://arxiv.org/abs/1802.01548">“Regularized Evolution for Image Classifier Architecture Search”</a> arXiv:1802.01548 (2018).</p>
<p>[10] Han Cai, et al. [“Efficient architecture search by network transformation”] AAAI 2018a.</p>
<p>[11] Han Cai, et al. <a href="https://arxiv.org/abs/1806.02639">“Path-Level Network Transformation for Efficient Architecture Search”</a> ICML 2018b.</p>
<p>[12] Han Cai, Ligeng Zhu & Song Han. <a href="https://arxiv.org/abs/1812.00332">“ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware”</a> ICLR 2019.</p>
<p>[13] Chenxi Liu, et al. <a href="https://arxiv.org/abs/1712.00559">“Progressive neural architecture search”</a> ECCV 2018.</p>
<p>[14] Hieu Pham, et al. <a href="https://arxiv.org/abs/1802.03268">“Efficient neural architecture search via parameter sharing”</a> ICML 2018.</p>
<p>[15] Andrew Brock, et al. <a href="https://arxiv.org/abs/1708.05344">“SMASH: One-shot model architecture search through hypernetworks.”</a> ICLR 2018.</p>
<p>[16] Gabriel Bender, et al. <a href="http://proceedings.mlr.press/v80/bender18a.html">“Understanding and simplifying one-shot architecture search.”</a> ICML 2018.</p>
<p>[17] Hanxiao Liu, Karen Simonyan, Yiming Yang. <a href="https://arxiv.org/abs/1806.09055">“DARTS: Differentiable Architecture Search”</a> ICLR 2019.</p>
<p>[18] Sirui Xie, Hehui Zheng, Chunxiao Liu, Liang Lin. <a href="https://arxiv.org/abs/1812.09926">“SNAS: Stochastic Neural Architecture Search”</a> ICLR 2019.</p>
<p>[19] Chenxi Liu et al. <a href="https://arxiv.org/abs/2003.12056">“Are Labels Necessary for Neural Architecture Search?”</a> ECCV 2020.</p>
<p>[20] Esteban Real, et al. <a href="https://arxiv.org/abs/2003.03384">“AutoML-Zero: Evolving Machine Learning Algorithms From Scratch”</a> ICML 2020.</p>Lilian WengNeural Architecture Search (NAS) automates network architecture engineering. It aims to learn a network topology that can achieve best performance on a certain task. By dissecting the methods for NAS into three components: search space, search algorithm and child model evolution strategy, this post reviews many interesting ideas for better, faster and more cost-efficient automatic neural architecture search.Exploration Strategies in Deep Reinforcement Learning2020-06-07T12:00:00+00:002020-06-07T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/06/07/exploration-strategies-in-deep-reinforcement-learning<blockquote>
<p>Exploitation versus exploration is a critical topic in reinforcement learning. This post introduces several common approaches for better exploration in Deep RL.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2020-06-17: Add <a href="#exploration-via-disagreement">“exploration via disagreement”</a> in the “Forward Dynamics” <a href="#forward-dynamics">section</a>.</span></p>
<p><a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html">Exploitation versus exploration</a> is a critical topic in Reinforcement Learning. We’d like the RL agent to find the best solution as fast as possible. However, in the meantime, committing to solutions too quickly without enough exploration sounds pretty bad, as it could lead to local minima or total failure. Modern <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">RL</a> <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html">algorithms</a> that optimize for the best returns can achieve good exploitation quite efficiently, while exploration remains more like an open topic.</p>
<p>I would like to discuss several common exploration strategies in Deep RL here. As this is a very big topic, my post by no means can cover all the important subtopics. I plan to update it periodically and keep further enriching the content gradually in time.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#classic-exploration-strategies" id="markdown-toc-classic-exploration-strategies">Classic Exploration Strategies</a></li>
<li><a href="#key-exploration-problems" id="markdown-toc-key-exploration-problems">Key Exploration Problems</a> <ul>
<li><a href="#the-hard-exploration-problem" id="markdown-toc-the-hard-exploration-problem">The Hard-Exploration Problem</a></li>
<li><a href="#the-noisy-tv-problem" id="markdown-toc-the-noisy-tv-problem">The Noisy-TV Problem</a></li>
</ul>
</li>
<li><a href="#intrinsic-rewards-as-exploration-bonuses" id="markdown-toc-intrinsic-rewards-as-exploration-bonuses">Intrinsic Rewards as Exploration Bonuses</a> <ul>
<li><a href="#count-based-exploration" id="markdown-toc-count-based-exploration">Count-based Exploration</a></li>
<li><a href="#prediction-based-exploration" id="markdown-toc-prediction-based-exploration">Prediction-based Exploration</a></li>
</ul>
</li>
<li><a href="#memory-based-exploration" id="markdown-toc-memory-based-exploration">Memory-based Exploration</a> <ul>
<li><a href="#episodic-memory" id="markdown-toc-episodic-memory">Episodic Memory</a></li>
<li><a href="#direct-exploration" id="markdown-toc-direct-exploration">Direct Exploration</a></li>
</ul>
</li>
<li><a href="#q-value-exploration" id="markdown-toc-q-value-exploration">Q-Value Exploration</a></li>
<li><a href="#varitional-options" id="markdown-toc-varitional-options">Varitional Options</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="classic-exploration-strategies">Classic Exploration Strategies</h2>
<p>As a quick recap, let’s first go through several classic exploration algorithms that work out pretty well in the multi-armed bandit problem or simple tabular RL.</p>
<ul>
<li><strong>Epsilon-greedy</strong>: The agent does random exploration occasionally with probability \(\epsilon\) and takes the optimal action most of the time with probability \(1-\epsilon\).</li>
<li><strong>Upper confidence bounds</strong>: The agent selects the greediest action to maximize the upper confidence bound \(\hat{Q}_t(a) + \hat{U}_t(a)\), where \(\hat{Q}_t(a)\) is the average rewards associated with action \(a\) up to time \(t\) and \(\hat{U}_t(a)\) is a function reversely proportional to how many times action \(a\) has been taken. See <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#upper-confidence-bounds">here</a> for more details.</li>
<li><strong>Boltzmann exploration</strong>: The agent draws actions from a <a href="https://en.wikipedia.org/wiki/Boltzmann_distribution">boltzmann distribution</a> (softmax) over the learned Q values, regulated by a temperature parameter \(\tau\).</li>
<li><strong>Thompson sampling</strong>: The agent keeps track of a belief over the probability of optimal actions and samples from this distribution. See <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#thompson-sampling">here</a> for more details.</li>
</ul>
<p>The following strategies could be used for better exploration in deep RL training when neural networks are used for function approximation:</p>
<ul>
<li><strong>Entropy loss term</strong>: Add an entropy term \(H(\pi(a \vert s))\) into the loss function, encouraging the policy to take diverse actions.</li>
<li><strong>Noise-based Exploration</strong>: Add noise into the observation, action or even parameter space (<a href="https://arxiv.org/abs/1706.10295">Fortunato, et al. 2017</a>, <a href="https://arxiv.org/abs/1706.01905">Plappert, et al. 2017</a>).</li>
</ul>
<h2 id="key-exploration-problems">Key Exploration Problems</h2>
<p>Good exploration becomes especially hard when the environment rarely provides rewards as feedback or the environment has distracting noise. Many exploration strategies are proposed to solve one or both of the following problems.</p>
<h3 id="the-hard-exploration-problem">The Hard-Exploration Problem</h3>
<p>The “hard-exploration” problem refers to exploration in an environment with very sparse or even deceptive reward. It is difficult because random exploration in such scenarios can rarely discover successful states or obtain meaningful feedback.</p>
<p><a href="https://en.wikipedia.org/wiki/Montezuma%27s_Revenge_(video_game)">Montezuma’s Revenge</a> is a concrete example for the hard-exploration problem. It remains as a few challenging games in Atari for DRL to solve. Many papers use Montezuma’s Revenge to benchmark their results.</p>
<h3 id="the-noisy-tv-problem">The Noisy-TV Problem</h3>
<p>The “Noisy-TV” problem started as a thought experiment in <a href="https://arxiv.org/abs/1810.12894">Burda, et al (2018)</a>. Imagine that an RL agent is rewarded with seeking novel experience, a TV with uncontrollable & unpredictable random noise outputs would be able to attract the agent’s attention forever. The agent obtains new rewards from noisy TV consistently, but it fails to make any meaningful progress and becomes a “couch potato”.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/the-noisy-TV-problem.gif" alt="The noisy-TV problem" /></p>
<p class="image-caption"><em>Fig. 1. An agent is rewarded with novel experience in the experiment. If a maze has a noisy TC set up, the agent would be attracted and stop moving in the maze. (Image source: <a href="https://openai.com/blog/reinforcement-learning-with-prediction-based-rewards/">OpenAI Blog: “Reinforcement Learning with Prediction-Based Rewards”</a>)</em></p>
<h2 id="intrinsic-rewards-as-exploration-bonuses">Intrinsic Rewards as Exploration Bonuses</h2>
<p>One common approach to better exploration, especially for solving the <a href="#the-hard-exploration-problem">hard-exploration</a> problem, is to augment the environment reward with an additional bonus signal to encourage extra exploration. The policy is thus trained with a reward composed of two terms, \(r_t = r^e_t + \beta r^i_t\), where \(\beta\) is a hyperparameter adjusting the balance between exploitation and exploration.</p>
<ul>
<li>\(r^e_t\) is an <em>extrinsic</em> reward from the environment at time \(t\), defined according to the task in hand.</li>
<li>\(r^i_t\) is an <em>intrinsic</em> exploration bonus at time \(t\).</li>
</ul>
<p>This intrinsic reward is somewhat inspired by <em>intrinsic motivation</em> in psychology (<a href="https://www.researchgate.net/profile/Pierre-Yves_Oudeyer/publication/29614795_How_can_we_define_intrinsic_motivation/links/09e415107f1b4c8041000000/How-can-we-define-intrinsic-motivation.pdf">Oudeyer & Kaplan, 2008</a>). Exploration driven by curiosity might be an important way for children to grow and learn. In other words, exploratory activities should be rewarding intrinsically in the human mind to encourage such behavior. The intrinsic rewards could be correlated with curiosity, surprise, familiarity of the state, and many other factors.</p>
<p>Same ideas can be applied to RL algorithms. In the following sections, methods of bonus-based exploration rewards are roughly grouped into two categories:</p>
<ol>
<li>Discovery of novel states</li>
<li>Improvement of the agent’s knowledge about the environment.</li>
</ol>
<h3 id="count-based-exploration">Count-based Exploration</h3>
<p>If we consider intrinsic rewards as rewarding conditions that surprise us, we need a way to measure whether a state is novel or appears often. One intuitive way is to count how many times a state has been encountered and to assign a bonus accordingly. The bonus guides the agent’s behavior to prefer rarely visited states to common states. This is known as the <strong>count-based exploration</strong> method.</p>
<p>Let \(N_n(s)\) be the <em>empirical count</em> function that tracks the real number of visits of a state \(s\) in the sequence of \(s_{1:n}\). Unfortunately, using \(N_n(s)\) for exploration directly is not practical, because most of the states would have \(N_n(s)=0\), especially considering that the state space is often continuous or high-dimensional. We need an non-zero count for most states, even when they haven’t been seen before.</p>
<h4 id="counting-by-density-model">Counting by Density Model</h4>
<p><a href="https://arxiv.org/abs/1606.01868">Bellemare, et al. (2016)</a> used a <strong>density model</strong> to approximate the frequency of state visits and a novel algorithm for deriving a <em>pseudo-count</em> from this density model. Let’s first define a conditional probability over the state space, \(\rho_n(s) = \rho(s \vert s_{1:n})\) as the probability of the \((n+1)\)-th state being \(s\) given the first \(n\) states are \(s_{1:n}\). To measure this empirically, we can simply use \(N_n(s)/n\).</p>
<p>Let’s also define a <em>recoding probability</em> of a state \(s\) as the probability assigned by the density model to \(s\) <em>after observing a new occurrence of</em> \(s\), \(\rho'_n(s) = \rho(s \vert s_{1:n}s)\).</p>
<p>The paper introduced two concepts to better regulate the density model, a <em>pseudo-count</em> function \(\hat{N}_n(s)\) and a <em>pseudo-count total</em> \(\hat{n}\). As they are designed to imitate an empirical count function, we would have:</p>
\[\rho_n(s) = \frac{\hat{N}_n(s)}{\hat{n}} \leq \rho'_n(s) = \frac{\hat{N}_n(s) + 1}{\hat{n} + 1}\]
<p>The relationship between \(\rho_n(x)\) and \(\rho'_n(x)\) requires the density model to be <em>learning-positive</em>: for all \(s_{1:n} \in \mathcal{S}^n\) and all \(s \in \mathcal{S}\), \(\rho_n(s) \leq \rho'_n(s)\). In other words, After observing one instance of \(s\), the density model’s prediction of that same \(s\) should increase. Apart from being learning-positive, the density model should be trained completely <em>online</em> with non-randomized mini-batches of experienced states, so naturally we have \(\rho'_n = \rho_{n+1}\).</p>
<p>The pseudo-count can be computed from \(\rho_n(s)\) and \(\rho'_n(s)\) after solving the above linear system:</p>
\[\hat{N}_n(s) = \hat{n} \rho_n(s) = \frac{\rho_n(s)(1 - \rho'_n(s))}{\rho'_n(s) - \rho_n(s)}\]
<p>Or estimated by the <em>prediction gain (PG)</em>:</p>
\[\hat{N}_n(s) \approx (e^{\text{PG}_n(s)} - 1)^{-1} = (e^{\log \rho'_n(s) - \log \rho(s)} - 1)^{-1}\]
<p>A common choice of a count-based intrinsic bonus is \(r^i_t = N(s_t, a_t)^{-1/2}\) (as in MBIE-EB; <a href="https://www.ics.uci.edu/~dechter/courses/ics-295/fall-2019/papers/2008-littman-aij-main.pdf">Strehl & Littman, 2008</a>). The pseudo-count-based exploration bonus is shaped in a similar form, \(r^i_t = \big(\hat{N}_n(s_t, a_t) + 0.01 \big)^{-1/2}\).</p>
<p>Experiments in <a href="https://arxiv.org/abs/1606.01868">Bellemare et al., (2016)</a> adopted a simple <a href="http://proceedings.mlr.press/v32/bellemare14.html">CTS</a> (Context Tree Switching) density model to estimate pseudo-counts. The CTS model takes as input a 2D image and assigns to it a probability according to the product of location-dependent L-shaped filters, where the prediction of each filter is given by a CTS algorithm trained on past images. The CTS model is simple but limited in expressiveness, scalability, and data efficiency. In a following-up paper, <a href="https://arxiv.org/abs/1703.01310">Georg Ostrovski, et al. (2017)</a> improved the approach by training a PixelCNN (<a href="https://arxiv.org/abs/1606.05328">van den Oord et al., 2016</a>) as the density model.</p>
<p>The density model can also be a Gaussian Mixture Model as in <a href="https://arxiv.org/abs/1902.08039">Zhao & Tresp (2018)</a>. They used a variational GMM to estimate the density of trajectories (e.g. concatenation of a sequence of states) and its predicted probabilities to guide prioritization in experience replay in off-policy setting.</p>
<h4 id="counting-after-hashing">Counting after Hashing</h4>
<p>Another idea to make it possible to count high-dimensional states is to map states into <strong>hash codes</strong> so that the occurrences of states become trackable (<a href="https://arxiv.org/abs/1611.04717">Tang et al. 2017</a>). The state space is discretized with a hash function \(\phi: \mathcal{S} \mapsto \mathbb{Z}^k\). An exploration bonus \(r^{i}: \mathcal{S} \mapsto \mathbb{R}\) is added to the reward function, defined as \(r^{i}(s) = {N(\phi(s))}^{-1/2}\), where \(N(\phi(s))\) is an empirical count of occurrences of \(\phi(s)\).</p>
<p><a href="https://arxiv.org/abs/1611.04717">Tang et al. (2017)</a> proposed to use <em>Locality-Sensitive Hashing</em> (<a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing"><em>LSH</em></a>) to convert continuous, high-dimensional data to discrete hash codes. LSH is a popular class of hash functions for querying nearest neighbors based on certain similarity metrics. A hashing scheme \(x \mapsto h(x)\) is locality-sensitive if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. (See how LSH is used in <a href="/lil-log/2020/04/07/the-transformer-family.html#LSH">Transformer improvement</a> if interested.) <a href="https://www.cs.princeton.edu/courses/archive/spr04/cos598B/bib/CharikarEstim.pdf">SimHash</a> is a type of computationally efficient LSH and it measures similarity by angular distance:</p>
\[\phi(s) = \text{sgn}(A g(s)) \in \{-1, 1\}^k\]
<p>where \(A \in \mathbb{R}^{k \times D}\) is a matrix with each entry drawn i.i.d. from a standard Gaussian and \(g: \mathcal{S} \mapsto \mathbb{R}^D\) is an optional preprocessing function. The dimension of binary codes is \(k\), controlling the granularity of the state space discretization. A higher \(k\) leads to higher granularity and fewer collisions.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/count-hashing-exploration.png" alt="#Exploration" /></p>
<p class="image-caption"><em>Fig. 2. Algorithm of count-based exploration through hashing high-dimensional states by SimHash. (Image source: <a href="https://arxiv.org/abs/1611.04717">Tang et al. 2017</a>)</em></p>
<p>For high-dimensional images, SimHash may not work well on the raw pixel level. <a href="https://arxiv.org/abs/1611.04717">Tang et al. (2017)</a> designed an autoencoder (AE) which takes as input states \(s\) to learn hash codes. It has one special dense layer composed of \(k\) sigmoid functions as the latent state in the middle and then the sigmoid activation values \(b(s)\) of this layer are binarized by rounding to their closest binary numbers \(\lfloor b(s)\rceil \in \{0, 1\}^D\) as the binary hash codes for state \(s\). The AE loss over \(n\) states includes two terms:</p>
\[\mathcal{L}(\{s_n\}_{n=1}^N) = \underbrace{-\frac{1}{N} \sum_{n=1}^N \log p(s_n)}_\text{reconstruction loss} + \underbrace{\frac{1}{N} \frac{\lambda}{K} \sum_{n=1}^N\sum_{i=1}^k \min \big \{ (1-b_i(s_n))^2, b_i(s_n)^2 \big\}}_\text{sigmoid activation being closer to binary}\]
<p>One problem with this approach is that dissimilar inputs \(s_i, s_j\) may be mapped to identical hash codes but the AE still reconstructs them perfectly. One can imagine replacing the bottleneck layer \(b(s)\) with the hash codes \(\lfloor b(s)\rceil\), but then gradients cannot be back-propagated through the rounding function. Injecting uniform noise could mitigate this effect, as the AE has to learn to push the latent variable far apart to counteract the noise.</p>
<h3 id="prediction-based-exploration">Prediction-based Exploration</h3>
<p>The second category of intrinsic exploration bonuses are rewarded for improvement of the agent’s knowledge about the environment. The agent’s familiarity with the environment dynamics can be estimated through a prediction model. This idea of using a prediction model to measure <em>curiosity</em> was actually proposed quite a long time ago (<a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.45.957">Schmidhuber, 1991</a>).</p>
<h4 id="forward-dynamics">Forward Dynamics</h4>
<p>Learning a <strong>forward dynamics prediction model</strong> is a great way to approximate how much knowledge our model has obtained about the environment and the task MDPs. It captures an agent’s capability of predicting the consequence of its own behavior, \(f: (s_t, a_t) \mapsto s_{t+1}\). Such a model cannot be perfect (e.g. due to partial observation), the error \(e(s_t, a_t) = \| f(s_t, a_t) - s_{t+1} \|^2_2\) can be used for providing intrinsic exploration rewards. The higher the prediction error, the less familiar we are with that state. The faster the error rate drops, the more learning progress signals we acquire.</p>
<p><em>Intelligent Adaptive Curiosity</em> (<strong>IAC</strong>; <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.177.7661&rep=rep1&type=pdf">Oudeyer, et al. 2007</a>) sketched an idea of using a forward dynamics prediction model to estimate learning progress and assigned intrinsic exploration reward accordingly.</p>
<p>IAC relies on a memory which stores all the experiences encountered by the robot, \(M=\{(s_t, a_t, s_{t+1})\}\) and a forward dynamics model \(f\). IAC incrementally splits the state space (i.e. sensorimotor space in the context of robotics, as discussed in the paper) into separate regions based on the transition samples, using a process similar to how a decision tree is split: The split happens when the number of samples is larger than a threshold, and the variance of states in each leaf should be minimal. Each tree node is characterized by its exclusive set of samples and has its own forward dynamics predictor \(f\), named “expert”.</p>
<p>The prediction error \(e_t\) of an expert is pushed into a list associated with each region. The <em>learning progress</em> is then measured as the difference between the mean error rate of a moving window with offset \(\tau\) and the current moving window. The intrinsic reward is defined for tracking the learning progress: \(r^i_t = \frac{1}{k}\sum_{i=0}^{k-1}(e_{t-i-\tau} - e_{t-i})\), where \(k\) is the moving window size. So the larger prediction error rate decrease we can achieve, the higher intrinsic reward we would assign to the agent. In other words, the agent is encouraged to take actions to quickly learn about the environment.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/IAC.png" alt="IAC" /></p>
<p class="image-caption"><em>Fig. 3. Architecture of the IAC (Intelligent Adaptive Curiosity) module: the intrinsic reward is assigned w.r.t the learning progress in reducing prediction error of the dynamics model. (Image source: <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.177.7661&rep=rep1&type=pdf">Oudeyer, et al. 2007</a>)</em></p>
<p><a href="https://arxiv.org/abs/1507.00814">Stadie et al. (2015)</a> trained a forward dynamics model in the encoding space defined by \(\phi\), \(f_\phi: (\phi(s_t), a_t) \mapsto \phi(s_{t+1})\). The model’s prediction error at time \(T\) is normalized by the maximum error up to time \(t\), \(\bar{e}_t = \frac{e_t}{\max_{i \leq t} e_i}\), so it is always between 0 and 1. The intrinsic reward is defined accordingly: \(r^i_t = (\frac{\bar{e}_t(s_t, a_t)}{t \cdot C})\), where \(C > 0\) is a decay constant.</p>
<p>Encoding the state space via \(\phi(.)\) is necessary, as experiments in the paper have shown that a dynamics model trained directly on raw pixels has <em>very poor</em> behavior — assigning same exploration bonuses to all the states. In <a href="https://arxiv.org/abs/1507.00814">Stadie et al. (2015)</a>, the encoding function \(\phi\) is learned via an autocoder (AE) and \(\phi(.)\) is one of the output layers in AE. The AE can be statically trained using a set of images collected by a random agent, or dynamically trained together with the policy where the early frames are gathered using <a href="#classic-exploration-strategies">\(\epsilon\)-greedy</a> exploration.</p>
<p><a name="ICM"></a>Instead of autoencoder, <em>Intrinsic Curiosity Module</em> (<strong>ICM</strong>; <a href="https://arxiv.org/abs/1705.05363">Pathak, et al., 2017</a>) learns the state space encoding \(\phi(.)\) with a self-supervised <strong>inverse dynamics</strong> model. Predicting the next state given the agent’s own action is not easy, especially considering that some factors in the environment cannot be controlled by the agent or do not affect the agent. ICM believes that a good state feature space should exclude such factors because <em>they cannot influence the agent’s behavior and thus the agent has no incentive for learning them</em>. By learning an inverse dynamics model \(g: (\phi(s_t), \phi(s_{t+1})) \mapsto a_t\), the feature space only captures those changes in the environment related to the actions of our agent, and ignores the rest.</p>
<p>Given a forward model \(f\), an inverse dynamics model \(g\) and an observation \((s_t, a_t, s_{t+1})\):</p>
\[g_{\psi_I}(\phi(s_t), \phi(s_{t+1})) = \hat{a}_t \quad
f_{\psi_F}(\phi(s_t), a_t) = \hat{\phi}(s_{t+1}) \quad
r_t^i = \| \hat{\phi}(s_{t+1}) - \phi(s_{t+1}) \|_2^2\]
<p>Such \(\phi(.)\) is expected to be robust to uncontrollable aspects of the environment.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ICM.png" alt="ICM" /></p>
<p class="image-caption"><em>Fig. 4. ICM (Intrinsic Curiosity Module) assigns the forward dynamics prediction error to the agent as the intrinsic reward. This dynamics model operates in a state encoding space learned through an inverse dynamics model to exclude environmental factors that do not affect the agent’s behavior. (Image source: <a href="https://arxiv.org/abs/1705.05363">Pathak, et al. 2017</a>)</em></p>
<p><a href="https://arxiv.org/abs/1808.04355">Burda, Edwards & Pathak, et al. (2018)</a> did a set of large-scale comparison experiments on purely curiosity-driven learning, meaning that only intrinsic rewards are provided to the agent. In this study, the reward is \(r_t = r^i_t = \| f(s_t, a_t) - \phi(s_{t+1})\|_2^2\). A good choice of \(\phi\) is crucial to learning forward dynamics, which is expected to be <em>compact</em>, <em>sufficient</em> and <em>stable</em>, making the prediction task more tractable and filtering out irrelevant observation.</p>
<p>In comparison of 4 encoding functions:</p>
<ol>
<li>Raw image pixels: No encoding, \(\phi(x) = x\).</li>
<li><a name="random-feature"></a>Random features (RF): Each state is compressed through a fixed random neural network.</li>
<li><a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#vae-variational-autoencoder">VAE</a>: The probabilistic encoder is used for encoding, \(\phi(x) = q(z \vert x)\).</li>
<li>Inverse dynamic features (IDF): The same feature space as used in <a href="#ICM">ICM</a>.</li>
</ol>
<p>All the experiments have the reward signals normalized by a running estimation of standard deviation of the cumulative returns. And all the experiments are running in an infinite horizon setting to avoid “done” flag leaking information.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/large-scale-curiosity-learning.png" alt="Large-scale curiosity learning" /></p>
<p class="image-caption"><em>Fig. 5. The mean reward in different games when training with only curiosity signals, generated by different state encoding functions.
(Image source: <a href="https://arxiv.org/abs/1808.04355">Burda, Edwards & Pathak, et al. 2018</a>)</em></p>
<p>Interestingly <em>random features</em> turn out to be quite competitive, but in feature transfer experiments (i.e. train an agent in Super Mario Bros level 1-1 and then test it in another level), learned IDF features can generalize better.</p>
<p>They also compared RF and IDF in an environment with a <a href="#the-noisy-tv-problem">noisy TV</a> on. Unsurprisingly the noisy TV drastically slows down the learning and extrinsic rewards are much lower in time.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/noisy-TV-experiment.png" alt="Noisy TV experiment" /></p>
<p class="image-caption"><em>Fig. 6. Experiments using RF and IDF feature encoding in an environment with noisy TV on or off. The plot tracks extrinsic reward per episode as the training progresses. (Image source: <a href="https://arxiv.org/abs/1808.04355">Burda, Edwards & Pathak, et al. 2018</a>)</em></p>
<p>The forward dynamics optimization can be modeled via variational inference as well. <strong>VIME</strong> (short for <em>“Variational information maximizing exploration”</em>; <a href="https://arxiv.org/abs/1605.09674">Houthooft, et al. 2017</a>) is an exploration strategy based on maximization of <em>information gain</em> about the agent’s belief of environment dynamics. How much additional information has been obtained about the forward dynamics can be measured as the reduction in entropy.</p>
<p>Let \(\mathcal{P}\) be the environment transition function, \(p(s_{t+1}\vert s_t, a_t; \theta)\) be the forward prediction model, parameterized by \(\theta \in \Theta\), and \(\xi_t = \{s_1, a_1, \dots, s_t\}\) be the trajectory history. We would like to reduce the entropy after taking a new action and observing the next state, which is to maximize the following:</p>
\[\begin{aligned}
&\sum_t H(\Theta \vert \xi_t, a_t) - H(\Theta \vert S_{t+1}, \xi_t, a_t) \\
=& I(\Theta; S_{t+1} \vert \xi_t, a_t) \quad \scriptstyle{\text{; because } I(X; Y) = I(X) - I(X \vert Y)} \\
=& \mathbb{E}_{s_{t+1} \sim \mathcal{P}(.\vert\xi_t,a_t)} [D_\text{KL}(p(\theta \vert \xi_t, a_t, s_{t+1}) \| p(\theta \vert \xi_t, a_t))] \quad \scriptstyle{\text{; because } I(X; Y) = \mathbb{E}_Y [D_\text{KL} (p_{X \vert Y} \| p_X)]} \\
=& \mathbb{E}_{s_{t+1} \sim \mathcal{P}(.\vert\xi_t,a_t)} [D_\text{KL}(p(\theta \vert \xi_t, a_t, s_{t+1}) \| p(\theta \vert \xi_t))] \quad \scriptstyle{\text{; because } \theta \text{ does not depend on } a_t}
\end{aligned}\]
<p>While taking expectation over the new possible states, the agent is expected to take a new action to increase the KL divergence (<em>“information gain”</em>) between its new belief over the prediction model to the old one. This term can be added into the reward function as an intrinsic reward: \(r^i_t = D_\text{KL} [p(\theta \vert \xi_t, a_t, s_{t+1}) \| p(\theta \vert \xi_t))]\).</p>
<p>However, computing the posterior \(p(\theta \vert \xi_t, a_t, s_{t+1})\) is generally intractable.</p>
\[\begin{aligned}
p(\theta \vert \xi_t, a_t, s_{t+1})
&= \frac{p(\theta \vert \xi_t, a_t) p(s_{t+1} \vert \xi_t, a_t; \theta)}{p(s_{t+1}\vert\xi_t, a_t)} \\
&= \frac{p(\theta \vert \xi_t) p(s_{t+1} \vert \xi_t, a_t; \theta)}{p(s_{t+1}\vert\xi_t, a_t)} & \scriptstyle{\text{; because action doesn't affect the belief.}} \\
&= \frac{\color{red}{p(\theta \vert \xi_t)} p(s_{t+1} \vert \xi_t, a_t; \theta)}{\int_\Theta p(s_{t+1}\vert\xi_t, a_t; \theta) \color{red}{p(\theta \vert \xi_t)} d\theta} & \scriptstyle{\text{; red part is hard to compute directly.}}
\end{aligned}\]
<p>Since it is difficult to compute \(p(\theta\vert\xi_t)\) directly, a natural choice is to approximate it with an alternative distribution \(q_\phi(\theta)\). With variational lower bound, we know the maximization of \(q_\phi(\theta)\) is equivalent to maximizing \(p(\xi_t\vert\theta)\) and minimizing \(D_\text{KL}[q_\phi(\theta) \| p(\theta)]\).</p>
<p>Using the approximation distribution \(q\), the intrinsic reward becomes:</p>
\[r^i_t = D_\text{KL} [q_{\phi_{t+1}}(\theta) \| q_{\phi_t}(\theta))]\]
<p>where \(\phi_{t+1}\) represents \(q\)’s parameters associated with the new relief after seeing \(a_t\) and \(s_{t+1}\). When used as an exploration bonus, it is normalized by division by the moving median of this KL divergence value.</p>
<p>Here the dynamics model is parameterized as a <a href="https://link.springer.com/book/10.1007/978-1-4612-0745-0">Bayesian neural network</a> (BNN), as it maintains a distribution over its weights. The BNN weight distribution \(q_\phi(\theta)\) is modeled as a fully <em>factorized</em> Gaussian with \(\phi = \{\mu, \sigma\}\) and we can easily sample \(\theta \sim q_\phi(.)\). After applying a second-order Taylor expansion, the KL term \(D_\text{KL}[q_{\phi + \lambda \Delta\phi}(\theta) \| q_{\phi}(\theta)]\) can be estimated using <a href="/lil-log/2019/09/05/evolution-strategies.html#estimation-using-fisher-information-matrix">Fisher Information Matrix</a> \(\mathbf{F}_\phi\), which is easy to compute, because \(q_\phi\) is factorized Gaussian and thus the covariance matrix is only a diagonal matrix. See more details in <a href="https://arxiv.org/abs/1605.09674">the paper</a>, especially section 2.3-2.5.</p>
<p><a name="exploration-via-disagreement"></a>All the methods above depend on a single prediction model. If we have multiple such models, we could use the disagreement among models to set the exploration bonus (<a href="https://arxiv.org/abs/1906.04161">Pathak, et al. 2019</a>). High disagreement indicates low confidence in prediction and thus requires more exploration. <a href="https://arxiv.org/abs/1906.04161">Pathak, et al. (2019)</a> proposed to train a set of forward dynamics models and to use the variance over the ensemble of model outputs as \(r_t^i\). Precisely, they encode the state space with <a href="#random-feature">random feature</a> and learn 5 models in the ensemble.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/exploration-via-disagreement.png" alt="Disagreement" /></p>
<p class="image-caption"><em>Fig. 7. Illustration of training architecture for self-supervised exploration via disagreement. (Image source: <a href="https://arxiv.org/abs/1906.04161">Pathak, et al. 2019</a>)</em></p>
<p>Because \(r^i_t\) is differentiable, the intrinsic reward in the model could be directly optimized through gradient descent so as to inform the policy agent to change actions. This differentiable exploration approach is very efficient but limited by having a short exploration horizon.</p>
<h4 id="random-networks">Random Networks</h4>
<p>But, what if the prediction task is not about the environment dynamics at all? It turns out when the prediction is for a random task, it still can help exploration.</p>
<p><strong>DORA</strong> (short for <em>“Directed Outreaching Reinforcement Action-Selection”</em>; <a href="https://arxiv.org/abs/1804.04012">Fox & Choshen, et al. 2018</a>) is a novel framework that injects exploration signals based on a newly introduced, <strong>task-independent</strong> MDP. The idea of DORA depends on two parallel MDPs:</p>
<ul>
<li>One is the original task MDP;</li>
<li>The other is an identical MDP but with <em>no reward attached</em>: Rather, every state-action pair is designed to have value 0. The Q-value learned for the second MDP is called <em>E-value</em>. If the model cannot perfectly predict E-value to be zero, it is still missing information.</li>
</ul>
<p>Initially E-value is assigned with value 1. Such positive initialization can encourage directed exploration for better E-value prediction. State-action pairs with high E-value estimation don’t have enough information gathered yet, at least not enough to exclude their high E-values. To some extent, the logarithm of E-values can be considered as a generalization of <em>visit counters</em>.</p>
<p>When using a neural network to do function approximation for E-value, another value head is added to predict E-value and it is simply expected to predict zero. Given a predicted E-value \(E(s_t, a_t)\), the exploration bonus is \(r^i_t = \frac{1}{\sqrt{-\log E(s_t, a_t)}}\).</p>
<p><a name="RND"></a>Similar to DORA, <strong>Random Network Distillation</strong> (<strong>RND</strong>; <a href="https://arxiv.org/abs/1810.12894">Burda, et al. 2018</a>) introduces a prediction task <em>independent of the main task</em>. The RND exploration bonus is defined as the error of a neural network \(\hat{f}(s_t)\) predicting features of the observations given by a <em>fixed randomly initialized</em> neural network \(f(s_t)\). The motivation is that given a new state, if similar states have been visited many times in the past, the prediction should be easier and thus has lower error. The exploration bonus is \(r^i(s_t) = \|\hat{f}(s_t; \theta) - f(s_t) \|_2^2\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RND.png" alt="RND" /></p>
<p class="image-caption"><em>Fig. 8. How RND (Random Network Distillation) works for providing an intrinsic reward. The features \(O_{i+1} \mapsto f_{i+1}\) are generated by a fixed random neural network. (Image source: <a href="https://openai.com/blog/reinforcement-learning-with-prediction-based-rewards/">OpenAI Blog: “Reinforcement Learning with Prediction-Based Rewards”</a>)</em></p>
<p>Two factors are important in RND experiments:</p>
<ol>
<li>Non-episodic setting results in better exploration, especially when not using any extrinsic rewards. It means that the return is not truncated at “Game over” and intrinsic return can spread across multiple episodes.</li>
<li>Normalization is important since the scale of the reward is tricky to adjust given a random neural network as a prediction target. The intrinsic reward is normalized by division by a running estimate of the standard deviations of the intrinsic return.</li>
</ol>
<p>The RND setup works well for resolving the hard-exploration problem. For example, maximizing the RND exploration bonus consistently finds more than half of the rooms in Montezuma’s Revenge.</p>
<h4 id="physical-properties">Physical Properties</h4>
<p>Different from games in simulators, some RL applications like Robotics need to understand objects and intuitive reasoning in the physical world. Some prediction tasks require the agent to perform a sequence of interactions with the environment and to observe the corresponding consequences, such as estimating some hidden properties in physics (e.g. mass, friction, etc).</p>
<p>Motivated by such ideas, <a href="https://arxiv.org/abs/1611.01843">Denil, et al. (2017)</a> found that DRL agents can learn to perform necessary exploration to discover such hidden properties. Precisely they considered two experiments:</p>
<ol>
<li><em>“Which is heavier?”</em> — The agent has to interact with the blocks and infer which one is heavier.</li>
<li><em>“Towers”</em> — The agent needs to infer how many rigid bodies a tower is composed of by knocking it down.</li>
</ol>
<p>The agent in the experiments first goes through an exploration phase to interact with the environment and to collect information. Once the exploration phase ends, the agent is asked to output a <em>labeling</em> action to answer the question. Then a positive reward is assigned to the agent if the answer is correct; otherwise a negative one is assigned. Because the answer requires a decent amount of interactions with items in the scene, the agent has to learn to efficiently play around so as to figure out the physics and the correct answer. The exploration naturally happens.</p>
<p>In their experiments, the agent is able to learn in both tasks with performance varied by the difficulty of the task. Although the paper didn’t use the physics prediction task to provide intrinsic reward bonus along with extrinsic reward associated with another learning task, rather it focused on the exploration tasks themselves. I do enjoy the idea of encouraging sophisticated exploration behavior by predicting hidden physics properties in the environment.</p>
<h2 id="memory-based-exploration">Memory-based Exploration</h2>
<p>Reward-based exploration suffers from several drawbacks:</p>
<ul>
<li>Function approximation is slow to catch up.</li>
<li>Exploration bonus is non-stationary.</li>
<li>Knowledge fading, meaning that states cease to be novel and cannot provide intrinsic reward signals in time.</li>
</ul>
<p>Methods in this section rely on external memory to resolve disadvantages of reward bonus-based exploration.</p>
<h3 id="episodic-memory">Episodic Memory</h3>
<p>As mentioned above, <a href="#RND">RND</a> is better running in an non-episodic setting, meaning the prediction knowledge is accumulated across multiple episodes. The exploration strategy, <strong>Never Give Up</strong> (<strong>NGU</strong>; <a href="https://arxiv.org/abs/2002.06038">Badia, et al. 2020a</a>), combines an episodic novelty module that can rapidly adapt within one episode with RND as a lifelong novelty module.</p>
<p>Precisely, the intrinsic reward in NGU consists of two exploration bonuses from two modules, <em>within one episode</em> and <em>across multiple episodes</em>, respectively.</p>
<p>The short-term per-episode reward is provided by an <em>episodic novelty module</em>. It contains an episodic memory \(M\), a dynamically-sized slot-based memory, and an IDF (inverse dynamics features) embedding function \(\phi\), same as the feature encoding in <a href="#ICM">ICM</a></p>
<ol>
<li>At every step the current state embedding \(\phi(s_t)\) is added into \(M\).</li>
<li>The intrinsic bonus is determined by comparing how similar the current observation is to the content of \(M\). A larger difference results in a larger bonus.
<br />
\(r^\text{episodic}_t \approx \frac{1}{\sqrt{\sum_{\phi_i \in N_k} K(\phi(x_t), \phi_i)} + c}\)
<br />
where \(K(x, y)\) is a kernel function for measuring the distance between two samples. \(N_k\) is a set of \(k\) nearest neighbors in \(M\) according to \(K(., .)\). \(c\) is a small constant to keep the denominator non-zero. In the paper, \(K(x, y)\) is configured to be the inverse kernel:
<br />
\(K(x, y) = \frac{\epsilon}{\frac{d^2(x, y)}{d^2_m} + \epsilon}\)
<br />
where \(d(.,.)\) is Euclidean distance between two samples and \(d_m\) is a running average of the squared Euclidean distance of the k-th nearest neighbors for better robustness. \(\epsilon\) is a small constant.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NGU.png" alt="RND" /></p>
<p class="image-caption"><em>Fig. 9. The architecture of NGU’s embedding function (left) and reward generator (right). (Image source: <a href="https://arxiv.org/abs/2002.06038">Badia, et al. 2020a</a>)</em></p>
<p>The long-term across-episode novelty relies on RND prediction error in <em>life-long novelty module</em>. The exploration bonus is \(\alpha_t = 1 + \frac{e^\text{RND}(s_t) - \mu_e}{\sigma_e}\) where \(\mu_e\) and \(\sigma_e\) are running mean and std dev for RND error \(e^\text{RND}(s_t)\).</p>
<blockquote>
<p>However in the conclusion section of the <a href="https://arxiv.org/abs/1810.12894">RND paper</a>, I noticed the following statement:</p>
<p>“We find that the RND exploration bonus is sufficient to deal with local exploration, i.e. exploring the consequences of short-term decisions, like whether to interact with a particular object, or avoid it. However global exploration that involves coordinated decisions over long time horizons is beyond the reach of our method. “</p>
<p>And this confuses me a bit how RND can be used as a good life-long novelty bonus provider. If you know why, feel free to leave a comment below.</p>
</blockquote>
<p>The final combined intrinsic reward is \(r^i_t = r^\text{episodic}_t \cdot \text{clip}(\alpha_t, 1, L)\), where \(L\) is a constant maximum reward scalar.</p>
<p>The design of NGU enables it to have two nice properties:</p>
<ol>
<li><em>Rapidly discourages</em> revisiting the same state <em>within</em> the same episode;</li>
<li><em>Slowly discourages</em> revisiting states that have been visited many times <em>across</em> episodes.</li>
</ol>
<p>Later, built on top of NGU, DeepMind proposed “Agent57” (<a href="https://arxiv.org/abs/2003.13350">Badia, et al. 2020b</a>), the first deep RL agent that outperforms the standard human benchmark on <em>all</em> 57 Atari games. Two major improvements in Agent57 over NGU are:</p>
<ol>
<li>A <em>population</em> of policies are trained in Agent57, each equipped with a different exploration parameter pair \(\{(\beta_j, \gamma_j)\}_{j=1}^N\). Recall that given \(\beta_j\), the reward is constructed as \(r_{j,t} = r_t^e + \beta_j r^i_t\) and \(\gamma_j\) is the reward discounting factor. It is natural to expect policies with higher \(\beta_j\) and lower \(\gamma_j\) to make more progress early in training, while the opposite would be expected as training progresses. A meta-controller (<a href="https://arxiv.org/pdf/0805.3415.pdf">sliding-window UCB bandit algorithm</a>) is trained to select which policies should be prioritized.</li>
<li>The second improvement is a new parameterization of Q-value function that decomposes the contributions of the intrinsic and extrinsic rewards in a similar form as the bundled reward: \(Q(s, a; \theta_j) = Q(s, a; \theta_j^e) + \beta_j Q(s, a; \theta_j^i)\). During training, \(Q(s, a; \theta_j^e)\) and \(Q(s, a; \theta_j^i)\) are optimized separately with rewards \(r_j^e\) and \(r_j^i\), respectively.</li>
</ol>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/agent57.png" alt="Agent57" /></p>
<p class="image-caption"><em>Fig. 10. A pretty cool illustration of techniques developed in time since DQN in 2015, eventually leading to Agent57. (Image source: <a href="https://deepmind.com/blog/article/Agent57-Outperforming-the-human-Atari-benchmark">DeepMind Blog: “Agent57: Outperforming the human Atari benchmark”</a>)</em></p>
<p>Instead of using the Euclidean distance to measure closeness of states in episodic memory, <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. (2019)</a> took the transition between states into consideration and proposed a method to measure the number of steps needed to visit one state from other states in memory, named <strong>Episodic Curiosity (EC)</strong> module. The novelty bonus depends on reachability between states.</p>
<ol>
<li>At the beginning of each episode, the agent starts with an empty episodic memory \(M\).</li>
<li>At every step, the agent compares the current state with saved states in memory to determine novelty bonus: If the current state is novel (i.e., takes more steps to reach from observations in memory than a threshold), the agent gets a bonus.</li>
<li>The current state is added into the episodic memory if the novelty bonus is high enough. (Imagine that if all the states were added into memory, any new state could be added within 1 step.)</li>
<li>Repeat 1-3 until the end of this episode.</li>
</ol>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/transition-graph.png" alt="Transition graph" /></p>
<p class="image-caption"><em>Fig. 11. The nodes in the graph are states, the edges are possible transitions. The blue nodes are states in memory. The green nodes are reachable from the memory within \(k = 2\) steps (not novel). The orange nodes are further away, so they are considered as novel states. (Image source: <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. 2019</a>)</em></p>
<p>In order to estimate reachability between states, we need to access the transition graph, which is unfortunately not entirely known. Thus, <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. (2019)</a> trained a <a href="/lil-log/2018/11/30/meta-learning.html#convolutional-siamese-neural-network">siamese</a> neural network to predict how many steps separate two states. It contains one embedding network \(\phi: \mathcal{S} \mapsto \mathbb{R}^n\) to first encode the states to feature vectors and then one comparator network \(C: \mathbb{R}^n \times \mathbb{R}^n \mapsto [0, 1]\) to output a binary label on whether two states are close enough (i.e., reachable within \(k\) steps) in the transition graph, \(C(\phi(s_i), \phi(s_j)) \mapsto [0, 1]\).</p>
<p>An episodic memory buffer \(M\) stores embeddings of some past observations within the same episode. A new observation will be compared with existing state embeddings via \(C\) and the results are aggregated (e.g. max, 90th percentile) to provide a reachability score \(C^M(\phi(s_t))\). The exploration bonus is \(r^i_t = \big(C' - C^M(f(s_t))\big)\), where \(C'\) is a predefined threshold for determining the sign of the reward (e.g. \(C'=0.5\) works well for fixed-duration episodes). High bonus is awarded to new states when they are not easily reachable from states in the memory buffer.</p>
<p>They claimed that the EC module can overcome the <a href="#the-noisy-tv-problem">noisy-TV</a> problem.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/episodic-memory-overview.png" alt="EC module" /></p>
<p class="image-caption"><em>Fig. 12. The architecture of episodic curiosity (EC) module for intrinsic reward generation. (Image source: <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. 2019</a>)</em></p>
<h3 id="direct-exploration">Direct Exploration</h3>
<p><strong>Go-Explore</strong> (<a href="https://arxiv.org/abs/1901.10995">Ecoffet, et al., 2019</a>) is an algorithm aiming to solve the “hard-exploration” problem. It is composed of the following two phases.</p>
<p><strong>Phase 1 (“Explore until solved”)</strong> feels quite like <a href="https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm">Dijkstra’s algorithm</a> for finding shortest paths in a graph. Indeed, no neural network is involved in phase 1. By maintaining a memory of interesting states as well as trajectories leading to them, the agent can go back (given a simulator is <em>deterministic</em>) to promising states and continue doing <em>random</em> exploration from there. The state is mapped into a short discretized code (named “cell”) in order to be memorized. The memory is updated if a new state appears or a better/shorter trajectory is found. When selecting which past states to return to, the agent might select one in the memory uniformly or according to heuristics like recency, visit count, count of neighbors in the memory, etc. This process is repeated until the task is solved and at least one solution trajectory is found.</p>
<p>The above found high-performance trajectories would not work well on evaluation envs with any stochasticity. Thus, <strong>Phase 2 (“Robustification”)</strong> is needed to robustify the solution via imitation learning. They adopted <a href="https://arxiv.org/abs/1812.03381">Backward Algorithm</a>, in which the agent is started near the last state in the trajectory and then runs RL optimization from there.</p>
<p>One important note in phase 1 is: In order to go back to a state deterministically without exploration, Go-Explore depends on a resettable and deterministic simulator, which is a big disadvantage.</p>
<p>To make the algorithm more generally useful to environments with stochasticity, an enhanced version of Go-Explore (<a href="https://arxiv.org/abs/2004.12919">Ecoffet, et al., 2020</a>), named <strong>policy-based Go-Explore</strong> was proposed later.</p>
<ul>
<li>Instead of resetting the simulator state effortlessly, the policy-based Go-Explore learns a <em>goal-conditioned policy</em> and uses that to access a known state in memory repeatedly. The goal-conditioned policy is trained to follow the best trajectory that previously led to the selected states in memory. They include a <strong>Self-Imitation Learning</strong> (<strong>SIL</strong>; <a href="https://arxiv.org/abs/1806.05635">Oh, et al. 2018</a>) loss to help extract as much information as possible from successful trajectories.</li>
<li>Also, they found sampling from policy works better than random actions when the agent returns to promising states to continue exploration.</li>
<li>Another improvement in policy-based Go-Explore is to make the downscaling function of images to cells adjustable. It is optimized so that there would be neither too many nor too few cells in the memory.</li>
</ul>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/policy-based-Go-Explore.png" alt="Policy-based Go-Explore" /></p>
<p class="image-caption"><em>Fig. 13. An overview of the Go-Explore algorithm. (Image source: <a href="https://arxiv.org/abs/2004.12919">Ecoffet, et al., 2020</a>)</em></p>
<p>After vanilla Go-Explore, <a href="https://arxiv.org/abs/1907.10247">Yijie Guo, et al. (2019)</a> proposed <strong>DTSIL</strong> (Diverse Trajectory-conditioned Self-Imitation Learning), which shared a similar idea as policy-based Go-Explore above. DTSIL maintains a memory of diverse demonstrations collected during training and uses them to train a trajectory-conditioned policy via <a href="https://arxiv.org/abs/1806.05635">SIL</a>. They prioritize trajectories that end with a rare state during sampling.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DTSIL-algo.png" alt="DTSIL" /></p>
<p class="image-caption"><em>Fig. 14. Algorithm of DTSIL (Diverse Trajectory-conditioned Self-Imitation Learning). (Image source: <a href="https://arxiv.org/abs/1907.10247">Yijie Guo, et al. 2019</a>)</em></p>
<p>The similar approach is also seen in <a href="https://arxiv.org/abs/1906.07805">Guo, et al. (2019)</a>. The main idea is to store goals with <em>high uncertainty</em> in memory so that later the agent can revisit these goal states with a goal-conditioned policy repeatedly. In each episode, the agent flips a coin (probability 0.5) to decide whether it will act greedily w.r.t. the policy or do directed exploration by sampling goals from the memory.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/directed-exploration.png" alt="Directed exploration" /></p>
<p class="image-caption"><em>Fig. 15. Different components in directed exploration with function approximation. (Image source: <a href="https://arxiv.org/abs/1906.07805">Guo, et al. 2019</a>)</em></p>
<p>The uncertainty measure of a state can be something simple like count-based bonuses or something complex like density or bayesian models. The paper trained a forward dynamics model and took its prediction error as the uncertainty metric.</p>
<h2 id="q-value-exploration">Q-Value Exploration</h2>
<p>Inspired by <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#thompson-sampling">Thompson sampling</a>, <strong>Bootstrapped DQN</strong> (<a href="https://arxiv.org/abs/1602.04621">Osband, et al. 2016</a>) introduces a notion of uncertainty in Q-value approximation in classic <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#deep-q-network">DQN</a> by using the <a href="https://en.wikipedia.org/wiki/Bootstrapping_(statistics)">bootstrapping</a> method. Bootstrapping is to approximate a distribution by sampling with replacement from the same population multiple times and then aggregate the results.</p>
<p>Multiple Q-value heads are trained in parallel but each only consumes a bootstrapped sub-sampled set of data and each has its own corresponding target network. All the Q-value heads share the same backbone network.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/bootstrapped-DQN-algo.png" alt="Bootstrapped DQN" /></p>
<p class="image-caption"><em>Fig. 16. The algorithm of Bootstrapped DQN. (Image source: <a href="https://arxiv.org/abs/1602.04621">Osband, et al. 2016</a>)</em></p>
<p>At the beginning of one episode, one Q-value head is sampled uniformly and acts for collecting experience data in this episode. Then a binary mask is sampled from the masking distribution \(m \sim \mathcal{M}\) and decides which heads can use this data for training. The choice of masking distribution \(\mathcal{M}\) determines how bootstrapped samples are generated; For example,</p>
<ul>
<li>If \(\mathcal{M}\) is an independent Bernoulli distribution with \(p=0.5\), this corresponds to the double-or-nothing bootstrap.</li>
<li>If \(\mathcal{M}\) always returns an all-one mask, the algorithm reduces to an ensemble method.</li>
</ul>
<p>However, this kind of exploration is still restricted, because uncertainty introduced by bootstrapping fully relies on the training data. It is better to inject some prior information independent of the data. This “noisy” prior is expected to drive the agent to keep exploring when the reward is sparse. The algorithm of adding random prior into bootstrapped DQN for better exploration (<a href="https://arxiv.org/abs/1806.03335">Osband, et al. 2018</a>) depends on Bayesian linear regression. The core idea of Bayesian regression is: We can <em>“generate posterior samples by training on noisy versions of the data, together with some random regularization”</em>.</p>
<p>Let \(\theta\) be the Q function parameter and \(\theta^-\) for the target Q, the loss function using a randomized prior function \(p\) is:</p>
\[\mathcal{L}(\theta, \theta^{-}, p, \mathcal{D}; \gamma) = \sum_{t\in\mathcal{D}}\Big( r_t + \gamma \max_{a'\in\mathcal{A}} (\underbrace{Q_{\theta^-} + p)}_\text{target Q}(s'_t, a') - \underbrace{(Q_\theta + p)}_\text{Q to optimize}(s_t, a_t) \Big)^2\]
<h2 id="varitional-options">Varitional Options</h2>
<p>Options are policies with termination conditions. There are a large set of options available in the search space and they are independent of an agent’s intentions. By explicitly including intrinsic options into modeling, the agent can obtain intrinsic rewards for exploration.</p>
<p><strong>VIC</strong> (short for <em>“Variational Intrinsic Control”</em>; <a href="https://arxiv.org/abs/1611.07507">Gregor, et al. 2017</a>) is such a framework for providing the agent with intrinsic exploration bonuses based on modeling options and learning policies conditioned on options. Let \(\Omega\) represent an option which starts from \(s_0\) and ends at \(s_f\). An environment probability distribution \(p^J(s_f \vert s_0, \Omega)\) defines where an option \(\Omega\) terminates given a starting state \(s_0\). A controllability distribution \(p^C(\Omega \vert s_0)\) defines the probability distribution of options we can sample from. And by definition we have \(p(s_f, \Omega \vert s_0) = p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0)\).</p>
<p>While choosing options, we would like to achieve two goals:</p>
<ul>
<li>Achieve a diverse set of the final states from \(s_0\) ⇨ Maximization of \(H(s_f \vert s_0)\).</li>
<li>Know precisely which state a given option \(\Omega\) can end with ⇨ Minimization of \(H(s_f \vert s_0, \Omega)\).</li>
</ul>
<p>Combining them, we get mutual information \(I(\Omega; s_f \vert s_0)\) to maximize:</p>
\[\begin{aligned}
I(\Omega; s_f \vert s_0)
&= H(s_f \vert s_0) - H(s_f \vert s_0, \Omega) \\
&= - \sum_{s_f} p(s_f \vert s_0) \log p(s_f \vert s_0) + \sum_{s_f, \Omega} p(s_f, \Omega \vert s_0) \log \frac{p(s_f, \Omega \vert s_0)}{p^C(\Omega \vert s_0)} \\
&= - \sum_{s_f} p(s_f \vert s_0) \log p(s_f \vert s_0) + \sum_{s_f, \Omega} p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0) \log p^J(s_f \vert s_0, \Omega) \\
\end{aligned}\]
<p>Because mutual information is symmetric, we can switch \(s_f\) and \(\Omega\) in several places without breaking the equivalence. Also because \(p(\Omega \vert s_0, s_f)\) is difficult to observe, let us replace it with an approximation distribution \(q\). According to the variational lower bound, we would have \(I(\Omega; s_f \vert s_0) \geq I^{VB}(\Omega; s_f \vert s_0)\).</p>
\[\begin{aligned}
I(\Omega; s_f \vert s_0)
&= I(s_f; \Omega \vert s_0) \\
&= - \sum_{\Omega} p(\Omega \vert s_0) \log p(\Omega \vert s_0) + \sum_{s_f, \Omega} p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0) \log \color{red}{p(\Omega \vert s_0, s_f)}\\
I^{VB}(\Omega; s_f \vert s_0)
&= - \sum_{\Omega} p(\Omega \vert s_0) \log p(\Omega \vert s_0) + \sum_{s_f, \Omega} p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0) \log \color{red}{q(\Omega \vert s_0, s_f)} \\
I(\Omega; s_f \vert s_0) &\geq I^{VB}(\Omega; s_f \vert s_0)
\end{aligned}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/VIC-explicit-options.png" alt="VIC" /></p>
<p class="image-caption"><em>Fig. 17. The algorithm for VIC (Variational Intrinsic Control). (Image source: <a href="https://arxiv.org/abs/1611.07507">Gregor, et al. 2017</a>)</em></p>
<p>Here \(\pi(a \vert \Omega, s)\) can be optimized with any RL algorithm. The option inference function \(q(\Omega \vert s_0, s_f)\) is doing supervised learning. The prior \(p^C\) is updated so that it tends to choose \(\Omega\) with higher rewards. Note that \(p^C\) can also be fixed (e.g. a Gaussian). Various \(\Omega\) will result in different behavior through learning. Additionally, <a href="https://arxiv.org/abs/1611.07507">Gregor, et al. (2017)</a> observed that it is difficult to make VIC with explicit options work in practice with function approximation and therefore they also proposed another version of VIC with implicit options.</p>
<p>Different from VIC which models \(\Omega\) conditioned only on the start and end states, <strong>VALOR</strong> (short for <em>“Variational Auto-encoding Learning of Options by Reinforcement”</em>; <a href="https://arxiv.org/abs/1807.10299">Achiam, et al. 2018</a>) relies on the whole trajectory to extract the option context \(c\), which is sampled from a fixed Gaussian distribution. In VALOR:</p>
<ul>
<li>A policy acts as an encoder, translating contexts from a noise distribution into trajectories</li>
<li>A decoder attempts to recover the contexts from the trajectories, and rewards the policies for making contexts easier to distinguish. The decoder never sees the actions during training, so the agent has to interact with the environment in a way that facilitates communication with the decoder for better prediction. Also, the decoder recurrently takes in a sequence of steps in one trajectory to better model the correlation between timesteps.</li>
</ul>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/VALOR-decoder.png" alt="VALOR" /></p>
<p class="image-caption"><em>Fig. 18. The decoder of VALOR is a biLSTM which takes \(N = 11\) equally spaced observations from one trajectory as inputs. (Image source: <a href="https://arxiv.org/abs/1807.10299">Achiam, et al. 2018</a>)</em></p>
<p>DIAYN (“Diversity is all you need”; <a href="https://arxiv.org/abs/1802.06070">Eysenbach, et al. 2018</a>) has the idea lying in the same direction, although with a different name — DIAYN models the policies conditioned on a latent <em>skill</em> variable. See my <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#learning-with-random-rewards">previous post</a> for more details.</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020exploration,
title = "Exploration Strategies in Deep Reinforcement Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/06/07/exploration-strategies-in-deep-reinforcement-learning.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Pierre-Yves Oudeyer & Frederic Kaplan. <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.567.6524&rep=rep1&type=pdf">“How can we define intrinsic motivation?”</a> Conf. on Epigenetic Robotics, 2008.</p>
<p>[2] Marc G. Bellemare, et al. <a href="https://arxiv.org/abs/1606.01868">“Unifying Count-Based Exploration and Intrinsic Motivation”</a>. NIPS 2016.</p>
<p>[3] Georg Ostrovski, et al. <a href="https://arxiv.org/abs/1703.01310">“Count-Based Exploration with Neural Density Models”</a>. PMLR 2017.</p>
<p>[4] Rui Zhao & Volker Tresp. <a href="https://arxiv.org/abs/1902.08039">“Curiosity-Driven Experience Prioritization via
Density Estimation”</a>. NIPS 2018.</p>
<p>[5] Haoran Tang, et al. <a href="https://arxiv.org/abs/1611.04717">“#Exploration: A Study of Count-Based Exploration for Deep Reinforcement Learning”</a>. NIPS 2017.</p>
<p>[6] Jürgen Schmidhuber. <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.45.957">“A possibility for implementing curiosity and boredom in model-building neural controllers”</a> 1991.</p>
<p>[7] Pierre-Yves Oudeyer, et al. <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.177.7661&rep=rep1&type=pdf">“Intrinsic Motivation Systems for Autonomous Mental Development”</a> IEEE Transactions on Evolutionary Computation, 2007.</p>
<p>[8] Bradly C. Stadie, et al. <a href="https://arxiv.org/abs/1507.00814">“Incentivizing Exploration In Reinforcement Learning With Deep Predictive Models”</a>. ICLR 2016.</p>
<p>[9] Deepak Pathak, et al. <a href="https://arxiv.org/abs/1705.05363">“Curiosity-driven Exploration by Self-supervised Prediction”</a>. CVPR 2017.</p>
<p>[10] Yuri Burda, Harri Edwards & Deepak Pathak, et al. <a href="https://arxiv.org/abs/1808.04355">“Large-Scale Study of Curiosity-Driven Learning”</a>. arXiv 1808.04355 (2018).</p>
<p>[11] Joshua Achiam & Shankar Sastry. <a href="https://arxiv.org/abs/1703.01732">“Surprise-Based Intrinsic Motivation for Deep Reinforcement Learning”</a> NIPS 2016 Deep RL Workshop.</p>
<p>[12] Rein Houthooft, et al. <a href="https://arxiv.org/abs/1605.09674">“VIME: Variational information maximizing exploration”</a>. NIPS 2016.</p>
<p>[13] Leshem Choshen, Lior Fox & Yonatan Loewenstein. <a href="https://arxiv.org/abs/1804.04012">“DORA the explorer: Directed outreaching reinforcement action-selection”</a>. ICLR 2018</p>
<p>[14] Yuri Burda, et al. <a href="https://arxiv.org/abs/1810.12894">“Exploration by Random Network Distillation”</a> ICLR 2019.</p>
<p>[15] OpenAI Blog: <a href="https://openai.com/blog/reinforcement-learning-with-prediction-based-rewards/">“Reinforcement Learning with
Prediction-Based Rewards”</a> Oct, 2018.</p>
<p>[16] Misha Denil, et al. <a href="https://arxiv.org/abs/1611.01843">“Learning to Perform Physics Experiments via Deep Reinforcement Learning”</a>. ICLR 2017.</p>
<p>[17] Ian Osband, et al. <a href="https://arxiv.org/abs/1602.04621">“Deep Exploration via Bootstrapped DQN”</a>. NIPS 2016.</p>
<p>[18] Ian Osband, John Aslanides & Albin Cassirer. <a href="https://arxiv.org/abs/1806.03335">“Randomized Prior Functions for Deep Reinforcement Learning”</a>. NIPS 2018.</p>
<p>[19] Karol Gregor, Danilo Jimenez Rezende & Daan Wierstra. <a href="https://arxiv.org/abs/1611.07507">“Variational Intrinsic Control”</a>. ICLR 2017.</p>
<p>[20] Joshua Achiam, et al. <a href="https://arxiv.org/abs/1807.10299">“Variational Option Discovery Algorithms”</a>. arXiv 1807.10299 (2018).</p>
<p>[21] Benjamin Eysenbach, et al. <a href="https://arxiv.org/abs/1802.06070">“Diversity is all you need: Learning skills without a reward function.”</a>. ICLR 2019.</p>
<p>[22] Adrià Puigdomènech Badia, et al. <a href="https://arxiv.org/abs/2002.06038">“Never Give Up (NGU): Learning Directed Exploration Strategies”</a> ICLR 2020.</p>
<p>[23] Adrià Puigdomènech Badia, et al. <a href="https://arxiv.org/abs/2003.13350">“Agent57: Outperforming the Atari Human Benchmark”</a>. arXiv 2003.13350 (2020).</p>
<p>[24] DeepMind Blog: <a href="https://deepmind.com/blog/article/Agent57-Outperforming-the-human-Atari-benchmark">“Agent57: Outperforming the human Atari benchmark”</a> Mar 2020.</p>
<p>[25] Nikolay Savinov, et al. <a href="https://arxiv.org/abs/1810.02274">“Episodic Curiosity through Reachability”</a> ICLR 2019.</p>
<p>[26] Adrien Ecoffet, et al. <a href="https://arxiv.org/abs/1901.10995">“Go-Explore: a New Approach for Hard-Exploration Problems”</a>. arXiv 1901.10995 (2019).</p>
<p>[27] Adrien Ecoffet, et al. <a href="https://arxiv.org/abs/2004.12919">“First return then explore”</a>. arXiv 2004.12919 (2020).</p>
<p>[28] Junhyuk Oh, et al. <a href="https://arxiv.org/abs/1806.05635">“Self-Imitation Learning”</a>. ICML 2018.</p>
<p>[29] Yijie Guo, et al. <a href="https://arxiv.org/abs/1907.10247">“Self-Imitation Learning via Trajectory-Conditioned Policy for Hard-Exploration Tasks”</a>. arXiv 1907.10247 (2019).</p>
<p>[30] Zhaohan Daniel Guo & Emma Brunskill. <a href="https://arxiv.org/abs/1906.07805">“Directed Exploration for Reinforcement Learning”</a>. arXiv 1906.07805 (2019).</p>
<p>[31] Deepak Pathak, et al. <a href="https://arxiv.org/abs/1906.04161">“Self-Supervised Exploration via Disagreement.”</a> ICML 2019.</p>Lilian WengExploitation versus exploration is a critical topic in reinforcement learning. This post introduces several common approaches for better exploration in Deep RL.The Transformer Family2020-04-07T12:00:00+00:002020-04-07T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/04/07/the-transformer-family<blockquote>
<p>Inspired by recent progress on various enhanced versions of Transformer models, this post presents how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving, etc.</p>
</blockquote>
<!--more-->
<p>It has been almost two years since my last post on <a href="/lil-log/2018/06/24/attention-attention.html">attention</a>. Recent progress on new and enhanced versions of Transformer motivates me to write another post on this specific topic, focusing on how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving and more.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#notations" id="markdown-toc-notations">Notations</a></li>
<li><a href="#attention-and-self-attention" id="markdown-toc-attention-and-self-attention">Attention and Self-Attention</a></li>
<li><a href="#multi-head-self-attention" id="markdown-toc-multi-head-self-attention">Multi-Head Self-Attention</a></li>
<li><a href="#transformer" id="markdown-toc-transformer">Transformer</a></li>
<li><a href="#adaptive-computation-time-act" id="markdown-toc-adaptive-computation-time-act">Adaptive Computation Time (ACT)</a></li>
<li><a href="#improved-attention-span" id="markdown-toc-improved-attention-span">Improved Attention Span</a> <ul>
<li><a href="#longer-attention-span-transformer-xl" id="markdown-toc-longer-attention-span-transformer-xl">Longer Attention Span (Transformer-XL)</a></li>
<li><a href="#adaptive-attention-span" id="markdown-toc-adaptive-attention-span">Adaptive Attention Span</a></li>
<li><a href="#localized-attention-span-image-transformer" id="markdown-toc-localized-attention-span-image-transformer">Localized Attention Span (Image Transformer)</a></li>
</ul>
</li>
<li><a href="#less-time-and-memory-cost" id="markdown-toc-less-time-and-memory-cost">Less Time and Memory Cost</a> <ul>
<li><a href="#sparse-attention-matrix-factorization-sparse-transformers" id="markdown-toc-sparse-attention-matrix-factorization-sparse-transformers">Sparse Attention Matrix Factorization (Sparse Transformers)</a></li>
<li><a href="#locality-sensitive-hashing-reformer" id="markdown-toc-locality-sensitive-hashing-reformer">Locality-Sensitive Hashing (Reformer)</a></li>
</ul>
</li>
<li><a href="#make-it-recurrent-universal-transformer" id="markdown-toc-make-it-recurrent-universal-transformer">Make it Recurrent (Universal Transformer)</a></li>
<li><a href="#stabilization-for-rl-gtrxl" id="markdown-toc-stabilization-for-rl-gtrxl">Stabilization for RL (GTrXL)</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h3 id="notations">Notations</h3>
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(d\)</td>
<td>The model size / hidden state dimension / positional encoding size.</td>
</tr>
<tr>
<td>\(h\)</td>
<td>The number of heads in multi-head attention layer.</td>
</tr>
<tr>
<td>\(L\)</td>
<td>The segment length of input sequence.</td>
</tr>
<tr>
<td>\(\mathbf{X} \in \mathbb{R}^{L \times d}\)</td>
<td>The input sequence where each element has been mapped into an embedding vector of shape \(d\), same as the model size.</td>
</tr>
<tr>
<td>\(\mathbf{W}^k \in \mathbb{R}^{d \times d_k}\)</td>
<td>The key weight matrix.</td>
</tr>
<tr>
<td>\(\mathbf{W}^q \in \mathbb{R}^{d \times d_k}\)</td>
<td>The query weight matrix.</td>
</tr>
<tr>
<td>\(\mathbf{W}^v \in \mathbb{R}^{d \times d_v}\)</td>
<td>The value weight matrix. Often we have \(d_k = d_v = d\).</td>
</tr>
<tr>
<td>\(\mathbf{W}^k_i, \mathbf{W}^q_i \in \mathbb{R}^{d \times d_k/h}; \mathbf{W}^v_i \in \mathbb{R}^{d \times d_v/h}\)</td>
<td>The weight matrices per head.</td>
</tr>
<tr>
<td>\(\mathbf{W}^o \in \mathbb{R}^{d_v \times d}\)</td>
<td>The output weight matrix.</td>
</tr>
<tr>
<td>\(\mathbf{Q} = \mathbf{X}\mathbf{W}^q \in \mathbb{R}^{L \times d_k}\)</td>
<td>The query embedding inputs.</td>
</tr>
<tr>
<td>\(\mathbf{K} = \mathbf{X}\mathbf{W}^k \in \mathbb{R}^{L \times d_k}\)</td>
<td>The key embedding inputs.</td>
</tr>
<tr>
<td>\(\mathbf{V} = \mathbf{X}\mathbf{W}^v \in \mathbb{R}^{L \times d_v}\)</td>
<td>The value embedding inputs.</td>
</tr>
<tr>
<td>\(S_i\)</td>
<td>A collection of key positions for the \(i\)-th query \(\mathbf{q}_i\) to attend to.</td>
</tr>
<tr>
<td>\(\mathbf{A} \in \mathbb{R}^{L \times L}\)</td>
<td>The self-attention matrix between a input sequence of lenght \(L\) and itself. \(\mathbf{A} = \text{softmax}(\mathbf{Q}\mathbf{K}^\top / \sqrt{d_k})\).</td>
</tr>
<tr>
<td>\(a_{ij} \in \mathbf{A}\)</td>
<td>The scalar attention score between query \(\mathbf{q}_i\) and key \(\mathbf{k}_j\).</td>
</tr>
<tr>
<td>\(\mathbf{P} \in \mathbb{R}^{L \times d}\)</td>
<td>position encoding matrix, where the \(i\)-th row \(\mathbf{p}_i\) is the positional encoding for input \(\mathbf{x}_i\).</td>
</tr>
</tbody>
</table>
<h2 id="attention-and-self-attention">Attention and Self-Attention</h2>
<p><em>Attention</em> is a mechanism in the neural network that a model can learn to make predictions by selectively attending to a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed as a weighted average.</p>
<p><em>Self-attention</em> is a type of attention mechanism where the model makes prediction for one part of a data sample using other parts of the observation about the same sample. Conceptually, it feels quite similar to <a href="https://en.wikipedia.org/wiki/Non-local_means">non-local means</a>. Also note that self-attention is permutation-invariant; in other words, it is an operation on sets.</p>
<p>There are various forms of attention / self-attention, Transformer (<a href="https://arxiv.org/abs/1706.03762">Vaswani et al., 2017</a>) relies on the <em>scaled dot-product attention</em>: given a query matrix \(\mathbf{Q}\), a key matrix \(\mathbf{K}\) and a value matrix \(\mathbf{V}\), the output is a weighted sum of the value vectors, where the weight assigned to each value slot is determined by the dot-product of the query with the corresponding key:</p>
\[\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}(\frac{\mathbf{Q} {\mathbf{K}}^\top}{\sqrt{d_k}})\mathbf{V}\]
<p>And for a query and a key vector \(\mathbf{q}_i, \mathbf{k}_j \in \mathbb{R}^d\) (row vectors in query and key matrices), we have a scalar score:</p>
\[a_{ij} = \text{softmax}(\frac{\mathbf{q}_i {\mathbf{k}_j}^\top}{\sqrt{d_k}})
= \frac{\exp(\mathbf{q}_i {\mathbf{k}_j}^\top)}{ \sqrt{d_k} \sum_{r \in S_i} \exp(\mathbf{q}_i {\mathbf{k}_r}^\top) }\]
<p>where \(S_i\) is a collection of key positions for the \(i\)-th query to attend to.</p>
<p>See my old <a href="/lil-log/2018/06/24/attention-attention.html#a-family-of-attention-mechanisms">post</a> for other types of attention if interested.</p>
<h2 id="multi-head-self-attention">Multi-Head Self-Attention</h2>
<p>The <em>multi-head self-attention</em> module is a key component in Transformer. Rather than only computing the attention once, the multi-head mechanism splits the inputs into smaller chunks and then computes the scaled dot-product attention over each subspace in parallel. The independent attention outputs are simply concatenated and linearly transformed into expected dimensions.</p>
\[\begin{aligned}
\text{MultiHeadAttention}(\mathbf{X}_q, \mathbf{X}_k, \mathbf{X}_v) &= [\text{head}_1; \dots; \text{head}_h] \mathbf{W}^o \\
\text{where head}_i &= \text{Attention}(\mathbf{X}_q\mathbf{W}^q_i, \mathbf{X}_k\mathbf{W}^k_i, \mathbf{X}_v\mathbf{W}^v_i)
\end{aligned}\]
<p>where \([.;.]\) is a concatenation operation. \(\mathbf{W}^q_i, \mathbf{W}^k_i \in \mathbb{R}^{d \times d_k/h}, \mathbf{W}^v_i \in \mathbb{R}^{d \times d_v/h}\) are weight matrices to map input embeddings of size \(L \times d\) into query, key and value matrices. And \(\mathbf{W}^o \in \mathbb{R}^{d_v \times d}\) is the output linear transformation. All the weights should be learned during training.</p>
<p style="width: 30%;" class="center"><img src="/lil-log/assets/images/multi-head-attention.png" alt="Multi-head scaled dot-product attention" /></p>
<p class="image-caption"><em>Fig. 1. Illustration of the multi-head scaled dot-product attention mechanism. (Image source: Figure 2 in <a href="https://arxiv.org/abs/1706.03762">Vaswani, et al., 2017</a>)</em></p>
<h2 id="transformer">Transformer</h2>
<p>The <strong>Transformer</strong> (which will be referred to as “vanilla Transformer” to distinguish it from other enhanced versions; <a href="https://arxiv.org/abs/1706.03762">Vaswani, et al., 2017</a>) model has an encoder-decoder architecture, as commonly used in many <a href="/lil-log/2018/06/24/attention-attention.html#born-for-translation">NMT</a> models. Later decoder-only Transformer was shown to achieve great performance in language modeling tasks, like in <a href="/lil-log/2019/01/31/generalized-language-models.html#openai-gpt">GPT and BERT</a>.</p>
<p><strong>Encoder-Decoder Architecture</strong></p>
<p>The <strong>encoder</strong> generates an attention-based representation with capability to locate a specific piece of information from a large context. It consists of a stack of 6 identity modules, each containing two submodules, a <em>multi-head self-attention</em> layer and a <em>point-wise</em> fully connected feed-forward network. By point-wise, it means that it applies the same linear transformation (with same weights) to each element in the sequence. This can also be viewed as a convolutional layer with filter size 1. Each submodule has a residual connection and layer normalization. All the submodules output data of the same dimension \(d\).</p>
<p>The function of Transformer <strong>decoder</strong> is to retrieve information from the encoded representation. The architecture is quite similar to the encoder, except that the decoder contains two multi-head attention submodules instead of one in each identical repeating module. The first multi-head attention submodule is <em>masked</em> to prevent positions from attending to the future.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer.png" alt="Transformer" /></p>
<p class="image-caption"><em>Fig. 2. The architecture of the vanilla Transformer model. (Image source: <a href="/lil-log/2018/06/24/attention-attention.html#full-architecture">Figure 17</a>)</em></p>
<p><strong>Positional Encoding</strong></p>
<p>Because self-attention operation is permutation invariant, it is important to use proper <strong>positional encoding</strong>to provide <em>order information</em> to the model. The positional encoding \(\mathbf{P} \in \mathbb{R}^{L \times d}\) has the same dimension as the input embedding, so it can be added on the input directly. The vanilla Transformer considered two types of encodings:</p>
<p>(1) <em>Sinusoidal positional encoding</em> is defined as follows, given the token position \(i=1,\dots,L\) and the dimension \(\delta=1,\dots,d\):</p>
\[\text{PE}(i,\delta) =
\begin{cases}
\sin(\frac{i}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta'\\
\cos(\frac{i}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta' + 1\\
\end{cases}\]
<p>In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in different dimensions, from \(2\pi\) to \(10000 \cdot 2\pi\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/sinoidual-positional-encoding.png" alt="Transformer" /></p>
<p class="image-caption"><em>Fig. 3. Sinusoidal positional encoding with \(L=32\) and \(d=128\). The value is between -1 (black) and 1 (white) and the value 0 is in gray.</em></p>
<p>(2) <em>Learned positional encoding</em>, as its name suggested, assigns each element with a learned column vector which encodes its <em>absolute</em> position (<a href="https://arxiv.org/abs/1705.03122">Gehring, et al. 2017</a>).</p>
<p><strong>Quick Follow-ups</strong></p>
<p>Following the vanilla Transformer, <a href="https://arxiv.org/abs/1808.04444">Al-Rfou et al. (2018)</a> added a set of auxiliary losses to enable training a deep Transformer model on character-level language modeling which outperformed LSTMs. Several types of auxiliary tasks are used:</p>
<ul>
<li>Instead of producing only one prediction at the sequence end, every <em>immediate position</em> is also asked to make a correct prediction, forcing the model to predict given smaller contexts (e.g. first couple tokens at the beginning of a context window).</li>
<li>Each intermediate Transformer layer is used for making predictions as well. Lower layers are weighted to contribute less and less to the total loss as training progresses.</li>
<li>Each position in the sequence can predict multiple targets, i.e. two or more predictions of the future tokens.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer-aux-losses.png" alt="Transformer" /></p>
<p class="image-caption"><em>Fig. 4. Auxiliary prediction tasks used in deep Transformer for character-level language modeling. (Image source: <a href="https://arxiv.org/abs/1808.04444">Al-Rfou et al. (2018)</a>)</em></p>
<h2 id="adaptive-computation-time-act">Adaptive Computation Time (ACT)</h2>
<p><strong>Adaptive Computation Time</strong> (short for <strong>ACT</strong>; <a href="https://arxiv.org/abs/1603.08983">Graves, 2016</a>) is a mechanism for dynamically deciding how many computational steps are needed in a recurrent neural network. Here is a cool <a href="https://distill.pub/2016/augmented-rnns/#adaptive-computation-time">tutorial</a> on ACT from distill.pub.</p>
<p>Let’s say, we have a RNN model \(\mathcal{R}\) composed of input weights \(W_x\), a parametric state transition function \(\mathcal{S}(.)\), a set of output weights \(W_y\) and an output bias \(b_y\). Given an input sequence \((x_1, \dots, x_L)\), the output sequence \((y_1, \dots, y_L)\) is computed by:</p>
\[s_t = \mathcal{S}(s_{t-1}, W_x x_t), \quad y_t = W_y s_t + b_y\quad\text{for }t=1, \dots, L\]
<p>ACT enables the above RNN setup to perform a variable number of steps at each input element. Multiple computational steps lead to a sequence of intermediate states \((s_t^1, \dots, s_t^{N(t)})\) and outputs \((y_t^1, \dots, y_t^{N(t)})\) — they all share the same state transition function \(\mathcal{S}(.)\), as well as the same output weights \(W_y\) and bias \(b_y\):</p>
\[\begin{aligned}
s_t^0 &= s_{t-1} \\
s_t^n &= \mathcal{S}(s_{t}^{n-1}, x_t^n) = \mathcal{S}(s_{t}^{n-1}, x_t + \delta_{n,1}) \text{ for } n=1, \dots, N(t)\\
y_t^n &= W_y s_t^n + b_y
\end{aligned}\]
<p>where \(\delta_{n,1}\) is a binary flag indicating whether the input step has been incremented.</p>
<p>The number of steps \(N(t)\) is determined by an extra sigmoidal halting unit \(h\), with associated weight matrix \(W_h\) and bias \(b_h\), outputting a halting probability \(p_t^n\) at immediate step \(n\) for \(t\)-th input element:</p>
\[h_t^n = \sigma(W_h s_t^n + b_h)\]
<p>In order to allow the computation to halt after a single step, ACT introduces a small constant \(\epsilon\) (e.g. 0.01), so that whenever the cumulative probability goes above \(1-\epsilon\), the computation stops.</p>
\[\begin{aligned}
N(t) &= \min(\min\{n': \sum_{n=1}^{n'} h_t^n \geq 1 -\epsilon\}, M) \\
p_t^n &= \begin{cases}
h_t^n & \text{if }n < N(t) \\
R(t) = 1 - \sum_{n=1}^{N(t)-1} h_t^n & \text{if }n= N(t)\\
\end{cases}
\end{aligned}\]
<p>where \(M\) is an upper limit for the number of immediate steps allowed.</p>
<p>The final state and output are mean-field updates:</p>
\[s_t = \sum_{n=1}^{N(t)} p_t^n s_t^n,\quad y_t = \sum_{n=1}^{N(t)} p_t^n y_t^n\]
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/ACT-computation-graph.png" alt="ACT computation graph" /></p>
<p class="image-caption"><em>Fig. 5. The computation graph of a RNN with ACT mechanism. (Image source: <a href="https://arxiv.org/abs/1603.08983">Graves, 2016</a>)</em></p>
<p>To avoid unnecessary pondering over each input, ACT adds a <em>ponder cost</em> \(\mathcal{P}(x) = \sum_{t=1}^L N(t) + R(t)\) in the loss function to encourage a smaller number of intermediate computational steps.</p>
<h2 id="improved-attention-span">Improved Attention Span</h2>
<p>The goal of improving attention span is to make the context that can be used in self-attention longer, more efficient and flexible.</p>
<h3 id="longer-attention-span-transformer-xl">Longer Attention Span (Transformer-XL)</h3>
<p>The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments.</p>
<p>This <em>context segmentation</em> causes several issues:</p>
<ul>
<li>The model cannot capture very long term dependencies.</li>
<li>It is hard to predict the first few tokens in each segment given no or thin context.</li>
<li>The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens.</li>
</ul>
<p><strong>Transformer-XL</strong> (<a href="https://arxiv.org/abs/1901.02860">Dai et al., 2019</a>; “XL” means “extra long”) solves the context segmentation problem with two main modifications:</p>
<ol>
<li>Reusing hidden states between segments.</li>
<li>Adopting a new positional encoding that is suitable for reused states.</li>
</ol>
<p><strong>Hidden State Reuse</strong></p>
<p>The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer-XL-training.png" alt="Training phrase of Transformer-XL" /></p>
<p class="image-caption"><em>Fig. 6. A comparison between the training phrase of vanilla Transformer & Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in <a href="https://arxiv.org/abs/1901.02860">Dai et al., 2019</a>).</em></p>
<p>Let’s label the hidden state of the \(n\)-th layer for the \((\tau + 1)\)-th segment in the model as \(\mathbf{h}_{\tau+1}^{(n)} \in \mathbb{R}^{L \times d}\). In addition to the hidden state of the last layer for the same segment \(\mathbf{h}_{\tau+1}^{(n-1)}\), it also depends on the hidden state of the same layer for the previous segment \(\mathbf{h}_{\tau}^{(n)}\). By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments.</p>
\[\begin{aligned}
\color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} &= [\text{stop-gradient}(\mathbf{h}_{\tau}^{(n-1)}) \circ \mathbf{h}_{\tau+1}^{(n-1)}] \\
\mathbf{Q}_{\tau+1}^{(n)} &= \mathbf{h}_{\tau+1}^{(n-1)}\mathbf{W}^q \\
\mathbf{K}_{\tau+1}^{(n)} &= \color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} \mathbf{W}^k \\
\mathbf{V}_{\tau+1}^{(n)} &= \color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} \mathbf{W}^v \\
\mathbf{h}_{\tau+1}^{(n)} &= \text{transformer-layer}(\mathbf{Q}_{\tau+1}^{(n)}, \mathbf{K}_{\tau+1}^{(n)}, \mathbf{V}_{\tau+1}^{(n)})
\end{aligned}\]
<p>Note that both key and value rely on the extended hidden state, while the query only consumes hidden state at current step. The concatenation operation \([. \circ .]\) is along the sequence length dimension.</p>
<p><strong>Relative Positional Encoding</strong></p>
<p>In order to work with this new form of attention span, Transformer-XL proposed a new type of positional encoding. If using the same approach by vanilla Transformer and encoding the absolute position, the previous and current segments will be assigned with the same encoding, which is undesired.</p>
<p>To keep the positional information flow coherently across segments, Transformer-XL encodes the <em>relative</em> position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. \(i-j\), between one key vector \(\mathbf{k}_{\tau, j}\) and its query \(\mathbf{q}_{\tau, i}\).</p>
<p>If omitting the scalar \(1/\sqrt{d_k}\) and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position \(i\) and key at position \(j\) as:</p>
\[\begin{aligned}
a_{ij}
&= \mathbf{q}_i {\mathbf{k}_j}^\top = (\mathbf{x}_i + \mathbf{p}_i)\mathbf{W}^q ((\mathbf{x}_j + \mathbf{p}_j)\mathbf{W}^k)^\top \\
&= \mathbf{x}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{x}_j^\top + \mathbf{x}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{p}_j^\top + \mathbf{p}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{x}_j^\top + \mathbf{p}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{p}_j^\top
\end{aligned}\]
<p>Transformer-XL reparameterizes the above four terms as follows:</p>
\[a_{ij}^\text{rel} =
\underbrace{ \mathbf{x}_i\mathbf{W}^q \color{blue}{ {\mathbf{W}_E^k}^\top } \mathbf{x}_j^\top }_\text{content-based addressing} +
\underbrace{ \mathbf{x}_i\mathbf{W}^q \color{blue}{ {\mathbf{W}_R^k}^\top } \color{green}{\mathbf{r}_{i-j}^\top} }_\text{content-dependent positional bias} +
\underbrace{ \color{red}{\mathbf{u}} \color{blue}{ {\mathbf{W}_E^k}^\top } \mathbf{x}_j^\top }_\text{global content bias} +
\underbrace{ \color{red}{\mathbf{v}} \color{blue}{ {\mathbf{W}_R^k}^\top } \color{green}{\mathbf{r}_{i-j}^\top} }_\text{global positional bias}\]
<ul>
<li>Replace \(\mathbf{p}_j\) with relative positional encoding \(\mathbf{r}_{i-j} \in \mathbf{R}^{d}\);</li>
<li>Replace \(\mathbf{p}_i\mathbf{W}^q\) with two trainable parameters \(\mathbf{u}\) (for content) and \(\mathbf{v}\) (for location) in two different terms;</li>
<li>Split \(\mathbf{W}^k\) into two matrices, \(\mathbf{W}^k_E\) for content information and \(\mathbf{W}^k_R\) for location information.</li>
</ul>
<h3 id="adaptive-attention-span">Adaptive Attention Span</h3>
<p>One key advantage of Transformer is the capability of capturing long-term dependencies. Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other. If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.</p>
<p>This is the motivation for <strong>Adaptive Attention Span</strong>. <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al., (2019)</a> proposed a self-attention mechanism that seeks an optimal attention span. They hypothesized that different attention heads might assign scores differently within the same context window (See Fig. 7) and thus the optimal span would be trained separately per head.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/attention-per-head.png" alt="Attention per head" /></p>
<p class="image-caption"><em>Fig. 7. Two attention heads in the same model, A & B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al. 2019</a>)</em></p>
<p>Given the \(i\)-th token, we need to compute the attention weights between this token and other keys at positions \(j \in S_i\), where \(S_i\) defineds the \(i\)-th token’s context window.</p>
\[\begin{aligned}
e_{ij} &= \mathbf{q}_i {\mathbf{k}_j}^\top \\
a_{ij} &= \text{softmax}(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{r=i-s}^{i-1} \exp(e_{ir})} \\
\mathbf{y}_i &= \sum_{r=i-s}^{i-1}a_{ir}\mathbf{v}_r = \sum_{r=i-s}^{i-1}a_{ir}\mathbf{x}_r\mathbf{W}^v
\end{aligned}\]
<p>A <em>soft mask function</em> \(m_z\) is added to control for an effective adjustable attention span, which maps the distance between query and key into a [0, 1] value. \(m_z\) is parameterized by \(z \in [0, s]\) and \(z\) is to be learned:</p>
\[m_z(x) = \text{clamp}(\frac{1}{R}(R+z-x), 0, 1)\]
<p>where \(R\) is a hyper-parameter which defines the softness of \(m_z\).</p>
<p style="width: 55%;" class="center"><img src="/lil-log/assets/images/soft-masking-function.png" alt="Soft masking function" /></p>
<p class="image-caption"><em>Fig. 8. The soft masking function used in the adaptive attention span. (Image source: <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al. 2019</a>.)</em></p>
<p>The soft mask function is applied to the softmax elements in the attention weights:</p>
\[a_{ij} = \frac{m_z(i-j)\exp(s_{ij})}{\sum_{r=i-s}^{i-1}m_z(i-r) \exp(s_{ir})}\]
<p>In the above equation, \(z\) is differentiable so it is trained jointly with other parts of the model. Parameters \(z^{(i)}, i=1, \dots, h\) are learned <em>separately per head</em>. Moreover, the loss function has an extra L1 penalty on \(\sum_{i=1}^h z^{(i)}\).</p>
<p>Using <a href="#adaptive-computation-time-act">Adaptive Computation Time</a>, the approach can be further enhanced to have flexible attention span length, adaptive to the current input dynamically. The span parameter \(z_t\) of an attention head at time \(t\) is a sigmoidal function, \(z_t = S \sigma(\mathbf{v} \cdot \mathbf{x}_t +b)\), where the vector \(\mathbf{v}\) and the bias scalar \(b\) are learned jointly with other parameters.</p>
<p>In the experiments of Transformer with adaptive attention span, <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al. (2019)</a> found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans. Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.</p>
<h3 id="localized-attention-span-image-transformer">Localized Attention Span (Image Transformer)</h3>
<p>The original, also the most popular, use case for Transformer is to do language modeling. The text sequence is one-dimensional in a clearly defined chronological order and thus the attention span grows linearly with increased context size.</p>
<p>However, if we want to use Transformer on images, it is unclear how to define the scope of context or the order. <strong>Image Transformer</strong> (<a href="https://arxiv.org/abs/1802.05751">Parmer, et al 2018</a>) embraces a formulation of image generation similar to sequence modeling within the Transformer framework. Additionally, Image Transformer restricts the self-attention span to only <em>local</em> neighborhoods, so that the model can scale up to process more images in parallel and keep the likelihood loss tractable.</p>
<p>The encoder-decoder architecture remains for image-conditioned generation:</p>
<ul>
<li>The encoder generates a contextualized, per-pixel-channel representation of the source image;</li>
<li>The decoder <em>autoregressively</em> generates an output image, one channel per pixel at each time step.</li>
</ul>
<p>Let’s label the representation of the current pixel to be generated as the query \(\mathbf{q}\). Other positions whose representations will be used for computing \(\mathbf{q}\) are key vector \(\mathbf{k}_1, \mathbf{k}_2, \dots\) and they together form a memory matrix \(\mathbf{M}\). The scope of \(\mathbf{M}\) defines the context window for pixel query \(\mathbf{q}\).</p>
<p>Image Transformer introduced two types of localized \(\mathbf{M}\), as illustrated below.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/image-transformer-attention.png" alt="Attention patterns in Image Transformer" /></p>
<p class="image-caption"><em>Fig. 9. Illustration of 1D and 2D attention span for visual inputs in Image Transformer. The black line marks a query block and the cyan outlines the actual attention span for pixel q. (Image source: Figure 2 in <a href="https://arxiv.org/abs/1802.05751">Parmer et al, 2018</a>)</em></p>
<p>(1) <em>1D Local Attention</em>: The input image is flattened in the <a href="https://en.wikipedia.org/wiki/Raster_scan#Scanning_pattern">raster scanning</a> order, that is, from left to right and top to bottom. The linearized image is then partitioned into non-overlapping query blocks. The context window consists of pixels in the same query block as \(\mathbf{q}\) and a fixed number of additional pixels generated before this query block.</p>
<p>(2) <em>2D Local Attention</em>: The image is partitioned into multiple non-overlapping rectangular query blocks. The query pixel can attend to all others in the same memory blocks. To make sure the pixel at the top-left corner can also have a valid context window, the memory block is extended to the top, left and right by a fixed amount, respectively.</p>
<h2 id="less-time-and-memory-cost">Less Time and Memory Cost</h2>
<p>This section introduces several improvements made on Transformer to reduce the computation time and memory consumption.</p>
<h3 id="sparse-attention-matrix-factorization-sparse-transformers">Sparse Attention Matrix Factorization (Sparse Transformers)</h3>
<p>The compute and memory cost of the vanilla Transformer grows quadratically with sequence length and thus it is hard to be applied on very long sequences.</p>
<p><strong>Sparse Transformer</strong> (<a href="https://arxiv.org/abs/1904.10509">Child et al., 2019</a>) introduced <em>factorized self-attention</em>, through sparse matrix factorization, making it possible to train dense attention networks with hundreds of layers on sequence length up to 16,384, which would be infeasible on modern hardware otherwise.</p>
<p>Given a set of attention connectivity pattern \(\mathcal{S} = \{S_1, \dots, S_n\}\), where each \(S_i\) records a set of key positions that the \(i\)-th query vector attends to.</p>
\[\begin{aligned}
\text{Attend}(\mathbf{X}, \mathcal{S}) &= \Big( a(\mathbf{x}_i, S_i) \Big)_{i \in \{1, \dots, L\}} \\
\text{ where } a(\mathbf{x}_i, S_i) &= \text{softmax}\Big(\frac{(\mathbf{x}_i \mathbf{W}^q)(\mathbf{x}_j \mathbf{W}^k)_{j \in S_i}^\top}{\sqrt{d_k}}\Big) (\mathbf{x}_j \mathbf{W}^v)_{j \in S_i}
\end{aligned}\]
<p>Note that although the size of \(S_i\) is not fixed, \(a(\mathbf{x}_i, S_i)\) is always of size \(d_v\) and thus \(\text{Attend}(\mathbf{X}, \mathcal{S}) \in \mathbb{R}^{L \times d_v}\).</p>
<p>In anto-regressive models, one attention span is defined as \(S_i = \{j: j \leq i\}\) as it allows each token to attend to all the positions in the past.</p>
<p>In factorized self-attention, the set \(S_i\) is decomposed into a <em>tree</em> of dependencies, such that for every pair of \((i, j)\) where \(j \leq i\), there is a path connecting \(i\) back to \(j\) and \(i\) can attend to \(j\) either directly or indirectly.</p>
<p>Precisely, the set \(S_i\) is divided into \(p\) <em>non-overlapping</em> subsets, where the \(m\)-th subset is denoted as \(A^{(m)}_i \subset S_i, m = 1,\dots, p\). Therefore the path between the output position \(i\) and any \(j\) has a maximum length \(p + 1\). For example, if \((j, a, b, c, \dots, i)\) is a path of indices between \(i\) and \(j\), we would have \(j \in A_a^{(1)}, a \in A_b^{(2)}, b \in A_c^{(3)}, \dots\), so on and so forth.</p>
<p><strong>Sparse Factorized Attention</strong></p>
<p>Sparse Transformer proposed two types of fractorized attention. It is easier to understand the concepts as illustrated in Fig. 10 with 2D image inputs as examples.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/sparse-attention.png" alt="Sparse attention" /></p>
<p class="image-caption"><em>Fig. 10. The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: <a href="https://arxiv.org/abs/1904.10509">Child et al., 2019</a> + a few of extra annotations.)</em></p>
<p>(1) <em>Strided</em> attention with stride \(\ell \sim \sqrt{n}\). This works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous \(\ell\) pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).</p>
\[\begin{aligned}
A_i^{(1)} &= \{ t, t+1, \dots, i\} \text{, where } t = \max(0, i - \ell) \\
A_i^{(2)} &= \{j: (i-j) \mod \ell = 0\}
\end{aligned}\]
<p>(2) <em>Fixed</em> attention. A small set of tokens summarize previous locations and propagate that information to all future locations.</p>
\[\begin{aligned}
A_i^{(1)} &= \{j: \lfloor \frac{j}{\ell} \rfloor = \lfloor \frac{i}{\ell} \rfloor \} \\
A_i^{(2)} &= \{j: j \mod \ell \in \{\ell-c, \dots, \ell-1\} \}
\end{aligned}\]
<p>where \(c\) is a hyperparameter. If \(c=1\), it restricts the representation whereas many depend on a few positions. The paper chose \(c\in \{ 8, 16, 32 \}\) for \(\ell \in \{ 128, 256 \}\).</p>
<p><strong>Use Factorized Self-Attention in Transformer</strong></p>
<p>There are three ways to use sparse factorized attention patterns in Transformer architecture:</p>
<ol>
<li>One attention type per residual block and then interleave them, <br />
\(\text{attention}(\mathbf{X}) = \text{Attend}(\mathbf{X}, A^{(n \mod p)}) \mathbf{W}^o\), where \(n\) is the index of the current residual block.</li>
<li>Set up a single head which attends to locations that all the factorized heads attend to, <br />
\(\text{attention}(\mathbf{X}) = \text{Attend}(\mathbf{X}, \cup_{m=1}^p A^{(m)}) \mathbf{W}^o\).</li>
<li>Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2. => This option often performs the best.</li>
</ol>
<p>Sparse Transformer also proposed a set of changes so as to train the Transformer up to hundreds of layers, including gradient checkpointing, recomputing attention & FF layers during the backward pass, mixed precision training, efficient block-sparse implementation, etc. Please check the <a href="https://arxiv.org/abs/1904.10509">paper</a> for more details.</p>
<h3 id="locality-sensitive-hashing-reformer">Locality-Sensitive Hashing (Reformer)</h3>
<p>The improvements proposed by the <strong>Reformer</strong> model (<a href="https://arxiv.org/abs/2001.04451">Kitaev, et al. 2020</a>) aim to solve the following pain points in Transformer:</p>
<ul>
<li>Memory in a model with \(N\) layers is \(N\)-times larger than in a single-layer model because we need to store activations for back-propagation.</li>
<li>The intermediate FF layers are often quite large.</li>
<li>The attention matrix on sequences of length \(L\) often requires \(O(L^2)\) in both memory and time.</li>
</ul>
<p>Reformer proposed two main changes:</p>
<ol>
<li>Replace the dot-product attention with <em>locality-sensitive hashing (LSH) attention</em>, reducing the complexity from \(O(L^2)\) to \(O(L\log L)\).</li>
<li>Replace the standard residual blocks with <em>reversible residual layers</em>, which allows storing activations only once during training instead of \(N\) times (i.e. proportional to the number of layers).</li>
</ol>
<p><a name="LSH"></a><strong>Locality-Sensitive Hashing Attention</strong></p>
<p>In \(\mathbf{Q} \mathbf{K}^\top\) part of the <a href="#attention-and-self-attention">attention formula</a>, we are only interested in the largest elements as only large elements contribute a lot after softmax. For each query \(\mathbf{q}_i \in \mathbf{Q}\), we are looking for row vectors in \(\mathbf{K}\) closest to \(\mathbf{q}_i\). In order to find nearest neighbors quickly in high-dimensional space, Reformer incorporates <a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing">Locality-Sensitive Hashing (LSH)</a> into its attention mechanism.</p>
<p>A hashing scheme \(x \mapsto h(x)\) is <em>locality-sensitive</em> if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. The Reformer adopts a hashing scheme as such, given a fixed random matrix \(\mathbf{R} \in \mathbb{R}^{d \times b/2}\) (where \(b\) is a hyperparam), the hash function is \(h(x) = \arg\max([xR; −xR])\).</p>
<!-- If we omit the scalar in self-attention and summarize the denominator into a normalizing term $$Z(.)$$, an normal attention output looks as follows:
$$
\mathbf{o}_i = \sum_{j \in S_i} \exp(\mathbf{q}_i \cdot \mathbf{k}_j - Z(i, S_i)) \mathbf{v}_j \text{, where } S_i = \{j: j \leq i\}
$$
-->
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/LSH-attention-matrix.png" alt="LSH attention matrix" /></p>
<p class="image-caption"><em>Fig. 11. Illustration of Locality-Sensitive Hashing (LSH) attention. (Image source: right part of Figure 1 in <a href="https://arxiv.org/abs/2001.04451">Kitaev, et al. 2020</a>).</em></p>
<p>In LSH attention, a query can only attend to positions in the same hashing bucket, \(S_i = \{j: h(\mathbf{q}_i) = h(\mathbf{k}_j)\}\). It is carried out in the following process, as illustrated in Fig. 11:</p>
<ul>
<li>(a) The attention matrix for full attention is often sparse.</li>
<li>(b) Using LSH, we can sort the keys and queries to be aligned according to their hash buckets.</li>
<li>(c) Set \(\mathbf{Q} = \mathbf{K}\) (precisely \(\mathbf{k}_j = \mathbf{q}_j / \|\mathbf{q}_j\|\)), so that there are equal numbers of keys and queries in one bucket, easier for batching. Interestingly, this “shared-QK” config does not affect the performance of the Transformer.</li>
<li>(d) Apply batching where chunks of \(m\) consecutive queries are grouped together.</li>
</ul>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/LSH-attention.png" alt="LSH attention" /></p>
<p class="image-caption"><em>Fig. 12. The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in <a href="https://arxiv.org/abs/2001.04451">Kitaev, et al. 2020</a>).</em></p>
<p><strong>Reversible Residual Network</strong></p>
<p>Another improvement by Reformer is to use <em>reversible residual layers</em> (<a href="https://arxiv.org/abs/1707.04585">Gomez et al. 2017</a>). The motivation for reversible residual network is to design the architecture in a way that activations at any given layer can be recovered from the activations at the following layer, using only the model parameters. Hence, we can save memory by recomputing the activation during backprop rather than storing all the activations.</p>
<p>Given a layer \(x \mapsto y\), the normal residual layer does \(y = x + F(x)\), but the reversible layer splits both input and output into pairs \((x_1, x_2) \mapsto (y_1, y_2)\) and then executes the following:</p>
\[y_1 = x_1 + F(x_2),\; y_2 = x_2 + G(y_1)\]
<p>and reversing is easy:</p>
\[x_2 = y_2 - G(y_1), \; x_1 = y_1 − F(x_2)\]
<p>Reformer applies the same idea to Transformer by combination attention (\(F\)) and feed-forward layers (\(G\)) within a reversible net block:</p>
\[Y_1 = X_1 + \text{Attention}(X_2), \; Y_2 = X_2 + \text{FeedForward}(Y_1)\]
<p>The memory can be further reduced by chunking the feed-forward computation:
\(Y_2 = [Y_2^{(1)}; \dots; Y_2^{(c)}] = [X_2^{(1)} + \text{FeedForward}(Y_1^{(1)}); \dots; X_2^{(c)} + \text{FeedForward}(Y_1^{(c)})]\)</p>
<p>The resulting reversible Transformer does not need to store activation in every layer.</p>
<h2 id="make-it-recurrent-universal-transformer">Make it Recurrent (Universal Transformer)</h2>
<p>The <strong>Universal Transformer</strong> (<a href="https://arxiv.org/abs/1807.03819">Dehghani, et al. 2019</a>) combines self-attention in Transformer with the recurrent mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned inductive biases of RNN.</p>
<p>Rather than going through a fixed number of layers, Universal Transformer dynamically adjusts the number of steps using <a href="#adaptive-computation-time-act">adaptive computation time</a>. If we fix the number of steps, an Universal Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.</p>
<p>On a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state representation per token. The recurrent function evolves in parallel across token positions and the information between positions is shared through self-attention.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/universal-transformer-loop.png" alt="Universal Transformer Recurrent Step" /></p>
<p class="image-caption"><em>Fig. 13. How the Universal Transformer refines a set of hidden state representations repeatedly for every position in parallel. (Image source: Figure 1 in <a href="https://arxiv.org/abs/1807.03819">Dehghani, et al. 2019</a>).</em></p>
<p>Given an input sequence of length \(L\), Universal Transformer iteratively updates the representation \(\mathbf{H}^t \in \mathbb{R}^{L \times d}\) at step \(t\) for an adjustable number of steps. At step 0, \(\mathbf{H}^0\) is initialized to be same as the input embedding matrix. All the positions are processed in parallel in the multi-head self-attention mechanism and then go through a recurrent transition function.</p>
\[\begin{aligned}
\mathbf{A}^t &= \text{LayerNorm}(\mathbf{H}^{t-1} + \text{MultiHeadAttention}(\mathbf{H}^{t-1} + \mathbf{P}^t) \\
\mathbf{H}^t &= \text{LayerNorm}(\mathbf{A}^{t-1} + \text{Transition}(\mathbf{A}^t))
\end{aligned}\]
<p>where \(\text{Transition}(.)\) is either a <a href="https://arxiv.org/abs/1610.02357">separable convolution</a> or a fully-connected neural network that consists of two position-wise (i.e. applied to each row of \(\mathbf{A}^t\) individually) affine transformation + one ReLU.</p>
<p>The positional encoding \(\mathbf{P}^t\) uses sinusoidal position signal but with an additional time dimension:</p>
\[\text{PE}(i, t, \delta) =
\begin{cases}
\sin(\frac{i}{10000^{2\delta'/d}}) \oplus \sin(\frac{t}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta'\\
\cos(\frac{i}{10000^{2\delta'/d}}) \oplus \cos(\frac{t}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta' + 1\\
\end{cases}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/universal-transformer.png" alt="Universal Transformer" /></p>
<p class="image-caption"><em>Fig. 14. A simplified illustration of Universal Transformer. The encoder and decoder share the same basic recurrent structure. But the decoder also attends to final encoder representation \(\mathbf{H}^T\). (Image source: Figure 2 in <a href="https://arxiv.org/abs/1807.03819">Dehghani, et al. 2019</a>)</em></p>
<p>In the adaptive version of Universal Transformer, the number of recurrent steps \(T\) is dynamically determined by <a href="#adaptive-computation-time-act">ACT</a>. Each position is equipped with a dynamic ACT halting mechanism. Once a per-token recurrent block halts, it stops taking more recurrent updates but simply copies the current value to the next step until all the blocks halt or until the model reaches a maximum step limit.</p>
<h2 id="stabilization-for-rl-gtrxl">Stabilization for RL (GTrXL)</h2>
<p>The self-attention mechanism avoids compressing the whole past into a fixed-size hidden state and does not suffer from vanishing or exploding gradients as much as RNNs. Reinforcement Learning tasks can for sure benefit from these traits. <em>However</em>, it is quite difficult to train Transformer even in supervised learning, let alone in the RL context. It could be quite challenging to stabilize and train a LSTM agent by itself, after all.</p>
<p>The <strong>Gated Transformer-XL</strong> (<strong>GTrXL</strong>; <a href="https://arxiv.org/abs/1910.06764">Parisotto, et al. 2019</a>) is one attempt to use Transformer for RL. GTrXL succeeded in stabilizing training with two changes on top of <a href="#longer-attention-span-transformer-xl">Transformer-XL</a>:</p>
<ol>
<li>The layer normalization is only applied on the input stream in a residual module, but NOT on the shortcut stream. A key benefit to this reordering is to allow the original input to flow from the first to last layer.</li>
<li>The residual connection is replaced with a GRU-style (Gated Recurrent Unit; <a href="https://arxiv.org/abs/1412.3555">Chung et al., 2014</a>) <em>gating</em> mechanism.</li>
</ol>
\[\begin{aligned}
r &= \sigma(W_r^{(l)} y + U_r^{(l)} x) \\
z &= \sigma(W_z^{(l)} y + U_z^{(l)} x - b_g^{(l)}) \\
\hat{h} &= \tanh(W_g^{(l)} y + U_g^{(l)} (r \odot x)) \\
g^{(l)}(x, y) &= (1-z)\odot x + z\odot \hat{h}
\end{aligned}\]
<p>The gating function parameters are explicitly initialized to be close to an identity map - this is why there is a \(b_g\) term. A \(b_g > 0\) greatly helps with the learning speedup.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/gated-transformer-XL.png" alt="GTrXL" /></p>
<p class="image-caption"><em>Fig. 15. Comparison of the model architecture of Transformer-XL, Transformer-XL with the layer norm reordered, and Gated Transformer-XL. (Image source: Figure 1 in <a href="https://arxiv.org/abs/1910.06764">Parisotto, et al. 2019</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020transformer,
title = "The Transformer Family",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/03/27/the-transformer-family.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Ashish Vaswani, et al. <a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">“Attention is all you need.”</a> NIPS 2017.</p>
<p>[2] Rami Al-Rfou, et al. <a href="https://arxiv.org/abs/1808.04444">“Character-level language modeling with deeper self-attention.”</a> AAAI 2019.</p>
<p>[3] Olah & Carter, <a href="http://doi.org/10.23915/disti">“Attention and Augmented Recurrent Neural Networks”</a>, Distill, 2016.</p>
<p>[4] Sainbayar Sukhbaatar, et al. <a href="https://arxiv.org/abs/1905.07799">“Adaptive Attention Span in Transformers”</a>. ACL 2019.</p>
<p>[5] Rewon Child, et al. <a href="https://arxiv.org/abs/1904.10509">“Generating Long Sequences with Sparse Transformers”</a> arXiv:1904.10509 (2019).</p>
<p>[6] Nikita Kitaev, et al. <a href="https://arxiv.org/abs/2001.04451">“Reformer: The Efficient Transformer”</a> ICLR 2020.</p>
<p>[7] Alex Graves. (“Adaptive Computation Time for Recurrent Neural Networks”)[https://arxiv.org/abs/1603.08983]</p>
<p>[8] Niki Parmar, et al. <a href="https://arxiv.org/abs/1802.05751">“Image Transformer”</a> ICML 2018.</p>
<p>[9] Zihang Dai, et al. <a href="https://arxiv.org/abs/1901.02860">“Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.”</a> ACL 2019.</p>
<p>[10] Aidan N. Gomez, et al. <a href="https://arxiv.org/abs/1707.04585">“The Reversible Residual Network: Backpropagation Without Storing Activations”</a> NIPS 2017.</p>
<p>[11] Mostafa Dehghani, et al. <a href="https://arxiv.org/abs/1807.03819">“Universal Transformers”</a> ICLR 2019.</p>
<p>[12] Emilio Parisotto, et al. <a href="https://arxiv.org/abs/1910.06764">“Stabilizing Transformers for Reinforcement Learning”</a> arXiv:1910.06764 (2019).</p>Lilian WengInspired by recent progress on various enhanced versions of Transformer models, this post presents how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving, etc.