Jekyll2020-11-12T09:21:18+00:00https://lilianweng.github.io/lil-log/feed.xmlLil’LogDocument my learning notes.Lilian WengHow to Build an Open-Domain Question Answering System?2020-10-29T12:00:00+00:002020-10-29T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/10/29/open-domain-question-answering<blockquote>
<p>A model that is capable of answering any question with regard to factual knowledge can enable many useful applications. This post delves into how we can build an Open-Domain Question Answering (ODQA) system, assuming we have access to a powerful pretrained language model. Both closed-book and open-book approachs are discussed.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2020-11-12: add <a href="#openai-api-example">an example</a> on closed-book factual QA using OpenAI API (beta).</span></p>
<p>A model that can answer any question with regard to factual knowledge can lead to many useful and practical applications, such as working as a chatbot or an AI assistant🤖. In this post, we will review several common approaches for building such an open-domain question answering system.</p>
<p>Disclaimers given so many papers in the wild:</p>
<ul>
<li>Assume we have access to a powerful pretrained <a href="/lil-log/2019/01/31/generalized-language-models.html">language model</a>.</li>
<li>We do not cover how to use structured knowledge base (e.g. Freebase, WikiData) here.</li>
<li>We only focus on a single-turn QA instead of a multi-turn conversation style QA.</li>
<li>We mostly focus on QA models that contain neural networks, specially Transformer-based language models.</li>
<li>I admit that I missed a lot of papers with architectures designed specifically for QA tasks between 2017-2019😔</li>
</ul>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#what-is-open-domain-question-answering" id="markdown-toc-what-is-open-domain-question-answering">What is Open-Domain Question Answering?</a> <ul>
<li><a href="#notation" id="markdown-toc-notation">Notation</a></li>
<li><a href="#concerns-of-qa-data-fine-tuning" id="markdown-toc-concerns-of-qa-data-fine-tuning">Concerns of QA data fine-tuning</a></li>
</ul>
</li>
<li><a href="#open-book-qa-retriever-reader" id="markdown-toc-open-book-qa-retriever-reader">Open-book QA: Retriever-Reader</a> <ul>
<li><a href="#retriever-model" id="markdown-toc-retriever-model">Retriever Model</a> <ul>
<li><a href="#classic-ir" id="markdown-toc-classic-ir">Classic IR</a></li>
<li><a href="#neural-ir" id="markdown-toc-neural-ir">Neural IR</a></li>
</ul>
</li>
<li><a href="#reader-model" id="markdown-toc-reader-model">Reader Model</a> <ul>
<li><a href="#bi-directional-lstm" id="markdown-toc-bi-directional-lstm">Bi-directional LSTM</a></li>
<li><a href="#bert-universe" id="markdown-toc-bert-universe">BERT-universe</a></li>
</ul>
</li>
<li><a href="#end-to-end-joint-training" id="markdown-toc-end-to-end-joint-training">End-to-end Joint Training</a></li>
</ul>
</li>
<li><a href="#open-book-qa-retriever-generator" id="markdown-toc-open-book-qa-retriever-generator">Open-book QA: Retriever-Generator</a></li>
<li><a href="#closed-book-qa-generative-language-model" id="markdown-toc-closed-book-qa-generative-language-model">Closed-book QA: Generative Language Model</a></li>
<li><a href="#related-techniques" id="markdown-toc-related-techniques">Related Techniques</a> <ul>
<li><a href="#fast-maximum-inner-product-search-mips" id="markdown-toc-fast-maximum-inner-product-search-mips">Fast Maximum Inner Product Search (MIPS)</a></li>
<li><a href="#language-model-pre-training" id="markdown-toc-language-model-pre-training">Language Model Pre-training</a></li>
</ul>
</li>
<li><a href="#summary" id="markdown-toc-summary">Summary</a></li>
<li><a href="#appendix-qa-datasets" id="markdown-toc-appendix-qa-datasets">Appendix: QA Datasets</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="what-is-open-domain-question-answering">What is Open-Domain Question Answering?</h2>
<p><strong>Open-domain Question Answering (ODQA)</strong> is a type of language tasks, asking a model to produce answers to factoid questions in natural language. The true answer is objective, so it is simple to evaluate model performance.</p>
<p>For example,</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Question: What did Albert Einstein win the Nobel Prize for?
Answer: The law of the photoelectric effect.
</code></pre></div></div>
<p>The “open-domain” part refers to the lack of the relevant context for any arbitrarily asked factual question. In the above case, the model only takes as the input the question but no article about “why Einstein didn’t win a Nobel Prize for the theory of relativity” is provided, where the term “the law of the photoelectric effect” is likely mentioned. In the case when both the question and the context are provided, the task is known as <strong>Reading comprehension (RC)</strong>.</p>
<p>An ODQA model may work with or without <em>access to an external source of knowledge</em> (e.g. Wikipedia) and these two conditions are referred to as <em>open-book</em> or <em>closed-book</em> question answering, respectively.</p>
<p>When considering different types of open-domain questions, I like the classification by <a href="https://arxiv.org/abs/2008.02637">Lewis, et al., 2020</a>, in increasing order of difficulty:</p>
<ol>
<li>A model is able to correctly memorize and respond with the answer to a question that has been seen at training time.</li>
<li>A model is able to answer novel questions at test time and choose an answer from the set of answers it has seen during training.</li>
<li>A model is able to answer novel questions which have answers not contained in the training dataset.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/QA-summary.png" alt="QA-summary" /></p>
<p><em>Fig. 1. Overview of three frameworks discussed in this post.</em></p>
<h3 id="notation">Notation</h3>
<p>Given a question \(x\) and a ground truth answer span \(y\), the context passage containing the true answer is labelled as \(z \in \mathcal{Z}\), where \(\mathcal{Z}\) is an external knowledge corpus. Wikipedia is a common choice for such an external knowledge source.</p>
<h3 id="concerns-of-qa-data-fine-tuning">Concerns of QA data fine-tuning</h3>
<p>Before we dive into the details of many models below. I would like to point out one concern of fine-tuning a model with common QA datasets, which appears as one fine-tuning step in several ODQA models. It could be concerning, because there is a significant overlap between questions in the train and test sets in several public QA datasets.</p>
<p><a href="https://arxiv.org/abs/2008.02637">Lewis, et al., (2020)</a> (<a href="https://github.com/facebookresearch/QA-Overlap">code</a>) found that 58-71% of test-time answers are also present somewhere in the training sets and 28-34% of test-set questions have a near-duplicate paraphrase in their corresponding training sets. In their experiments, several models performed notably worse when duplicated or paraphrased questions were removed from the training set.</p>
<h2 id="open-book-qa-retriever-reader">Open-book QA: Retriever-Reader</h2>
<p>Given a factoid question, if a language model has no context or is not big enough to memorize the context which exists in the training dataset, it is unlikely to guess the correct answer. In an open-book exam, students are allowed to refer to external resources like notes and books while answering test questions. Similarly, a ODQA system can be paired with a rich knowledge base to identify relevant documents as evidence of answers.</p>
<p>We can decompose the process of finding answers to given questions into two stages,</p>
<ol>
<li>Find the related context in an external repository of knowledge;</li>
<li>Process the retrieved context to <em>extract</em> an answer.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/QA-retriever-reader.png" alt="retriever + reader QA system" /></p>
<p><em>Fig. 2. The retriever-reader QA framework combines information retrieval with machine reading comprehension.</em></p>
<p>Such a retriever + reader framework was first proposed in <strong>DrQA</strong> (“Document retriever Question-Answering” by <a href="https://arxiv.org/abs/1704.00051">Chen et al., 2017</a>; <a href="https://github.com/facebookresearch/DrQA">code</a>). The retriever and the reader components can be set up and trained independently, or jointly trained <a href="#end-to-end-joint-training">end-to-end</a>.</p>
<h3 id="retriever-model">Retriever Model</h3>
<p>Two popular approaches for implementing the retriever is to use the information retrieval (IR) system that depends on (1) the classic non-learning-based <a href="https://en.wikipedia.org/wiki/Tf%E2%80%93idf">TF-IDF</a> features (“classic IR”) or (2) dense embedding vectors of text produced by neural networks (“neural IR”).</p>
<h4 id="classic-ir">Classic IR</h4>
<p><strong>DrQA</strong> (<a href="https://arxiv.org/abs/1704.00051">Chen et al., 2017</a>) adopts an efficient non-learning-based search engine based on the <a href="https://en.wikipedia.org/wiki/Vector_space_model">vector space model</a>. Every query and document is modelled as a bag-of-word vector, where each term is weighted by TF-IDF (term frequency \(\times\) inverse document frequency).</p>
\[\begin{aligned}
\text{tf-idf}(t, d, \mathcal{D}) &= \text{tf}(t, d) \times \text{idf}(t, \mathcal{D}) \\
\text{tf}(t, d) &= \log(1 + \text{freq}(t, d)) \\
\text{idf}(t, \mathcal{D}) &= \log \Big( \frac{\vert\mathcal{D}\vert}{\vert d\in\mathcal{D}: t\in d\vert} \Big)
\end{aligned}\]
<p>where \(t\) is a unigram or bigram term in a document \(d\) from a collection of documents \(\mathcal{D}\) . \(\text{freq}(t, d)\) measures how many times a term \(t\) appears in \(d\). Note that the term-frequency here includes bigram counts too, which is found to be very helpful because the local word order is taken into consideration via bigrams. As part of the implementation, DrQA maps the bigrams of \(2^{24}\) bins using unsigned murmur3 hash.</p>
<p>Precisely, DrQA implemented Wikipedia as its knowledge source and this choice has became a default setting for many ODQA studies since then. The non-ML document retriever returns the top \(k=5\) most relevant Wikipedia articles given a question.</p>
<p><strong>BERTserini</strong> (<a href="https://arxiv.org/abs/1902.01718">Yang et al., 2019</a>) pairs the open-source <a href="https://github.com/castorini/anserini"><em>Anserini</em></a> IR toolkit as the retriever with a fine-tuned pre-trained BERT model as the reader. The top \(k\) documents (\(k=10\)) are retrieved via the <code class="language-plaintext highlighter-rouge">post-v3.0</code> branch of Anserini with the query treated as a bag of words. The retrieved text segments are ranked by <a href="https://en.wikipedia.org/wiki/Okapi_BM25">BM25</a>, a classic TF-IDF-based retrieval scoring function. In terms of the effect of text granularity on performance, they found that paragraph retrieval > sentence retrieval > article retrieval.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/BERTserini-arch.png" alt="BERTserini" /></p>
<p><em>Fig. 3. An illustration of BERTserini architecture. (Image source: <a href="https://arxiv.org/abs/1902.01718">Yang et al., 2019</a>)</em></p>
<p><em>ElasticSearch + BM25</em> is used by the <strong>Multi-passage BERT</strong> QA model (<a href="https://arxiv.org/abs/1908.08167">Wang et al., 2019</a>). They found that splitting articles into passages with the length of 100 words by <em>sliding window</em> brings 4% improvements, since splitting documents into passages without overlap may cause some near-boundary evidence to lose useful contexts.</p>
<h4 id="neural-ir">Neural IR</h4>
<p>There is a long history in learning a low-dimensional representation of text, denser than raw term-based vectors (<a href="http://lsa.colorado.edu/papers/JASIS.lsi.90.pdf">Deerwester et al., 1990</a>; <a href="https://www.aclweb.org/anthology/W11-0329/">Yih, et al., 2011</a>). Dense representations can be learned through matrix decomposition or some neural network architectures (e.g. MLP, LSTM, bidirectional LSTM, etc). When involving neural networks, such approaches are referred to as “Neural IR”, Neural IR is a new category of methods for retrieval problems, but it is not necessary to perform better/superior than classic IR (<a href="https://sigir.org/wp-content/uploads/2019/01/p040.pdf">Lim, 2018</a>).</p>
<p>After the success of many large-scale <a href="/lil-log/2019/01/31/generalized-language-models.html">general language models</a>, many QA models embrace the following approach:</p>
\[h_x = E_x(x)\quad
h_z = E_z(z)\quad
\text{score}(x, z) = h_x^\top h_z\]
<ol>
<li>Extract the dense representations of a question \(x\) and a context passage \(z\) by feeding them into a language model;</li>
<li>Use the dot-product of these two representations as the retrieval score to rank and select most relevant passages.</li>
</ol>
<p>ORQA, REALM and DPR all use such a scoring function for context retrieval, which will be described in detail in a <a href="#end-to-end-joint-training">later section</a> on the end-to-end QA model.</p>
<p>An extreme approach, investigated by <strong>DenSPI</strong> (“Dense-Sparse Phrase Index”; <a href="https://arxiv.org/abs/1906.05807">Seo et al., 2019</a>), is to encode all the text in the knowledge corpus at the <em>phrase</em> level and then only rely on the retriever to identify the most relevant phrase as the predicted answer. In this way, the retriever+reader pipeline is reduced to only retriever. Of course, the index would be much larger and the retrieval problem is more challenging.</p>
<p>DenSPI introduces a <em>query-agnostic</em> indexable representation of document phrases. Precisely it encodes query-agnostic representations of text spans in Wikipedia offline and looks for the answer at inference time by performing nearest neighbor search. It can drastically speed up the inference time, because there is no need to re-encode documents for every new query, which is often required by a reader model.</p>
<p>Given a question \(x\) and a fixed set of (Wikipedia) documents, \(z_1, \dots, z_K\) and each document \(z_k\) contains \(N_k\) words, \(z_k = \langle z_k^{(1)}, \dots, z_k^{(N_k)}\rangle\). An ODQA model is a scoring function \(F\) for each candidate phrase span \(z_k^{(i:j)}, 1 \leq i \leq j \leq N_k\), such that the truth answer is the phrase with maximum score: \(y = {\arg\max}_{k,i,j} F(x, z_k^{(i:j)})\).</p>
<p>The phrase representation \(z_k^{(i:j)}\) combines both dense and sparse vectors, \(z_k^{(i:j)} = [d_k^{(i:j)}, s_k^{(i:j)}] \in \mathbb{R}^{d^d + d^s}\) (note that \(d^d \ll d^s\)):</p>
<ul>
<li>The dense vector \(d_k^{(i:j)}\) is effective for encoding local <em>syntactic</em> and <em>semantic</em> cues, as what can be learned by a pretrained language model.</li>
<li>The sparse vector \(s_k^{(i:j)}\) is superior at encoding precise <em>lexical</em> information. The sparse vector is term-frequency-based encoding. DenSPI uses 2-gram term-frequency same as DrQA, resulting a highly sparse representation (\(d^s \approx 16\)M)</li>
</ul>
<p>The dense vector \(d^{(i:j)}\) is further decomposed into three parts, \(d^{(i:j)} = [a_i, b_j, c_{ij}] \in \mathbb{R}^{2d^b + 1}\) where \(2d^b + 1 = d^d\). All three components are learned based on different columns of the fine-tuned BERT representations.</p>
<ul>
<li>A vector \(a_i\) encodes the <em>start</em> position for the \(i\)-th word of the document;</li>
<li>A vector \(b_j\) encodes the <em>end</em> position for the \(j\)-th word of the document;</li>
<li>A scalar \(c_{ij}\) measures the <em>coherency</em> between the start and the end vectors, helping avoid non-constituent phrases during inference.</li>
</ul>
<p>For all possible \((i,j,k)\) tuples where \(j-i < J\), the text span embeddings are precomputed and stored as a <em>phrase index</em>. The maximum span length \(J\) is a predefined scalar constant.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/DenSPI-arch.png" alt="DenseSPI" /></p>
<p><em>Fig. 4. An illustration of Dense-Sparse Phrase Index (DenSPI) architecture. (Image source: <a href="https://arxiv.org/abs/1906.05807">Seo et al., 2019</a>)</em></p>
<p>At the inference time, the question is mapped into the same vector space \(x=[d', s'] \in \mathbb{R}^{d^d + d^s}\), where the dense vector \(d'\) is extracted from the BERT embedding of the special <code class="language-plaintext highlighter-rouge">[CLS]</code> symbol. The same BERT model is shared for encoding both questions and phrases. The final answer is predicted by \(k^*, i^*, j^* = \arg\max x^\top z_k^{(i:j)}\).</p>
<h3 id="reader-model">Reader Model</h3>
<p>The reader model learns to solve the reading comprehension task — extract an answer for a given question from a given context document. Here we only discuss approaches for machine comprehension using neural networks.</p>
<h4 id="bi-directional-lstm">Bi-directional LSTM</h4>
<p>The reader model for answer detection of <strong>DrQA</strong> (<a href="https://arxiv.org/abs/1704.00051">Chen et al., 2017</a>) is a 3-layer bidirectional LSTM with hidden size 128. Every relevant paragraph of retrieved Wikipedia articles is encoded by a sequence of feature vector, \(\{\tilde{\mathbf{z}}_1, \dots, \tilde{\mathbf{z}}_m \}\). Each feature vector \(\hat{\mathbf{z}}_i \in \mathbb{R}^{d_z}\) is expected to capture useful contextual information around one token \(z_i\). The feature consists of several categories of features:</p>
<ol>
<li>Word embeddings: A 300d <a href="/lil-log/2017/10/15/learning-word-embedding.html#glove-global-vectors">Glove</a> word embedding trained from 800B Web crawl data, \(f_\text{embed} = E_g(z_i)\).</li>
<li>Exact match: Whether a word \(z_i\) appears in the question \(x\), \(f_\text{match} = \mathbb{I}(z_i \in x)\).</li>
<li>Token features: This includes POS (part-of-speech) tagging, NER (named entity recognition), and TF (term-frequency), \(f_\text{token}(z_i) = (\text{POS}(z_i), \text{NER}(z_i), \text{TF}(z_i))\).</li>
<li>Aligned question embedding: The attention score \(y_{ij}\) is designed to capture inter-sentence matching and similarity between the paragraph token \(z_i\) and the question word \(x_j\). This feature adds soft alignments between similar but non-identical words.</li>
</ol>
\[\begin{aligned}
f_\text{align}(z_i) &= \sum_j y_{i,j} E_g(x_j) \\
y_{i,j} &= \frac{\exp(\alpha(E_g(z_i))^\top \alpha(E_g(x_j)) )}{\sum_{j'} \exp(\alpha(E_g(z_i))^\top \alpha(E_g(x_{j'})) ) }
\end{aligned}\]
<p>where \(\alpha\) is a single dense layer with ReLU and \(E_g(.)\) is the glove word embedding.</p>
<p>The feature vector of a paragraph of \(m\) tokens is fed into LSTM to obtain the final paragraph vectors:</p>
\[\begin{aligned}
\mathbf{z} = \{\mathbf{z}_1, \dots, \mathbf{z}_m\} &= \text{LSTM}(\{\tilde{\mathbf{z}}_1, \dots, \tilde{\mathbf{z}}_m\}) \\
\text{where } \tilde{\mathbf{z}}_i &= \{f_\text{embed}, f_\text{match}, f_\text{token}, f_\text{align}\}
\end{aligned}\]
<p>The question is encoded as a weighted sum of the embeddings of every word in the question:</p>
\[\mathbf{x} = \sum_j b_j E(x_j) \quad b_j = \text{softmax}(\mathbf{w}^\top E(x_j))\]
<p>where \(\mathbf{w}\) is a weight vector to learn.</p>
<p>Once the feature vectors are constructed for the question and all the related paragraphs, the reader needs to predict the probabilities of each position in a paragraph to be the start and the end of an answer span, \(p_\text{start}(i_s)\) and \(p_\text{end}(i_s)\), respectively. Across all the paragraphs, the optimal span is returned as the final answer with maximum \(p_\text{start}(i_s) \times p_\text{end}(i_e)\).</p>
\[\begin{aligned}
p_\text{start}(i_s) \propto \exp(\mathbf{z}_{i_s} \mathbf{W}_s \mathbf{x}) \\
p_\text{end}(i_e) \propto \exp(\mathbf{z}_{i_e} \mathbf{W}_e \mathbf{x}) \\
\text{ s.t. } i_s \leq i_e \leq i_s + 15
\end{aligned}\]
<p>where \(\mathbf{W}_s\) and \(\mathbf{W}_e\) are learned parameters.</p>
<h4 id="bert-universe">BERT-universe</h4>
<p>Following the success of <a href="/lil-log /2019/01/31/generalized-language-models.html#bert">BERT</a> (<a href="https://arxiv.org/abs/1810.04805">Devlin et al., 2018</a>), many QA models develop the machine comprehension component based on BERT. Let’s define the BERT model as a function that can take one or multiple strings (concatenated by <code class="language-plaintext highlighter-rouge">[SEP]</code>) as input and outputs a set of BERT encoding vectors for the special <code class="language-plaintext highlighter-rouge">[CLS]</code> token and every input token:</p>
\[\text{BERT}(s_1, s_2, \dots) = [\mathbf{h}^\texttt{[CLS]}, \mathbf{h}^{(1)}, \mathbf{h}^{(2)}, \dots]\]
<p>where \(\mathbf{h}^\texttt{[CLS]}\) is the embedding vector for the special <code class="language-plaintext highlighter-rouge">[CLS]</code> token and \(\mathbf{h}^{(i)}\) is the embedding vector for the \(i\)-th token.</p>
<p>To use BERT for reading comprehension, it learns two additional weights, \(\mathbf{W}_s\) and \(\mathbf{W}_e\), and \(\text{softmax}(\mathbf{h}^{(i)}\mathbf{W}_s)\) and \(\text{softmax}(\mathbf{h}^{(i)}\mathbf{W}_e)\) define two probability distributions of start and end position of the predicted span per token.</p>
<p><strong>BERTserini</strong> (<a href="https://arxiv.org/abs/1902.01718">Yang et al., 2019</a>) utilizes a pre-trained BERT model to work as the reader. Their experiments showed that <em>fine-tuning</em> pretrained BERT with SQuAD is sufficient to achieve high accuracy in identifying answer spans.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/BERT-RC.png" alt="BERT for reading comprehension" /></p>
<p><em>Fig. 5. How BERT is used to solve question-answering tasks. (Image source: <a href="https://arxiv.org/abs/1810.04805">Devlin et al., 2018</a>)</em></p>
<p>The key difference of the BERTserini reader from the original BERT is: to allow comparison and aggregation of results from different segments, the final softmax layer over different answer spans is removed. The pre-trained BERT model is fine-tuned on the training set of SQuAD, where all inputs to the reader are padded to 384 tokens with the learning rate 3e-5.</p>
<p>When ranking all the extracted answer spans, the retriever score (BM25) and the reader score (probability of token being the start position \(\times\) probability of the same token being the end position ) are combined via linear interpolation.</p>
<p>The original BERT normalizes the probability distributions of start and end position per token for every passage independently. Differently, the <strong>Multi-passage BERT</strong> (<a href="https://arxiv.org/abs/1908.08167">Wang et al., 2019</a>) normalizes answer scores across all the retrieved passages of one question <a href="https://arxiv.org/abs/1710.10723">globally</a>. Precisely, multi-passage BERT removes the final normalization layer per passage in BERT for QA (same as in BERTserini) and then adds a global <code class="language-plaintext highlighter-rouge">softmax</code> over all the word positions of all the passages. Global normalization makes the reader model more stable while pin-pointing answers from a large number of passages.</p>
<p>In addition, multi-passage BERT implemented an independent <em>passage ranker</em> model via another BERT model and the rank score for \((x, z)\) is generated by a <code class="language-plaintext highlighter-rouge">softmax</code> over the representation vectors of the first <code class="language-plaintext highlighter-rouge">[CLS]</code> token. The passage ranker brings in extra 2% improvements. Similar idea of re-ranking passages with BERT was discussed in <a href="https://arxiv.org/abs/1901.04085">Nogueira & Cho, 2019</a>, too.</p>
<p>Interestingly, <a href="https://arxiv.org/abs/1908.08167">Wang et al., 2019</a> found that <em>explicit inter-sentence matching</em> does not seem to be critical for RC tasks with BERT; check the original paper for how the experiments were designed. One possible reason is that the multi-head self-attention layers in BERT has already embedded the inter-sentence matching.</p>
<h3 id="end-to-end-joint-training">End-to-end Joint Training</h3>
<p>The retriever and reader components can be jointly trained. This section covers R^3, ORQA, REALM and DPR. There are a lot of common designs, such as BERT-based dense vectors for retrieval and the loss function on maximizing the marginal likelihood of obtaining true answers.</p>
<p>The retriever and reader models in the <strong>R^3</strong> (“Reinforced Ranker-Reader”; <a href="https://arxiv.org/abs/1709.00023">Wang, et al., 2017</a>) QA system are jointly trained via <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">reinforcement learning</a>. (Note that to keep the term consistent between papers in this section, the “ranker” model in the original R^3 paper is referred to as the “retriever” model here.) Both components are variants of <a href="https://arxiv.org/abs/1512.08849">Match-LSTM</a>, which relies on an attention mechanism to compute word similarities between the passage and question sequences.</p>
<p><strong>How does the Match-LSTM module work?</strong> Given a question \(\mathbf{X}\) of \(d_x\) words and a passage \(\mathbf{Z}\) of \(d_z\) words, both representations use fixed <a href="/lil-log/2017/10/15/learning-word-embedding.html#glove-global-vectors">Glove</a> word embeddings,</p>
\[\begin{aligned}
\mathbf{H}^x &= \text{BiLSTM}(\mathbf{X}) \in \mathbb{R}^{l \times d_x} \\
\mathbf{H}^z &= \text{BiLSTM}(\mathbf{Z}) \in \mathbb{R}^{l \times d_z} \\
\mathbf{G} &= \text{softmax}((\mathbf{W}^g \mathbf{H}^x + \mathbf{b}^g \otimes \mathbf{e}_{d_x})^\top \mathbf{H}^z) \in \mathbb{R}^{d_x \times d_z} & \text{; an attention matrix}\\
\bar{\mathbf{H}}^x &= \mathbf{H}^x \mathbf{G} \in \mathbb{R}^{l \times d_z} \\
\mathbf{M} &= \text{ReLU} \Big( \mathbf{W}^m \begin{bmatrix}
\mathbf{H}^z \\
\bar{\mathbf{H}}^x \\
\mathbf{H}^z \odot \bar{\mathbf{H}}^x \\
\mathbf{H}^z - \bar{\mathbf{H}}^x
\end{bmatrix} \Big) \in \mathbb{R}^{2l \times d_z} \\
\mathbf{H}^m &= \text{BiLSTM}(M) \in \mathbb{R}^{l \times d_z}
\end{aligned}\]
<p>where \(l\) is the hidden dimension of the bidirectional LSTM module. \(\mathbf{W}^g \in \mathbb{R}^{l\times l}\), \(\mathbf{b}^g \in \mathbb{R}^l\), and \(\mathbf{W}^m \in \mathbb{R}^{2l \times 4l}\) are parameters to learn. The operator \(\otimes \mathbf{e}_{d_x}\) is the outer product to repeat the column vector \(\mathbf{b}^g\) \(d_x\) times.</p>
<p>The ranker and reader components share the same Match-LSTM module with two separate prediction heads in the last layer, resulting in \(\mathbf{H}^\text{rank}\) and \(\mathbf{H}^\text{reader}\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/R%5E3-arch.png" alt="R^3 QA" /></p>
<p><em>Fig. 6. The overview of R^3 (reinforced ranker-reader) architecture. Both components share the same Match-LSTM module. (Image source: <a href="https://arxiv.org/abs/1709.00023">Wang, et al., 2017</a>)</em></p>
<p>The retriever runs a max-pooling operation per passage and then aggregates to output a probability of each passage entailing the answer.</p>
\[\begin{aligned}
\mathbf{u}_i &= \text{max-pooling}(\mathbf{H}^\text{rank}_i) \in \mathbb{R}^l \\
\mathbf{C} &= \text{tanh}(\mathbf{W}^c[\mathbf{u}_1;\dots;\mathbf{u}_N] + \mathbf{b}^c \otimes \mathbf{e}_N) \in \mathbb{R}^{l \times n} \\
\gamma &= \text{softmax}(\mathbf{w}^c \mathbf{C}) \in \mathbb{R}^n
\end{aligned}\]
<p>Finally, the retriever is viewed as a <em>policy</em> to output action to sample a passage according to predicted \(\gamma\),</p>
\[\pi(z \vert x; \theta^\gamma) = \gamma_z\]
<p>The reader predicts the start position \(\beta^s\) and the end position \(\beta^e\) of the answer span. Two positions are computed in the same way, with independent parameters to learn. There are \(V\) words in all the passages involved.</p>
\[\begin{aligned}
\mathbf{H}^\text{read} &= [\mathbf{H}^\text{read}_\tau; \mathbf{H}^\text{read}_{\text{neg}_1}; \dots; \mathbf{H}^\text{read}_{\text{neg}_n}] \\
\mathbf{F}^s &= \text{tanh}(\mathbf{W}^s \mathbf{H}^\text{read} + \mathbf{b}^s \otimes \mathbf{e}_V) \quad
\beta^s = \text{softmax}(\mathbf{w}^s \mathbf{F}^s) \in \mathbb{R}^V \\
\mathbf{F}^e &= \text{tanh}(\mathbf{W}^e \mathbf{H}^\text{read} + \mathbf{b}^e \otimes \mathbf{e}_V) \quad
\beta^e = \text{softmax}(\mathbf{w}^e \mathbf{F}^e) \in \mathbb{R}^V \\
L(y \vert z, x) &= -\log(\beta^s_{y_z^s})-\log(\beta^e_{y_z^e})
\end{aligned}\]
<p>where \(y\) is the ground-truth answer and the passage \(z\) is sampled by the retriever. \(\beta^s_{y_z^s}\) and \(\beta^s_{y_z^e}\) represent the probabilities of the start and end positions of \(y\) in passage \(z\).</p>
<p>The training objective for the end-to-end R^3 QA system is to minimize the negative log-likelihood of obtaining the correct answer \(y\) given a question \(x\),</p>
\[\begin{aligned}
\mathcal{J}(\theta) &= -\mathbb{E}_{z\sim\pi(.\vert x)} [L(y \vert z, x)] \\
\nabla \mathcal{J}(\theta)
&= - \nabla_\theta \sum_z \pi(z \vert x) L(y \vert z, x) \\
&= - \sum_z \big( L(y \vert z, x) \nabla_\theta\pi(z \vert x) + \pi(z \vert x) \nabla_\theta L(y \vert z, x) \big) \\
&= - \mathbb{E}_{z\sim\pi(.\vert x)} \big( \color{red}{L(y \vert z, x)\nabla_\theta\log\pi(z \vert x)} + \nabla_\theta L(y \vert z, x) \big) \\
&\approx - \mathbb{E}_{z\sim\pi(.\vert x)} \big( \underbrace{\color{red}{R(y \vert z, x)\nabla_\theta\log\pi(z \vert x)}}_\text{REINFORCE} + \nabla_\theta L(y \vert z, x) \big)
\end{aligned}\]
<p>Essentially in training, given a passage \(z\) sampled by the retriever, the reader is trained by gradient descent while the retriever is trained by <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce">REINFORCE</a> using \(L(y \vert z, x)\) as the reward function. However, \(L(y \vert z, x)\) is not bounded and may introduce a lot of variance. The paper replaces the reward with a customized scoring function by comparing the ground truth \(y\) and the answer extracted by the reader \(\hat{y}\):</p>
\[R(y, \hat{y} \vert z) = \begin{cases}
2 & \text{if } y = \hat{y}\\
f1(y, \hat{y}) & \text{if } y \cap \hat{y} = \varnothing \\
-1 & \text{otherwise}
\end{cases}\]
<p style="width: 30%;" class="center"><img src="/lil-log/assets/images/R%5E3-reward-flow.png" alt="R^3 reward flow" /></p>
<p><em>Fig. 7. The workflow of R^3 training process. (Image source: <a href="https://github.com/danqi/acl2020-openqa-tutorial/blob/master/slides/part4-retriever-reader.pdf">acl2020-openqa-tutorial/slides/part4</a>)</em></p>
<p><a name="ORQA"></a><strong>ORQA</strong> (“Open-Retrieval Question-Answering”; <a href="https://arxiv.org/abs/1906.00300">Lee et al., 2019</a>) jointly learns a retriever + reader QA model to optimize marginal log-likelihood of obtaining correct answers in a supervised manner. No explicit “black-box” IR system is involved. Instead, it is capable of retrieving any text in an open corpus. During training, ORQA does not need ground-truth context passages (i.e. reading comprehension datasets) but only needs (question, answer) string pairs. Both retriever and reader components are based on BERT, but not shared.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/ORQA-retriever.png" alt="ORQA-retriever" /></p>
<p><em>Fig. 8. An illustration of the retriever component in ORQA. (Image source: replotted based on one slide in <a href="https://github.com/danqi/acl2020-openqa-tutorial/blob/master/slides/part5-dense-retriever-e2e-training.pdf">acl2020-openqa-tutorial/slides/part5</a>)</em></p>
<p>All the evidence blocks are ranked by a retrieval score, defined as the inner product of BERT embedding vectors of the <code class="language-plaintext highlighter-rouge">[CLS]</code> token of the question \(x\) and the evidence block \(z\). Note that the encoders for questions and context are independent.</p>
\[\begin{aligned}
h_x &= \mathbf{W}_x \text{BERT}_x(x)^{\mathtt{[CLS]}} \\
h_z &= \mathbf{W}_z \text{BERT}_z(z)^{\mathtt{[CLS]}} \\
S_\text{retr}(z, x) &= h_x^\top h_z
\end{aligned}\]
<p><a name="ICT-loss"></a>The retriever module is pretrained with <em>Inverse Cloze Task (ICT)</em>, which is to predict the context given a sentence, opposite to the standard <a href="https://en.wikipedia.org/wiki/Cloze_test">Cloze Task</a>. The ICT objective is to maximize the retrieval score of the correct context \(z\) given a random sentence \(x\):</p>
\[L_\text{ICT} = p_\text{early}(z \vert x) = \frac{\exp(S_\text{retr}(z, x))}{\sum_{z'\in\text{BATCH}(\mathcal{Z})} \exp(S_\text{retr}(z', x))}\]
<p>where \(\text{BATCH}(\mathcal{Z})\) is the set of evidence blocks in the same batch used as sampled negatives.</p>
<p>After such pretraining, the BERT retriever is expected to have representations good enough for evidence retrieval. Only the question encoder needs to be fine-tuned for answer extraction. In other words, the evidence block encoder (i.e., \(\mathbf{W}_z\) and \(\text{BERT}_z\)) is fixed and thus all the evidence block encodings can be pre-computed with support for <a href="#fast-maximum-inner-product-search-mips">fast Maximum Inner Product Search (MIPS)</a>.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/ORQA-reader.png" alt="ORQA-reader" /></p>
<p><em>Fig. 9. An illustration of the reader component in ORQA. (Image source: <a href="https://github.com/danqi/acl2020-openqa-tutorial/blob/master/slides/part5-dense-retriever-e2e-training.pdf">acl2020-openqa-tutorial/slides/part5</a>)</em></p>
<p>The reader follows the same design as in the original <a href="/lil-log/2019/01/31/generalized-language-models.html#use-bert-in-downstream-tasks">BERT RC</a> experiments. It learns in a supervised manner, while the parameters of the evidence block encoder are fixed and all other parameters are fine-tuned. Given a question \(x\) and a gold answer string \(y\), the reader loss contains two parts:</p>
\[\mathcal{L}(x, y) = \mathcal{L}_\text{early}(x, y) + \mathcal{L}_\text{full}(x, y)\]
<p>(1) Find all correct text spans within top \(k\) evidence blocks and optimize for the marginal likelihood of a text span \(s\) that matches the true answer \(y\):</p>
\[\begin{aligned}
h_s &= \text{BERT}_R(x, y)^{(\text{START}(s))} \\
h_e &= \text{BERT}_R(x, y)^{(\text{END}(s))} \\
S_\text{read}(z, s, x) &= \text{MLP}([h_s; h_e]) \\
p(z, s \vert x) &= \frac{\exp(S_\text{read}(z, s, x))}{\sum_{z'\in\text{TOP}(k)} \sum_{s'\in z'} \exp(S_\text{read}(z', s', x))} \\
L_\text{full}(x, y) &= - \log \sum_{\substack{z \in \text{TOP}(k)\\ s \in z}} \sum_{y=\text{TEXT}(s)} p(z, s \vert x)
\end{aligned}\]
<p>where \(y=\text{TEXT}(s)\) indicates whether the answer \(y\) matches the text span \(s\). \(\text{TOP}(k)\) is the top \(k\) retrieved blocks according to \(S_\text{retr}(z, x)\). The paper sets \(k=5\).</p>
<p>(2) At the early stage of learning, when the retriever is not strong enough, it is possible none of the top \(k\) blocks contains the answer. To avoid such sparse learning signals, ORQA considers a larger set of \(c\) evidence blocks for more aggressive learning. The paper has \(c=5000\).</p>
\[L_\text{early}(x, y)
= -\log \sum_{\substack{z\in \text{TOP}(c)\\y\in\text{TEXT}(z)}} p_\text{early}(z\vert x)
= -\log \sum_{\substack{z\in \text{TOP}(c)\\y\in\text{TEXT}(z)}} \frac{\exp(S_\text{retr}(z, x)}{\sum_{z'\in\text{TOP}(c)} \exp(S_\text{retr}(z', x)}\]
<p>Some issues in SQuAD dataset were discussed in the ORQA paper:</p>
<blockquote>
<p>” The notable drop between development and test accuracy for SQuAD is a reflection of an artifact in the dataset—its 100k questions are derived from only 536 documents. Therefore, good retrieval targets are highly correlated between training examples, violating the IID assumption, and making it unsuitable for learned retrieval. We strongly suggest that those who are interested in end-to-end open-domain QA models no longer train and evaluate with SQuAD for this reason.”</p>
</blockquote>
<p><a name="REALM"></a><strong>REALM</strong> (“Retrieval-Augmented Language Model pre-training”; <a href="https://arxiv.org/abs/2002.08909">Guu et al., 2020</a>) also jointly trains retriever + reader by optimizing the marginal likelihood of obtaining the true answer:</p>
\[p(y \vert x)
= \sum_{z \in \mathcal{Z}} \underbrace{p(y \vert x, z)}_\text{reader} \underbrace{p(z \vert x)}_\text{retriever}
\approx \sum_{z \in \text{TOP}_k(\mathcal{Z})} p(y \vert x, z) p(z \vert x)\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/REALM-train.png" alt="REALM" /></p>
<p><em>Fig. 10. REALM is first unsupervised pre-trained with salient spans masking and then fine-tuned with QA data. (Image source: <a href="https://arxiv.org/abs/2002.08909">Guu et al., 2020</a>).</em></p>
<p>REALM computes two probabilities, \(p(z \vert x)\) and \(p(y \vert x, z)\), same as ORQA. However, different from ICT in ORQA, REALM upgrades the unsupervised pre-training step with several new design decisions, leading towards better retrievals. REALM pre-trains the model with Wikipedia or CC-News corpus.</p>
<ol>
<li><a name="ssm"></a>Use <em>salient span masking</em>. Named entities and dates are identified. Then one of these “salient spans” is selected and masked. Salient span masking is a special case of MLM and works out well for QA tasks.</li>
<li>Add an <em>empty null document</em>. Because not every question demands a context document.</li>
<li>No trivial retrieval. The context document should not be same as the selected sentence with a masked span.</li>
<li>Apply the same ICT loss as in ORQA to encourage learning when the retrieval quality is still poor at the early stage of training.</li>
</ol>
<blockquote>
<p>“Among all systems, the most direct comparison with REALM is ORQA (Lee et al., 2019), where the fine-tuning setup, hyperparameters and training data are identical. The improvement of REALM over ORQA is purely due to better pre-training methods.” — from REALM paper.</p>
</blockquote>
<p>Both unsupervised pre-training and supervised fine-tuning optimize the same log-likelihood \(\log p(y \vert x)\). Because the parameters of the retriever encoder for evidence documents are also updated in the process, the index for MIPS is changing. REALM asynchronously refreshes the index with the updated encoder parameters every several hundred training steps.</p>
<p><a name="DPR"></a><strong>DPR</strong> (“Dense Passage Retriever”; <a href="https://arxiv.org/abs/2004.04906">Karpukhin et al., 2020</a>, <a href="https://github.com/facebookresearch/DPR">code</a>) argues that ICT pre-training could be too computationally expensive and the ORQA’s context encoder might be sub-optimal because it is not fine-tuned with question-answer pairs. DPR aims to resolve these two issues by only training a dense dual-encoder architecture for retrieval only from a small number of Q/A pairs, without any pre-training.</p>
<p>Same as previous work, DPR uses the dot-product (L2 distance or cosine similarity also works) of BERT representations as retrieval score. The loss function for training the dual-encoder is the NLL of the positive passage, which essentially takes the same formulation as <a href="#ICT-loss">ICT loss</a> of ORQA. Note that both of them consider other passages in the same batch as the negative samples, named <em>in-batch negative sampling</em>. The main difference is that DPR relies on supervised QA data, while ORQA trains with ICT on unsupervised corpus. At the inference time, DPR uses <a href="https://github.com/facebookresearch/faiss">FAISS</a> to run fast MIPS.</p>
<p>DPR did a set of comparison experiments involving several different types of negatives:</p>
<ol>
<li>Random: any random passage from the corpus;</li>
<li>BM25: top passages returned by BM25 which don’t contain the answer but match most question tokens;</li>
<li>In-batch negative sampling (“gold”): positive passages paired with other questions which appear in the training set.</li>
</ol>
<p>DPR found that using gold passages from the same mini-batch and one negative passage with high BM25 score works the best. To further improve the retrieval results, DPR also explored a setting where a BM25 score and a dense embedding retrieval score are linearly combined to serve as a new ranking function.</p>
<h2 id="open-book-qa-retriever-generator">Open-book QA: Retriever-Generator</h2>
<p>Compared to the retriever-reader approach, the retriever-generator also has 2 stages but the second stage is to generate free text directly to answer the question rather than to extract start/end position in a retrieved passage. Some paper also refer to this as <em>Generative question answering</em>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/QA-retiever-generator.png" alt="retriever + text generator" /></p>
<p><em>Fig. 11. The retriever + generator QA framework combines a document retrieval system with a general language model.</em></p>
<p>A pretrained LM has a great capacity of memorizing knowledge in its parameters, as shown above. However, they cannot easily modify or expand their memory, cannot straightforwardly provide insights into their predictions, and may produce non-existent illusion.</p>
<p><a href="https://arxiv.org/abs/2005.04611">Petroni et al. (2020)</a> studied how the retrieved relevant context can help a generative language model produce better answers. They found:</p>
<ol>
<li>Augmenting queries with relevant contexts dramatically improves the pretrained LM on unsupervised machine reading capabilities.</li>
<li>An off-the-shelf IR system is sufficient for BERT to match the performance of a supervised ODQA baseline;</li>
<li>BERT’s <a href="/lil-log/2019/01/31/generalized-language-models.html#pre-training-tasks">NSP</a> pre-training strategy is a highly effective unsupervised mechanism in dealing with noisy and irrelevant contexts.</li>
</ol>
<p>They pair the BERT model with different types of context, including adversarial (unrelated context), retrieved (by BM25), and generative (by an autoregressive language model of 1.4N parameters, trained on CC-NEWS). The model is found to be robust to adversarial context, but only when the question and the context are provided as two segments (e.g. separated by <code class="language-plaintext highlighter-rouge">[SEP]</code>). One hypothesis is related to NSP task: “BERT might learn to not condition across segments for masked token prediction if the NSP score is low, thereby implicitly detecting irrelevant and noisy contexts.”</p>
<p><strong>RAG</strong> (“Retrieval-Augmented Generation”; <a href="https://arxiv.org/abs/2005.11401">Lewis et al., 2020</a>) combines pre-trained parametric (language model) and non-parametric memory (external knowledge index) together for language generation. RAG can be fine-tuned on any seq2seq task, whereby both the retriever and the sequence generator are jointly learned. They found that unconstrained generation outperforms previous extractive approaches.</p>
<p>RAG consists of a retriever model \(p_\eta(z \vert x)\) and a generator model \(p_\theta(y_i \vert x, z, y_{1:i-1})\):</p>
<ul>
<li>The retriever uses the input sequence \(x\) to retrieve text passages \(z\), implemented as a <a href="#DPR">DPR</a> retriever. \(\log p_\eta(z \vert x) \propto E_z(z)^\top E_x(x)\).</li>
<li>The generator uses \(z\) as additional context when generating the target sequence \(y\), where the context and the question are simply concatenated.</li>
</ul>
<p>Depending on whether using the same or different retrieved documents for each token generation, there are two versions of RAG:</p>
\[\begin{aligned}
p_\text{RAG-seq}(y \vert x) &= \sum_{z \in \text{TOP}_k(p_\eta(.\vert x))} p_\eta(z \vert x) \prod_i^N p_\theta(y_i \vert x, z, y_{1:i-1}) \\
p_\text{RAG-token}(y \vert x) &= \prod_i^N \sum_{z \in \text{TOP}_k(p_\eta(.\vert x))} p_\eta(z_i\vert x) p_\theta(y_i \vert x, z_i, y_{1:i-1})
\end{aligned}\]
<p>The retriever + generator in RAG is jointly trained to minimize the NLL loss, \(\mathcal{L}_\text{RAG} = \sum_j -\log p(y_j \vert x_j)\). Updating the passage encoder \(E_z(.)\) is expensive as it requires the model to re-index the documents for fast MIPS. RAG does not find fine-tuning \(E_z(.)\) necessary (like in <a href="#ORQA">ORQA</a>) and only updates the query encoder + generator.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RAG.png" alt="RAG" /></p>
<p><em>Fig. 12. An illustration of retrieval-augmented generation (RAG) architecture. (Image source: <a href="https://arxiv.org/abs/2005.11401">Lewis et al., 2020</a>)</em></p>
<p>At decoding/test time, RAG-token can be evaluated via a <a href="https://d2l.ai/chapter_recurrent-modern/beam-search.html#id1">beam search</a>. RAG-seq cannot be broken down into a set of per-token likelihood, so it runs beam search for each candidate document \(z\) and picks the one with optimal \(p_\theta(y_i \vert x, z, y_{1:i-1})\).</p>
<p>The <em>Fusion-in-Decoder</em> approach, proposed by <a href="https://arxiv.org/abs/2007.01282">Izacard & Grave (2020)</a> is also based on a pre-trained T5. It works similar to RAG but differently for how the context is integrated into the decoder.</p>
<ol>
<li>Retrieve top \(k\) related passage of 100 words each, using BM25 or DPR.</li>
<li>Each retrieved passage and its title are concatenated with the question using special tokens like <code class="language-plaintext highlighter-rouge">question:</code>, <code class="language-plaintext highlighter-rouge">title:</code> and <code class="language-plaintext highlighter-rouge">context:</code> to indicate the content differences.</li>
<li>Each retrieved passage is processed independently and later combined in the decoder. Processing passages independently in the encoder allows us to parallelize the computation. OTOH, processing them jointly encourages better aggregation of multiple pieces of evidence. The aggregation part is missing in extractive approaches.</li>
</ol>
<p>Note that they did fine-tune the pretrained LM independently for each dataset.</p>
<h2 id="closed-book-qa-generative-language-model">Closed-book QA: Generative Language Model</h2>
<p>Big language models have been pre-trained on a large collection of unsupervised textual corpus. Given enough parameters, these models are able to memorize some factual knowledge within parameter weights. Therefore, we can use these models to do question-answering without explicit context, just like in a closed-book exam. The pre-trained language models produce <em>free text</em> to respond to questions, no explicit reading comprehension.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/LM-compute.png" alt="LM compute" /></p>
<p><em>Fig. 13. The amount of computation used for training big language models of different sizes is getting big. (Image source: <a href="https://arxiv.org/abs/2005.14165">Brown et al., 2020</a>).</em></p>
<p><a href="https://arxiv.org/abs/2002.08910">Roberts et al. (2020)</a> measured the practical utility of a language model by fine-tuning a pre-trained model to answer questions without access to any external context or knowledge. They fine-tuned the <a href="https://arxiv.org/abs/1910.10683">T5</a> language model (same architecture as the original Transformer) to answer questions without inputting any additional information or context. Such setup enforces the language model to answer questions based on “knowledge” that it internalized during pre-training.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/T5_SSM.png" alt="T5+SSM" /></p>
<p><em>Fig. 14. T5 is first pre-trained with salient span masking and then fine-tuned for each QA dataset to produce answers in free text. (Image source: <a href="https://arxiv.org/abs/2002.08910">Roberts et al. 2020</a>)</em></p>
<p>The original T5 models were pre-trained on a multi-task mixture including an unsupervised <a href="/lil-log/2019/01/31/generalized-language-models.html#use-bert-in-downstream-tasks">“masked language modeling”</a> (MLM) tasks on the C4 (“Colossal Clean Crawled Corpus”) dataset as well as fine-tuned altogether with supervised translation, summarization, classification, and reading comprehension tasks. <a href="https://arxiv.org/abs/2002.08910">Roberts, et al. (2020)</a> took a pre-trained T5 model and continued pre-training with <a href="#ssm">salient span masking</a> over Wikipedia corpus, which has been found to substantially boost the performance for ODQA. Then they fine-tuned the model for each QA datasets independently.</p>
<p>With a pre-trained T5 language model + continue pre-training with salient spans masking + fine-tuning for each QA dataset,</p>
<ul>
<li>It can attain competitive results in open-domain question answering without access to external knowledge.</li>
<li>A larger model can obtain better performance. For example, a T5 with 11B parameters is able to match the performance with <a href="#DPR">DPR</a> with 3 BERT-base models, each with 330M parameters.</li>
</ul>
<p>Interestingly, fine-tuning is not strictly necessary. GPT3 (<a href="https://arxiv.org/abs/2005.14165">Brown et al., 2020</a>) has been evaluated on the closed book question answering task <em>without any gradient updates or fine-tuning</em>. During evaluation, the few-shot, one-shot and zero-shot settings here only refer to how many demonstrations are provided as context in the text input:</p>
<ol>
<li>“few-shot learning”: GPT3 is allowed to take as many demonstrations as what can fit into the model’s context window (typically 10 to 100).</li>
<li>“one-shot learning”: only one demonstration is provided.</li>
<li>“zero-shot learning”: no demonstrations are allowed and only an instruction in natural language is given to the model.</li>
</ol>
<p>The performance grows with the model size. On the TriviaQA dataset, GPT3 evaluation with demonstrations can match or exceed the performance of SOTA baseline with fine-tuning.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/GPT3-triviaqa.png" alt="GPT3 on TriviaQA" /></p>
<p><em>Fig. 15. GPT3’s performance on TriviaQA grows smoothly with the model size. More demonstrations lead to better performance. (Image source: <a href="https://arxiv.org/abs/2005.14165">Brown et al., 2020</a>).</em></p>
<p><a name="openai-api-example"></a>Check out this cool example in OpenAI API <a href="https://beta.openai.com/playground/p/HMoho4552EHXrPLbmOIxpX4X">playground viewer</a>. The model is able to answer factal questions in short answer and not to make up things when the model does not know the answer. I added the last two questions and asked the model to respond with <code class="language-plaintext highlighter-rouge">A:</code>. The API is still in beta version, so you might need to <a href="https://beta.openai.com/">apply</a> to get on the wait list.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Q: Who is Batman?
A: Batman is a fictional comic book character.
###
Q: What is torsalplexity?
A: ?
###
Q: What is Devz9?
A: ?
###
Q: Who is George Lucas?
A: George Lucas is American film director and producer famous for creating Star Wars.
###
Q: What is the capital of California?
A: Sacramento.
###
Q: What orbits the Earth?
A: The Moon.
###
Q: Who is Fred Rickerson?
A: ?
###
Q: What is an atom?
A: An atom is a tiny particle that makes up everything.
###
Q: Who is Alvan Muntz?
A: ?
###
Q: What is Kozar-09?
A: ?
###
Q: How many moons does Mars have?
A: Two, Phobos and Deimos.
###
Q: What is COVID-19?
A: ?
###
Q: What is H1N1?
A: H1N1 is a strain of influenza.
</code></pre></div></div>
<h2 id="related-techniques">Related Techniques</h2>
<h3 id="fast-maximum-inner-product-search-mips">Fast Maximum Inner Product Search (MIPS)</h3>
<p>MIPS (maximum inner product search) is a crucial component in many open-domain question answering models. In retriever + reader/generator framework, a large number of passages from the knowledge source are encoded and stored in a memory. A retrieval model is able to query the memory to identify the top relevant passages which have the maximum inner product with the question’s embedding.</p>
<p>We need fast MIPS because the number of precomputed passage representations can be gigantic. There are several ways to achieve fast MIPS at run time, such as <a href="https://papers.nips.cc/paper/5329-asymmetric-lsh-alsh-for-sublinear-time-maximum-inner-product-search-mips.pdf">asymmetric LSH</a>, <a href="https://arxiv.org/abs/1501.01062">data-dependent hashing</a>, and <a href="https://github.com/facebookresearch/faiss">FAISS</a>.</p>
<h3 id="language-model-pre-training">Language Model Pre-training</h3>
<p>Two pre-training tasks are especially helpful for QA tasks, as we have discussed above.</p>
<ul>
<li>
<p><strong>Inverse Cloze Task</strong> (proposed by <a href="#ORQA">ORQA</a>): The goal of <a href="https://en.wikipedia.org/wiki/Cloze_test">Cloze Task</a> is to predict masked-out text based on its context. The prediction of Inverse Cloze Task (ICT) is in the reverse direction, aiming to predict the context given a sentence. In the context of QA tasks, a random sentence can be treated as a pseudo-question, and its context can be treated as pseudo-evidence.</p>
</li>
<li>
<p><strong>Salient Spans Masking</strong> (proposed by <a href="#REALM">REALM</a>): Salient span masking is a special case for MLM task in language model training. First, we find <em>salient spans</em> by using a tagger to identify named entities and a regular expression to identify dates. Then one of the detected salient spans is selected and masked. The task is to predict this masked salient span.</p>
</li>
</ul>
<h2 id="summary">Summary</h2>
<table class="info">
<thead>
<tr>
<th>Model</th>
<th>Retriever</th>
<th>Reader / Generator</th>
<th>Pre-training / Fine-tuning</th>
<th>End2end</th>
</tr>
</thead>
<tbody>
<tr>
<td>DrQA</td>
<td>TF-IDF</td>
<td>Bi-directional LSTM</td>
<td>–</td>
<td>No</td>
</tr>
<tr>
<td>BERTserini</td>
<td>Aserini + BM25</td>
<td>BERT without softmax layer</td>
<td>Fine-tune with SQuAD</td>
<td>No</td>
</tr>
<tr>
<td>Multi-passage BERT</td>
<td>ElasticSearch + BM25</td>
<td>Multi-passage BERT + Passage ranker</td>
<td> </td>
<td>No</td>
</tr>
<tr>
<td>R^3</td>
<td>Classic IR + Match-LSTM</td>
<td>Match-LSTM</td>
<td> </td>
<td>Yes</td>
</tr>
<tr>
<td>ORQA</td>
<td>Dot product of BERT embeddings</td>
<td>BERT-RC</td>
<td>Inverse cloze task</td>
<td>Yes</td>
</tr>
<tr>
<td>REALM</td>
<td>Dot product of BERT embeddings</td>
<td>BERT-RC</td>
<td>Salient span masking</td>
<td>Yes</td>
</tr>
<tr>
<td>DPR</td>
<td>Dot product of BERT embeddings</td>
<td>BERT-RC</td>
<td>supervised training with QA pairs</td>
<td>Yes</td>
</tr>
<tr>
<td>DenSPI</td>
<td>Classic + Neural IR</td>
<td>–</td>
<td> </td>
<td>Yes</td>
</tr>
<tr>
<td>T5 + SSM</td>
<td>–</td>
<td>T5</td>
<td>SSM on <a href="https://commoncrawl.org/the-data/get-started/">CommonCrawl</a> data + Fine-tuning on QA data</td>
<td>Yes</td>
</tr>
<tr>
<td>GPT3</td>
<td>–</td>
<td>GPT3</td>
<td>NSP on <a href="https://commoncrawl.org/the-data/get-started/">CommonCrawl</a> data</td>
<td>Yes</td>
</tr>
<tr>
<td>RAG</td>
<td>DPR retriever</td>
<td><a href="https://arxiv.org/abs/1910.13461">BART</a></td>
<td> </td>
<td>Yes</td>
</tr>
<tr>
<td>Fusion-in-Decoder</td>
<td>BM25 / DPR retriever</td>
<td>Tranformer</td>
<td> </td>
<td>No</td>
</tr>
</tbody>
</table>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/QA-results.png" alt="SOTA-comparison" /></p>
<p><em>Fig. 16. A comparison of performance of several QA models on common QA datasets. On TriviaQA, two columns of results are reported, on the open domain test set (left) and on the hidden test set (right). (Image source: <a href="https://arxiv.org/abs/2007.01282">Izacard & Grave, 2020</a>).</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020odqa,
title = "How to Build an Open-Domain Question Answering System?",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/10/29/open-domain-question-answering.html"
}
</code></pre></div></div>
<h2 id="appendix-qa-datasets">Appendix: QA Datasets</h2>
<ul>
<li><a href="https://rajpurkar.github.io/SQuAD-explorer/">SQuAD 2.0</a>: the Stanford QA dataset.</li>
<li><a href="http://www.qizhexie.com/data/RACE_leaderboard">RACE</a>: a reading comprehension dataset collected from English Examinations that are created for middle school and high school students.</li>
<li><a href="https://trec.nist.gov/data/qa.html">TREC QA</a>: the TREC QA collections.</li>
<li><a href="https://microsoft.github.io/msmarco/">MS MARCO</a>: a QA dataset featuring 100,000 real Bing questions and a human generated answer.</li>
<li><a href="https://github.com/brmson/dataset-factoid-curated">CuratedTREC</a>: based on the benchmarks from the TREC QA tasks that have been curated by <a href="https://link.springer.com/chapter/10.1007%2F978-3-319-24027-5_20">Baudis & Sedivy (2015)</a>.</li>
<li><a href="https://ai.google.com/research/NaturalQuestions/dataset">Google Natural Questions</a>: contains real user questions issued to Google search, and answers found from Wikipedia by annotators.</li>
<li><a href="https://github.com/brmson/dataset-factoid-webquestions">WebQuestions</a>: designed for knowledge-base QA with answers restricted to Freebase entities.</li>
<li><a href="https://www.microsoft.com/en-us/research/publication/wikiqa-a-challenge-dataset-for-open-domain-question-answering/">WikiQA</a>: Bing query logs were used as the source of questions. Each question is then linked to a Wikipedia page that potentially contains the answer.</li>
<li><a href="https://research.fb.com/downloads/babi/">WikiMovies</a>: contains movie-related questions from the OMDb and MovieLens databases and where the questions can be answered using Wikipedia pages.</li>
<li><a href="https://github.com/google-research-datasets/wiki-reading">WikiReading</a>: to predict textual values from the structured knowledge base Wikidata by reading the text of the corresponding Wikipedia articles.</li>
<li><a href="https://nlp.cs.washington.edu/triviaqa/">TriviaQA</a>: a reading comprehension dataset containing 95K question-answer pairs authored by trivia enthusiasts and independently gathered multiple evidence documents per question.</li>
<li><a href="https://www.kaggle.com/tunguz/200000-jeopardy-questions"> Jeopardy! Questions</a>: contains 200,000+ <a href="https://en.wikipedia.org/wiki/Jeopardy!">Jeopardy!</a> questions.</li>
<li><a href="https://cs.nyu.edu/~kcho/DMQA/">DeepMind Q&A Dataset</a>: question/answer pairs from CNN and Daily Mail articles.</li>
<li><a href="https://research.fb.com/downloads/babi/">bAbi</a>: a rich collection of datasets for text understanding by Facebook.</li>
<li><a href="https://fever.ai/data.html">FEVER</a>: for fact extraction and verification.</li>
<li><a href="https://github.com/nyu-dl/dl4ir-searchQA">SearchQA</a>: question-answer pairs were crawled from from <a href="https://j-archive.com/"> J! Archive</a>, and then augmented with text snippets from Google.</li>
<li><a href="https://github.com/bdhingra/quasar">Quasar-T</a>: a collection of open-domain trivia questions and their answers obtained from various internet sources.</li>
<li><a href="https://people.cs.umass.edu/~miyyer/qblearn/index.html">Quiz bowl</a>: contains data from a trivia competition called quiz bowl.</li>
<li><a href="https://nlp.cs.washington.edu/ambigqa/">AmbigNQ</a>: ambiguous questions selected from NQ-OPEN dataset.</li>
<li><a href="https://github.com/facebookresearch/QA-Overlap">QA-Overlap</a>: a collections of overlapped answers/questions between train and test set for Natural Questions, TriviaQA, and WebQuestions.</li>
</ul>
<h2 id="references">References</h2>
<p>[1] Danqi Chen & Scott Yih. <a href="https://github.com/danqi/acl2020-openqa-tutorial">“ACL2020 Tutorial: Open-Domain Question Answering”</a> July 2020.</p>
<table>
<tbody>
<tr>
<td>[2] Danqi Chen, et al. <a href="https://arxiv.org/abs/1704.00051">“Reading Wikipedia to Answer Open-Domain Questions”</a> ACL 2017.</td>
<td><a href="https://github.com/facebookresearch/DrQA">code</a></td>
</tr>
</tbody>
</table>
<p>[3] Shuohang Wang, et al. <a href="https://arxiv.org/abs/1709.00023">“R^3: Reinforced Ranker-Reader for Open-Domain Question Answering”</a> AAAI 2018.</p>
<p>[4] Jimmy Lin. <a href="https://sigir.org/wp-content/uploads/2019/01/p040.pdf">“The neural hype and comparisons against weak baselines.”</a> ACM SIGIR Forum. Vol. 52. No. 2. 2019.</p>
<p>[5] Wei Yang, et al. <a href="https://arxiv.org/abs/1902.01718">“End-to-End Open-Domain Question Answering with BERTserini”</a> NAACL 2019.</p>
<p>[6] Christopher Clark & Matt Gardner. <a href="https://arxiv.org/abs/1710.10723">“Simple and Effective Multi-Paragraph Reading Comprehension.”</a> arXiv:1710.10723 (2017).</p>
<p>[7] Rodrigo Nogueira & Kyunghyun Cho. <a href="https://arxiv.org/abs/1901.04085">“Passage Re-ranking with BERT.”</a> arXiv preprint arXiv:1901.04085 (2019). | <a href="https://github.com/nyu-dl/dl4marco-bert">code</a></p>
<p>[8] Zhiguo Wang, et al. <a href="https://arxiv.org/abs/1908.08167">“Multi-passage BERT: A globally normalized BERT model for open-domain question answering.”</a> EMNLP 2019.</p>
<p>[9] Minjoon Seo et al. <a href="https://arxiv.org/abs/1906.05807">“Real-time open-domain question answering with dense-sparse phrase index.”</a> ACL 2019.</p>
<p>[10] Kenton Lee, et al. <a href="https://arxiv.org/abs/1906.00300">“Latent Retrieval for Weakly Supervised Open Domain Question Answering”</a> ACL 2019.</p>
<p>[11] Kelvin Guu, et al. <a href="https://arxiv.org/abs/2002.08909">“REALM: Retrieval-Augmented Language Model Pre-Training”</a> arXiv:2002.08909 (2020).</p>
<p>[12] Vladimir Karpukhin et al. <a href="https://arxiv.org/abs/2004.04906">“Dense passage retrieval for open-domain question answering.”</a>. EMNLP 2020. | <a href="https://github.com/facebookresearch/DPR">code</a></p>
<p>[13] Patrick Lewis et al. <a href="https://arxiv.org/abs/2005.11401">“Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks”</a> arXiv:2005.11401 (2020).</p>
<p>[14] Adam Roberts, et al. <a href="https://arxiv.org/abs/2002.08910">“How Much Knowledge Can You Pack Into the Parameters of a Language Model?”</a> EMNLP 2020.</p>
<p>[15] Tom Brown, et al. <a href="https://arxiv.org/abs/2005.14165">“Language models are few-shot learners.”</a> arXiv:2005.14165 (2020).</p>
<p>[16] Fabio Petroni, et al. <a href="https://arxiv.org/abs/2005.04611">“How Context Affects Language Models’ Factual Predictions”</a> AKBC 2020.</p>
<p>[17] Gautier Izacard & Edouard Grave. <a href="https://arxiv.org/abs/2007.01282">“Leveraging passage retrieval with generative models for open domain question answering.”</a> arXiv:2007.01282 (2020).</p>
<p>[18] <a href="https://d2l.ai/chapter_recurrent-modern/beam-search.html">“Dive into deep learning: Beam search”</a></p>
<p>[19] Patrick Lewis, et al. <a href="https://arxiv.org/abs/2008.02637">“Question and Answer Test-Train Overlap in Open-Domain Question Answering Datasets”</a> arXiv:2008.02637 (2020). | <a href="https://github.com/facebookresearch/QA-Overlap">data</a></p>
<p>[20] Hervé Jegou, et al. <a href="https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/">“Faiss: A library for efficient similarity search”</a> Mar 2017.</p>Lilian WengA model that is capable of answering any question with regard to factual knowledge can enable many useful applications. This post delves into how we can build an Open-Domain Question Answering (ODQA) system, assuming we have access to a powerful pretrained language model. Both closed-book and open-book approachs are discussed.Neural Architecture Search2020-08-06T12:00:00+00:002020-08-06T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/08/06/neural-architecture-search<blockquote>
<p>Neural Architecture Search (NAS) automates network architecture engineering. It aims to learn a network topology that can achieve best performance on a certain task. By dissecting the methods for NAS into three components: search space, search algorithm and child model evolution strategy, this post reviews many interesting ideas for better, faster and more cost-efficient automatic neural architecture search.</p>
</blockquote>
<!--more-->
<p>Although most popular and successful model architectures are designed by human experts, it doesn’t mean we have explored the entire network architecture space and settled down with the best option. We would have a better chance to find the optimal solution if we adopt a systematic and automatic way of learning high-performance model architectures.</p>
<p>Automatically learning and evolving network topologies is not a new idea (<a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">Stanley & Miikkulainen, 2002</a>). In recent years, the pioneering work by <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a> and <a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a> has attracted a lot of attention into the field of Neural Architecture Search (NAS), leading to many interesting ideas for better, faster and more cost-efficient NAS methods.</p>
<p>As I started looking into NAS, I found this nice survey very helpful by <a href="https://arxiv.org/abs/1808.05377">Elsken, et al 2019</a>. They characterize NAS as a system with three major components, which is clean & concise, and also commonly adopted in other NAS papers.</p>
<ol>
<li><strong>Search space</strong>: The NAS search space defines a set of operations (e.g. convolution, fully-connected, pooling) and how operations can be connected to form valid network architectures. The design of search space usually involves human expertise, as well as unavoidably human biases.</li>
<li><strong>Search algorithm</strong>: A NAS search algorithm samples a population of network architecture candidates. It receives the child model performance metrics as rewards (e.g. high accuracy, low latency) and optimizes to generate high-performance architecture candidates.</li>
<li><strong>Evaluation strategy</strong>: We need to measure, estimate, or predict the performance of a large number of proposed child models in order to obtain feedback for the search algorithm to learn. The process of candidate evaluation could be very expensive and many new methods have been proposed to save time or computation resources.</li>
</ol>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/NAS-high-level.png" alt="High-level categorization of NAS" /></p>
<p><em>Figure 1. Three main components of Neural Architecture Search (NAS) models. (Image source: <a href="https://arxiv.org/abs/1808.05377">Elsken, et al. 2019</a> with customized annotation in red)</em></p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#search-space" id="markdown-toc-search-space">Search Space</a> <ul>
<li><a href="#sequential-layer-wise-operations" id="markdown-toc-sequential-layer-wise-operations">Sequential Layer-wise Operations</a></li>
<li><a href="#cell-based-representation" id="markdown-toc-cell-based-representation">Cell-based Representation</a></li>
<li><a href="#hierarchical-structure" id="markdown-toc-hierarchical-structure">Hierarchical Structure</a></li>
<li><a href="#memory-bank-representation" id="markdown-toc-memory-bank-representation">Memory-bank Representation</a></li>
</ul>
</li>
<li><a href="#search-algorithms" id="markdown-toc-search-algorithms">Search Algorithms</a> <ul>
<li><a href="#random-search" id="markdown-toc-random-search">Random Search</a></li>
<li><a href="#reinforcement-learning" id="markdown-toc-reinforcement-learning">Reinforcement Learning</a></li>
<li><a href="#evolutionary-algorithms" id="markdown-toc-evolutionary-algorithms">Evolutionary Algorithms</a></li>
<li><a href="#progressive-decision-process" id="markdown-toc-progressive-decision-process">Progressive Decision Process</a></li>
<li><a href="#gradient-descent" id="markdown-toc-gradient-descent">Gradient descent</a></li>
</ul>
</li>
<li><a href="#evaluation-strategy" id="markdown-toc-evaluation-strategy">Evaluation Strategy</a> <ul>
<li><a href="#training-from-scratch" id="markdown-toc-training-from-scratch">Training from Scratch</a></li>
<li><a href="#proxy-task-performance" id="markdown-toc-proxy-task-performance">Proxy Task Performance</a></li>
<li><a href="#parameter-sharing" id="markdown-toc-parameter-sharing">Parameter Sharing</a></li>
<li><a href="#prediction-based" id="markdown-toc-prediction-based">Prediction-Based</a></li>
</ul>
</li>
<li><a href="#one-shot-approach-search--evaluation" id="markdown-toc-one-shot-approach-search--evaluation">One-Shot Approach: Search + Evaluation</a></li>
<li><a href="#whats-the-future" id="markdown-toc-whats-the-future">What’s the Future?</a></li>
<li><a href="#appendix-summary-of-nas-papers" id="markdown-toc-appendix-summary-of-nas-papers">Appendix: Summary of NAS Papers</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="search-space">Search Space</h2>
<p>The NAS search space defines a set of basic network operations and how operations can be connected to construct valid network architectures.</p>
<h3 id="sequential-layer-wise-operations">Sequential Layer-wise Operations</h3>
<p>The most naive way to design the search space for neural network architectures is to depict network topologies, either CNN or RNN, with a list of <em>sequential layer-wise operations</em>, as seen in the early work of <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a> & <a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a>. The serialization of network representation requires a decent amount of expert knowledge, since each operation is associated with different layer-specific parameters and such associations need to be hardcoded. For example, after predicting a <code class="language-plaintext highlighter-rouge">conv</code> op, the model should output kernel size, stride size, etc; or after predicting an <code class="language-plaintext highlighter-rouge">FC</code> op, we need to see the number of units as the next prediction.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NAS-search-space.png" alt="The sequential layer-wise operation search space" /></p>
<p><em>Figure 2. (Top) A sequential representation of CNN. (Bottom) A sequential representation of the tree structure of a recurrent cell. (Image source: <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>)</em></p>
<p>To make sure the generated architecture is valid, additional rules might be needed (<a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>):</p>
<ul>
<li>If a layer is not connected to any input layer then it is used as the input layer;</li>
<li>At the final layer, take all layer outputs that have not been connected and concatenate them;</li>
<li>If one layer has many input layers, then all input layers are concatenated in the depth dimension;</li>
<li>If input layers to be concatenated have different sizes, we pad the small layers with zeros so that the concatenated layers have the same sizes.</li>
</ul>
<p>The skip connection can be predicted as well, using an <a href="/lil-log /2018/06/24/attention-attention.html">attention</a>-style mechanism. At layer \(i\) , an anchor point is added with \(i−1\) content-based sigmoids to indicate which of the previous layers to be connected. Each sigmoid takes as input the hidden states of the current node \(h_i\) and \(i-1\) previous nodes \(h_j, j=1, \dots, i-1\) .</p>
\[P(\text{Layer j is an input to layer i}) = \text{sigmoid}(v^\top \tanh(\mathbf{W}_\text{prev} h_j + \mathbf{W}_\text{curr} h_i))\]
<p>The sequential search space has a lot of representation power, but it is very large and consumes a ton of computation resources to exhaustively cover the search space. In the experiments by <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>, they were running 800 GPUs in parallel for 28 days and <a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a> restricted the search space to contain at most 2 <code class="language-plaintext highlighter-rouge">FC</code> layers.</p>
<h3 id="cell-based-representation">Cell-based Representation</h3>
<p>Inspired by the design of using repeated modules in successful vision model architectures (e.g. Inception, ResNet), the <em>NASNet search space</em> (<a href="https://arxiv.org/abs/1707.07012">Zoph et al. 2018</a>) defines the architecture of a conv net as the same cell getting repeated multiple times and each cell contains several operations predicted by the NAS algorithm. A well-designed cell module enables transferability between datasets. It is also easy to scale down or up the model size by adjusting the number of cell repeats.</p>
<p>Precisely, the NASNet search space learns two types of cells for network construction:</p>
<ol>
<li><em>Normal Cell</em>: The input and output feature maps have the same dimension.</li>
<li><em>Reduction Cell</em>: The output feature map has its width and height reduced by half.</li>
</ol>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/NASNet-search-space.png" alt="NASNet search space" /></p>
<p><em>Figure 3. The NASNet search space constrains the architecture as a repeated stack of cells. The cell architecture is optimized via NAS algorithms. (Image source: <a href="https://arxiv.org/abs/1707.07012">Zoph et al. 2018</a>)</em></p>
<p>The predictions for each cell are grouped into \(B\) blocks (\(B=5\) in the NASNet paper), where each block has 5 prediction steps made by 5 distinct softmax classifiers corresponding to discrete choices of the elements of a block. Note that the NASNet search space does not have residual connections between cells and the model only learns skip connections on their own within blocks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/cell-prediction-steps.png" alt="5 prediction steps in one block" /></p>
<p><em>Figure 4. (a) Each cell consists of \(B\) blocks and each block is predicted by 5 discrete decisions. (b) An concrete example of what operations can be chosen in each decision step.</em></p>
<p><a name="ScheduledDropPath"></a>During the experiments, they discovered that a modified version of <a href="https://arxiv.org/abs/1605.07648"><em>DropPath</em></a>, named <em>ScheduledDropPath</em>, significantly improves the final performance of NASNet experiments. DropPath stochastically drops out paths (i.e. edges with operations attached in NASNet) with a fixed probability. ScheduledDropPath is DropPath with a linearly increasing probability of path dropping during training time.</p>
<p><a href="https://arxiv.org/abs/1808.05377">Elsken, et al (2019)</a> point out three major advantages of the NASNet search space:</p>
<ol>
<li>The search space size is reduced drastically;</li>
<li>The <a href="https://en.wikipedia.org/wiki/Network_motif">motif</a>-based architecture can be more easily transferred to different datasets.</li>
<li>It demonstrates a strong proof of a useful design pattern of repeatedly stacking modules in architecture engineering. For example, we can build strong models by stacking residual blocks in CNN or stacking multi-headed attention blocks in Transformer.</li>
</ol>
<h3 id="hierarchical-structure">Hierarchical Structure</h3>
<p>To take advantage of already discovered well-designed network <a href="https://en.wikipedia.org/wiki/Network_motif">motifs</a>, the NAS search space can be constrained as a hierarchical structure, as in <em>Hierarchical NAS</em> (<strong>HNAS</strong>; (<a href="https://arxiv.org/abs/1711.00436">Liu et al 2017</a>)). It starts with a small set of primitives, including individual operations like convolution operation, pooling, identity, etc. Then small sub-graphs (or “motifs”) that consist of primitive operations are recursively used to form higher-level computation graphs.</p>
<p>A computation motif at level \(\ell=1, \dots, L\) can be represented by \((G^{(\ell)}, \mathcal{O}^{(\ell)})\), where:</p>
<ul>
<li>\(\mathcal{O}^{(\ell)}\) is a set of operations, \(\mathcal{O}^{(\ell)} = \{ o^{(\ell)}_1, o^{(\ell)}_2, \dots \}\)</li>
<li>\(G^{(\ell)}\) is an adjacency matrix, where the entry \(G_{ij}=k\) indicates that operation \(o^{(\ell)}_k\) is placed between node \(i\) and \(j\). The node indices follow <a href="https://en.wikipedia.org/wiki/Topological_sorting">topological ordering</a> in DAG, where the index \(1\) is the source and the maximal index is the sink node.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/hierarchical-NAS-search-space.png" alt="Hierarchical search space" /></p>
<p><em>Figure 5. (Top) Three level-1 primitive operations are composed into a level-2 motif. (Bottom) Three level-2 motifs are plugged into a base network structure and assembled into a level-3 motif. (Image source: <a href="https://arxiv.org/abs/1711.00436">Liu et al 2017</a>)</em></p>
<p>To build a network according to the hierarchical structure, we start from the lowest level \(\ell=1\) and recursively define the \(m\)-th motif operation at level \(\ell\) as</p>
\[o^{(\ell)}_m = \text{assemble}\Big( G_m^{(\ell)}, \mathcal{O}^{(\ell-1)} \Big)\]
<p>A hierarchical representation becomes \(\Big( \big\{ \{ G_m^{(\ell)} \}_{m=1}^{M_\ell} \big\}_{\ell=2}^L, \mathcal{O}^{(1)} \Big), \forall \ell=2, \dots, L\), where \(\mathcal{O}^{(1)}\) contains a set of primitive operations.</p>
<p>The \(\text{assemble}()\) process is equivalent to sequentially compute the feature map of node \(i\) by aggregating all the feature maps of its predecessor node \(j\) following the topological ordering:</p>
\[x_i = \text{merge} \big[ \{ o^{(\ell)}_{G^{(\ell)}_{ij}}(x_j) \}_{j < i} \big], i = 2, \dots, \vert G^{(\ell)} \vert\]
<p>where \(\text{merge}[]\) is implemented as depth-wise concatenation in the <a href="https://arxiv.org/abs/1711.00436">paper</a>.</p>
<p>Same as NASNet, experiments in <a href="https://arxiv.org/abs/1711.00436">Liu et al (2017)</a> focused on discovering good cell architecture within a predefined “macro” structure with repeated modules. They showed that the power of simple search methods (e.g. random search or evolutionary algorithms) can be substantially enhanced using well-designed search spaces.</p>
<p><a href="https://arxiv.org/abs/1806.02639">Cai et al (2018b)</a> propose a tree-structure search space using path-level network transformation. Each node in a tree structure defines an <em>allocation</em> scheme for splitting inputs for child nodes and a <em>merge</em> scheme for combining results from child nodes. The path-level network transformation allows replacing a single layer with a multi-branch motif if its corresponding merge scheme is add or concat.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/path-level-network-transformations.png" alt="Path-level network transformation" /></p>
<p><em>Figure 6. An illustration of transforming a single layer to a tree-structured motif via path-level transformation operations. (Image source: <a href="https://arxiv.org/abs/1806.02639">Cai et al. 2018b</a>)</em></p>
<h3 id="memory-bank-representation">Memory-bank Representation</h3>
<p>A memory-bank representation of feed-forward networks is proposed by <a href="https://arxiv.org/abs/1708.05344">Brock et al. (2017)</a> in <a href="#prediction-based">SMASH</a>. Instead of a graph of operations, they view a neural network as a system with multiple memory blocks which can read and write. Each layer operation is designed to: (1) read from a subset of memory blocks; (2) computes results; finally (3) write the results into another subset of blocks. For example, in a sequential model, a single memory block would get read and overwritten consistently.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/NAS-memory-bank-view-representation.png" alt="Memory-bank representation" /></p>
<p><em>Figure 7. Memory-bank representation of several popular network architecture blocks. (Image source: <a href="https://arxiv.org/abs/1708.05344">Brock et al. 2017</a>)</em></p>
<h2 id="search-algorithms">Search Algorithms</h2>
<p>NAS search algorithms sample a population of child networks. It receives the child models’ performance metrics as rewards and learns to generate high-performance architecture candidates. You may a lot in common with the field of hyperparameter search.</p>
<h3 id="random-search">Random Search</h3>
<p>Random search is the most naive baseline. It samples a valid architecture candidate from the search space <em>at random</em> and no learning model is involved. Random search has proved to be quite useful in hyperparameter search (<a href="http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf">Bergstra & Bengio 2012</a>). With a well-designed search space, random search could be a very challenging baseline to beat.</p>
<h3 id="reinforcement-learning">Reinforcement Learning</h3>
<p>The initial design of <strong>NAS</strong> (<a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>) involves a RL-based controller for proposing child model architectures for evaluation. The controller is implemented as a RNN, outputting a variable-length sequence of tokens used for configuring a network architecture.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/NAS.png" alt="NAS" /></p>
<p><em>Figure 8. A high level overview of NAS, containing a RNN controller and a pipeline for evaluating child models. (Image source: <a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>)</em></p>
<p>The controller is trained as a <em>RL task</em> using <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce">REINFORCE</a>.</p>
<ul>
<li><strong>Action space</strong>: The action space is a list of tokens for defining a child network predicted by the controller (See more in the above <a href="#sequential-layer-wise-operations">section</a>). The controller outputs <em>action</em>, \(a_{1:T}\), where \(T\) is the total number of tokens.</li>
<li><strong>Reward</strong>: The accuracy of a child network that can be achieved at convergence is the reward for training the controller, \(R\).</li>
<li><strong>Loss</strong>: NAS optimizes the controller parameters \(\theta\) with a REINFORCE loss. We want to maximize the expected reward (high accuracy) with the gradient as follows. The nice thing here with policy gradient is that it works even when the reward is non-differentiable.</li>
</ul>
\[\nabla_{\theta} J(\theta) = \sum_{t=1}^T \mathbb{E}[\nabla_{\theta} \log P(a_t \vert a_{1:(t-1)}; \theta) R ]\]
<p><strong>MetaQNN</strong> (<a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a>) trains an agent to sequentially choose CNN layers using <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#q-learning-off-policy-td-control"><em>Q-learning</em></a> with an <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#%CE%B5-greedy-algorithm">\(\epsilon\)-greedy</a> exploration strategy and experience replay. The reward is the validation accuracy as well.</p>
\[Q^{(t+1)}(s_t, a_t) = (1 - \alpha)Q^{(t)}(s_t, a_t) + \alpha (R_t + \gamma \max_{a \in \mathcal{A}} Q^{(t)}(s_{t+1}, a'))\]
<p>where a state \(s_t\) is a tuple of layer operation and related parameters. An action $$a$ determines the connectivity between operations. The Q-value is proportional to how confident we are in two connected operations leading to high accuracy.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/MetaQNN.png" alt="MetaQNN" /></p>
<p><em>Figure 9. Overview of MetaQNN - designing CNN models with Q-Learning. (Image source: <a href="https://arxiv.org/abs/1611.02167">Baker et al. 2017</a>)</em></p>
<h3 id="evolutionary-algorithms">Evolutionary Algorithms</h3>
<p><strong>NEAT</strong> (short for <em>NeuroEvolution of Augmenting Topologies</em>) is an approach for evolving neural network topologies with <a href="https://en.wikipedia.org/wiki/Genetic_algorithm">genetic algorithm (GA)</a>, proposed by <a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">Stanley & Miikkulainen</a> in 2002. NEAT evolves both connection weights and network topology together. Each gene encodes the full information for configuring a network, including node weights and edges. The population grows by applying mutation of both weights and connections, as well as crossover between two parent genes. For more in neuroevolution, please refer to the in-depth <a href="https://www.nature.com/articles/s42256-018-0006-z">survey</a> by Stanley et al. (2019).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NEAT-mutations.png" alt="Mutation operations in NEAT" /></p>
<p><em>Figure 10. Mutations in the NEAT algorithm. (Image source: Fig 3 & 4 in <a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">Stanley & Miikkulainen, 2002</a>)</em></p>
<p><a href="https://arxiv.org/abs/1802.01548">Real et al. (2018)</a> adopt the evolutionary algorithms (EA) as a way to search for high-performance network architectures, named <strong>AmoebaNet</strong>. They apply the <a href="https://en.wikipedia.org/wiki/Tournament_selection">tournament selection</a> method, which at each iteration picks a best candidate out of a random set of samples and places its mutated offspring back into the population. When the tournament size is \(1\), it is equivalent to random selection.</p>
<p><a href="aging-evolutionary-algorithms"></a>AmoebaNet modified the tournament selection to favor <em>younger</em> genotypes and always discard the oldest models within each cycle. Such an approach, named <em>aging evolution</em>, allows AmoebaNet to cover and explore more search space, rather than to narrow down on good performance models too early.</p>
<p>Precisely, in every cycle of the tournament selection with aging regularization (See Figure 11):</p>
<ol>
<li>Sample \(S\) models from the population and the one with highest accuracy is chosen as <em>parent</em>.</li>
<li>A <em>child</em> model is produced by mutating <em>parent</em>.</li>
<li>Then the child model is trained, evaluated and added back into the population.</li>
<li>The oldest model is removed from the population.</li>
</ol>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/aging-evolution-algorithm.png" alt="Aging evolution algorithm" /></p>
<p><em>Figure 11. The algorithm of aging evolution. (Image source: <a href="https://arxiv.org/abs/1802.01548">Real et al. 2018</a>)</em></p>
<p>Two types of mutations are applied:</p>
<ol>
<li><em>Hidden state mutation</em>: randomly chooses a pairwise combination and rewires a random end such that there is no loop in the graph.</li>
<li><em>Operation mutation</em>: randomly replaces an existing operation with a random one.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/AmoebaNet-mutations.png" alt="Mutations in AmoebaNet" /></p>
<p><em>Figure 12. Two types of mutations in AmoebaNet. (Image source: <a href="https://arxiv.org/abs/1802.01548">Real et al. 2018</a>)</em></p>
<p>In their experiments, EA and RL work equally well in terms of the final validation accuracy, but EA has better anytime performance and is able to find smaller models. Here using EA in NAS is still expensive in terms of computation, as each experiment took 7 days with 450 GPUs.</p>
<p><strong>HNAS</strong> (<a href="https://arxiv.org/abs/1711.00436">Liu et al 2017</a>) also employs the evolutionary algorithms (the original tournament selection) as their search strategy. In the <a href="#hierarchical-structure">hierarchical structure</a> search space, each edge is an operation. Thus genotype mutation in their experiments is applied by replacing a random edge with a different operation. The replacement set includes an <code class="language-plaintext highlighter-rouge">none</code> op, so it can alter, remove and add an edge. The initial set of genotypes is created by applying a large number of random mutations on “trivial” motifs (all identity mappings).</p>
<h3 id="progressive-decision-process">Progressive Decision Process</h3>
<p>Constructing a model architecture is a sequential process. Every additional operator or layer brings extra complexity. If we guide the search model to start the investigation from simple models and gradually evolve to more complex architectures, it is like to introduce <a href="/lil-log/2020/01/29/curriculum-for-reinforcement-learning.html">“curriculum”</a> into the search model’s learning process.</p>
<p><em>Progressive NAS</em> (<strong>PNAS</strong>; <a href="https://arxiv.org/abs/1712.00559">Liu, et al 2018</a>) frames the problem of NAS as a progressive procedure for searching models of increasing complexity. Instead of RL or EA, PNAS adopts a Sequential Model-based Bayesian Optimization (SMBO) as the search strategy. PNAS works similar to A* search, as it searches for models from simple to hard while simultaneously learning a surrogate function to guide the search.</p>
<blockquote>
<p><a href="https://en.wikipedia.org/wiki/A*_search_algorithm">A* search algorithm</a> (“best-first search”) is a popular algorithm for path finding. The problem is framed as finding a path of smallest cost from a specific starting node to a given target node in a weighted graph. At each iteration, A* finds a path to extend by minimizing: \(f(n)=g(n)+h(n)\), where \(n\) is the next node, \(g(n)\) is the cost from start to \(n\), and \(h(n)\) is the heuristic function that estimates the minimum cost of going from node \(n\) to the goal.</p>
</blockquote>
<p>PNAS uses the <a href="#cell-based-representation">NASNet</a> search space. Each block is specified as a 5-element tuple and PNAS only considers the element-wise addition as the step 5 combination operator, no concatenation. Differently, instead of setting the number of blocks \(B\) at a fixed number, PNAS starts with \(B=1\), a model with only one block in a cell, and gradually increases \(B\).</p>
<p>The performance on a validation set is used as feedback to train a <em>surrogate</em> model for <em>predicting</em> the performance of novel architectures. With this predictor, we can thus decide which models should be prioritized to be evaluated next. Since the performance predictor should be able to handle various-sized inputs, accuracy, and sample-efficient, they ended up using an RNN model.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/progressive-NAS-algorithm.png" alt="Progressive NAS" /></p>
<p><em>Figure 13. The algorithm of Progressive NAS. (Image source: <a href="https://arxiv.org/abs/1712.00559">Liu, et al 2018</a>)</em></p>
<h3 id="gradient-descent">Gradient descent</h3>
<p>Using gradient descent to update the architecture search model requires an effort to make the process of choosing discrete operations differentiable. These approaches usually combine the learning of both architecture parameters and network weights together into one model. See more in the <a href="#one-shot-approach-search--evaluation">section</a> on the <em>“one-shot”</em> approach.</p>
<h2 id="evaluation-strategy">Evaluation Strategy</h2>
<p>We need to measure, estimate or predict the performance of every child model in order to obtain feedback for optimizing the search algorithm. The process of candidate evaluation could be very expensive and many new evaluation methods have been proposed to save time or computation. When evaluating a child model, we mostly care about its performance measured as accuracy on a validation set. Recent work has started looking into other factors of a model, such as model size and latency, as certain devices may have limitations on memory or demand fast response time.</p>
<h3 id="training-from-scratch">Training from Scratch</h3>
<p>The most naive approach is to train every child network independently from scratch until <em>convergence</em> and then measure its accuracy on a validation set (<a href="https://arxiv.org/abs/1611.01578">Zoph & Le 2017</a>). It provides solid performance numbers, but one complete train-converge-evaluate loop only generates a single data sample for training the RL controller (let alone RL is known to be sample-inefficient in general). Thus it is very expensive in terms of computation consumption.</p>
<h3 id="proxy-task-performance">Proxy Task Performance</h3>
<p>There are several approaches for using a proxy task performance as the performance estimator of a child network, which is generally cheaper and faster to calculate:</p>
<ul>
<li>Train on a smaller dataset.</li>
<li>Train for fewer epochs.</li>
<li>Train and evaluate a down-scaled model in the search stage. For example, once a cell structure is learned, we can play with the number of cell repeats or scale up the number of filters (<a href="https://arxiv.org/abs/1707.07012">Zoph et al. 2018</a>).</li>
<li>Predict the learning curve. <a href="https://arxiv.org/abs/1705.10823">Baker et al (2018)</a> model the prediction of validation accuracies as a time-series regression problem. The features for the regression model (\(\nu\)-support vector machine regressions; \(\nu\)-SVR) include the early sequences of accuracy per epoch, architecture parameters, and hyperparameters.</li>
</ul>
<h3 id="parameter-sharing">Parameter Sharing</h3>
<p>Instead of training every child model independently from scratch. You may ask, ok, what if we fabricate dependency between them and find a way to reuse weights? Some researchers succeeded to make such approaches work.</p>
<p>Inspired by <a href="https://arxiv.org/abs/1511.05641">Net2net</a> transformation, <a href="https://arxiv.org/abs/1707.04873">Cai et al (2017)</a> proposed <em>Efficient Architecture Search</em> (<strong>EAS</strong>). EAS sets up an RL agent, known as a meta-controller, to predict function-preserving network transformation so as to grow the network depth or layer width. Because the network is growing incrementally, the weights of previously validated networks can be <em>reused</em> for further exploration. With inherited weights, newly constructed networks only need some light-weighted training.</p>
<p>A meta-controller learns to generate <em>network transformation actions</em> given the current network architecture, which is specified with a variable-length string. In order to handle architecture configuration of a variable length, the meta-controller is implemented as a bi-directional recurrent network. Multiple actor networks output different transformation decisions:</p>
<ol>
<li><em>Net2WiderNet</em> operation allows to replace a layer with a wider layer, meaning more units for fully-connected layers, or more filters for convolutional layers, while preserving the functionality.</li>
<li><em>Net2DeeperNet</em> operation allows to insert a new layer that is initialized as adding an identity mapping between two layers so as to preserve the functionality.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/EAS-meta-controller.png" alt="EAS meta-controller" /></p>
<p><em>Figure 14. Overview of the RL based meta-controller in Efficient Architecture Search (NAS). After encoding the architecture configuration, it outputs net2net transformation actions through two separate actor networks. (Image source: <a href="https://arxiv.org/abs/1707.04873">Cai et al 2017</a>)</em></p>
<p><a name="ENAS"></a>With similar motivation, <em>Efficient NAS</em> (<strong>ENAS</strong>; <a href="https://arxiv.org/abs/1802.03268">Pham et al. 2018</a>) speeds up NAS (i.e. 1000x less) by aggressively sharing parameters among child models. The core motivation behind ENAS is the observation that all of the sampled architecture graphs can be viewed as <em>sub-graphs</em> of a larger <em>supergraph</em>. All the child networks are sharing weights of this supergraph.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ENAS-example.png" alt="ENAS example" /></p>
<p><em>Figure 15. (Left) The graph represents the entire search space for a 4-node recurrent cell, but only connections in red are active. (Middle) An example of how the left active sub-graph can be translated into a child model architecture. (Right) The network parameters produced by an RNN controller for the architecture in the middle. (Image source: <a href="https://arxiv.org/abs/1802.03268">Pham et al. 2018</a>)</em></p>
<p>ENAS alternates between training the shared model weights \(\omega\) and training the controller \(\theta\):</p>
<ol>
<li>The parameters of the controller LSTM \(\theta\) are trained with <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce">REINFORCE</a>, where the reward \(R(\mathbf{m}, \omega)\) is computed on the validation set.</li>
<li>The shared parameters of the child models \(\omega\) are trained with standard supervised learning loss. Note that different operators associated with the same node in the supergraph would have their own distinct parameters.</li>
</ol>
<h3 id="prediction-based">Prediction-Based</h3>
<p>A routine child model evaluation loop is to update model weights via standard gradient descent. SMASH (<a href="https://arxiv.org/abs/1708.05344">Brock et al. 2017</a>) proposes a different and interesting idea: <em>Can we predict the model weights directly based on the network architecture parameters?</em></p>
<p>They employ a <a href="https://blog.otoro.net/2016/09/28/hyper-networks/">HyperNet</a> (<a href="https://arxiv.org/abs/1609.09106">Ha et al 2016</a>) to directly generate the weights of a model conditioned on an encoding of its architecture configuration. Then the model with HyperNet-generated weights is validated directly. Note that we don’t need extra training for every child model but we do need to train the HyperNet.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SMASH-algorithm.png" alt="SMASH algorithm" /></p>
<p><em>Figure 16. The algorithm of SMASH. (Image source: <a href="https://arxiv.org/abs/1708.05344">Brock et al. 2017</a>)</em></p>
<p>The correlation between model performance with SMASH-generated weights and true validation errors suggests that predicted weights can be used for model comparison, to some extent. We do need a HyperNet of large enough capacity, as the correlation would be corrupted if the HyperNet model is too small compared to the child model size.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SMASH-error-correlation.png" alt="SMASH error correlation" /></p>
<p><em>Figure 17. The algorithm of SMASH. (Image source: <a href="https://arxiv.org/abs/1708.05344">Brock et al. 2017</a>)</em></p>
<p>SMASH can be viewed as another way to implement the idea of <a href="#parameter-sharing">parameter sharing</a>. One problem of SMASH as pointed out by <a href="https://arxiv.org/abs/1802.03268">Pham et al. (2018)</a> is: The usage of HyperNet restricts the weights of SMASH child models to a <em>low-rank space</em>, because weights are generated via tensor products. In comparison, <a href="#ENAS">ENAS</a> has no such restrictions.</p>
<h2 id="one-shot-approach-search--evaluation">One-Shot Approach: Search + Evaluation</h2>
<p>Running search & evaluation independently for a large population of child models is expensive. We have seen promising approaches like <a href="https://arxiv.org/abs/1708.05344">Brock et al. (2017)</a> or <a href="https://arxiv.org/abs/1802.03268">Pham et al. (2018)</a>, where training a single model is enough for emulating any child model in the search space.</p>
<p>The <strong>one-shot</strong> architecture search extends the idea of weight sharing and further combines the learning of architecture generation together with weight parameters. The following approaches all treat child architectures as different sub-graphs of a supergraph with shared weights between common edges in the supergraph.</p>
<p><a href="http://proceedings.mlr.press/v80/bender18a/bender18a.pdf">Bender et al (2018)</a> construct a single large over-parameterized network, known as the <strong>One-Shot model</strong>, such that it contains every possible operation in the search space. With <a href="#ScheduledDropPath">ScheduledDropPath</a> (the dropout rate is increased over time, which is \(r^{1/k}\) at the end of training, where \(0 < r < 1\) is a hyperparam and \(k\) is the number of incoming paths) and some carefully designed tricks (e.g. ghost batch normalization, L2 regularization only on the active architecture), the training of such a giant model can be stabilized enough and used for evaluating any child model sampled from the supergraph.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/one-shot-model-architecture.png" alt="One-Shot model architecture" /></p>
<p><em>Figure 18. The architecture of the One-Shot model in <a href="http://proceedings.mlr.press/v80/bender18a/bender18a.pdf">Bender et al 2018</a>. Each cell has \(N\) choice blocks and each choice block can select up to 2 operations. Solid edges are used in every architecture, where dash lines are optional. (Image source: <a href="http://proceedings.mlr.press/v80/bender18a/bender18a.pdf">Bender et al 2018</a>)</em></p>
<p>Once the one-shot model is trained, it is used for evaluating the performance of many different architectures sampled at random by zeroing out or removing some operations. This sampling process can be replaced by RL or evolution.</p>
<p>They observed that the difference between the accuracy measured with the one-shot model and the accuracy of the same architecture after a small fine-tuning could be very large. Their hypothesis is that the one-shot model automatically learns to focus on the <em>most useful</em> operations in the network and comes to <em>rely on</em> these operations when they are available. Thus zeroing out useful operations lead to big reduction in model accuracy, while removing less important components only causes a small impact — Therefore, we see a larger variance in scores when using the one-shot model for evaluation.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/one-shot-model-accuracy-correlation.png" alt="One-shot accuracy" /></p>
<p><em>Figure 19. A stratified sample of models with different one-shot model accuracy versus their true validation accuracy as stand-alone models. (Image source: <a href="http://proceedings.mlr.press/v80/bender18a/bender18a.pdf">Bender et al 2018</a>)</em></p>
<p>Clearly designing such a search graph is not a trivial task, but it demonstrates a strong potential with the one-shot approach. It works well with only gradient descent and no additional algorithm like RL or EA is a must.</p>
<p>Some believe that one main cause for inefficiency in NAS is to treat the architecture search as a <em>black-box optimization</em> and thus we fall into methods like RL, evolution, SMBO, etc. If we shift to rely on standard gradient descent, we could potentially make the search process more effectively. As a result, <a href="https://arxiv.org/abs/1806.09055">Liu et al (2019)</a> propose <em>Differentiable Architecture Search</em> (<strong>DARTS</strong>). DARTS introduces a continuous relaxation on each path in the search supergraph, making it possible to jointly train architecture parameters and weights via gradient descent.</p>
<p>Let’s use the directed acyclic graph (DAG) representation here. A cell is a DAG consisting of a topologically ordered sequence of \(N\) nodes. Each node has a latent representation \(x_i\) to be learned. Each edge \((i, j)\) is tied to some operation \(o^{(i,j)} \in \mathcal{O}\) that transforms \(x_j\) to compose \(x_i\):</p>
\[x_i = \sum_{j < i} o^{(i,j)}(x_j)\]
<p>To make the search space continuous, DARTS relaxes the categorical choice of a particular operation as a softmax over all the operations and the task of architecture search is reduced to learn a set of mixing probabilities \(\alpha = \{ \alpha^{(i,j)} \}\).</p>
\[\bar{o}^{(i,j)}(x) = \sum_{o\in\mathcal{O}} \frac{\exp(\alpha_{ij}^o)}{\sum_{o'\in\mathcal{O}} \exp(\alpha^{o'}_{ij})} o(x)\]
<p>where \(\alpha_{ij}\) is a vector of dimension \(\vert \mathcal{O} \vert\), containing weights between nodes \(i\) and \(j\) over different operations.</p>
<p>The bilevel optimization exists as we want to optimize both the network weights \(w\) and the architecture representation \(\alpha\):</p>
\[\begin{aligned}
\min_\alpha & \mathcal{L}_\text{validate} (w^*(\alpha), \alpha) \\
\text{s.t.} & w^*(\alpha) = \arg\min_w \mathcal{L}_\text{train} (w, \alpha)
\end{aligned}\]
<p>At step \(k\), given the current architecture parameters \(\alpha_{k−1}\), we first optimize weights \(w_k\) by moving \(w_{k−1}\) in the direction of minimizing the training loss \(\mathcal{L}_\text{train}(w_{k−1}, \alpha_{k−1})\) with a learning rate \(\xi\). Next, while keeping the newly updated weights \(w_k\) fixed, we update the mixing probabilities so as to minimize the validation loss <em>after a single step of gradient descent w.r.t. the weights</em>:</p>
\[J_\alpha = \mathcal{L}_\text{val}(w_k - \xi \nabla_w \mathcal{L}_\text{train}(w_k, \alpha_{k-1}), \alpha_{k-1})\]
<p>The motivation here is that we want to find an architecture with a low validation loss when its weights are optimized by gradient descent and the one-step unrolled weights serve as the <em>surrogate</em> for \(w^∗(\alpha)\).</p>
<blockquote>
<p>Side note: Earlier we have seen similar formulation in <a href="/lil-log/2018/11/30/meta-learning.html#maml">MAML</a> where the two-step optimization happens between task losses and the meta-learner update, as well as framing <a href="/lil-log/2019/05/05/domain-randomization.html#dr-as-optimization">Domain Randomization</a> as a bilevel optimization for better transfer in the real environment.</p>
</blockquote>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DARTS-illustration.png" alt="DARTS" /></p>
<p><em>Figure 20. An illustration of how DARTS applies continuous relaxation on edges in DAG supergraph and identifies the final model. (Image source: <a href="https://arxiv.org/abs/1806.09055">Liu et al 2019</a>)</em></p>
\[\begin{aligned}
\text{Let }w'_k &= w_k - \xi \nabla_w \mathcal{L}_\text{train}(w_k, \alpha_{k-1}) & \\
J_\alpha &= \mathcal{L}_\text{val}(w_k - \xi \nabla_w \mathcal{L}_\text{train}(w_k, \alpha_{k-1}), \alpha_{k-1}) = \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) & \\
\nabla_\alpha J_\alpha
&= \nabla_{\alpha_{k-1}} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) \nabla_\alpha \alpha_{k-1} + \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1})\nabla_\alpha w'_k & \\& \text{; multivariable chain rule}\\
&= \nabla_{\alpha_{k-1}} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) + \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) \big( - \xi \color{red}{\nabla^2_{\alpha, w} \mathcal{L}_\text{train}(w_k, \alpha_{k-1})} \big) & \\
&\approx \nabla_{\alpha_{k-1}} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) - \xi \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1}) \color{red}{\frac{\nabla_\alpha \mathcal{L}_\text{train}(w_k^+, \alpha_{k-1}) - \nabla_\alpha \mathcal{L}_\text{train}(w_k^-, \alpha_{k-1}) }{2\epsilon}} & \\
& \text{; apply numerical differentiation approximation}
\end{aligned}\]
<p>where the red part is using numerical differentiation approximation where \(w_k^+ = w_k + \epsilon \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1})\) and \(w_k^- = w_k - \epsilon \nabla_{w'_k} \mathcal{L}_\text{val}(w'_k, \alpha_{k-1})\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DARTS-algorithm.png" alt="DARTS algorithm" /></p>
<p><em>Figure 21. The algorithm overview of DARTS. (Image source: <a href="https://arxiv.org/abs/1806.09055">Liu et al 2019</a>)</em></p>
<p>As another idea similar to DARTS, Stochastic NAS (<a href="https://arxiv.org/abs/1812.09926">Xie et al., 2019</a>) applies a continuous relaxation by employing the concrete distribution (CONCRETE = CONtinuous relaxations of disCRETE random variables; <a href="https://arxiv.org/abs/1611.00712">Maddison et al 2017</a>) and reparametrization tricks. The goal is same as DARTS, to make the discrete distribution differentiable and thus enable optimization by gradient descent.
<!--- TBA: maybe add more details on SNAS --></p>
<p>DARTS is able to greatly reduce the cost of GPU hours. Their experiments for searching for CNN cells have \(N=7\) and only took 1.5 days with a single GPU. However, it suffers from the high GPU memory consumption issue due to its continuous representation of network architecture. In order to fit the model into the memory of a single GPU, they picked a small \(N\).</p>
<p>To constrain the GPU memory consumption, <strong>ProxylessNAS</strong> (<a href="https://arxiv.org/abs/1812.00332">Cai et al., 2019</a>) views NAS as a path-level pruning process in DAG and binarizes the architecture parameters to force only one path to be active between two nodes at a time. The probabilities for an edge being either masked out or not are then learned by sampling a few binarized architectures and using <em>BinaryConnect</em> (<a href="https://arxiv.org/abs/1511.00363">Courbariaux et al., 2015</a>) to update the corresponding probabilities. ProxylessNAS demonstrates a strong connection between NAS and model compression. By using path-level compression, it is able to save memory consumption by one order of magnitude.</p>
<p>Let’s continue with the graph representation. In a DAG adjacency matrix \(G\) where \(G_{ij}\) represents an edge between node \(i\) and \(j\) and its value can be chosen from the set of \(\vert \mathcal{O} \vert\) candidate primitive operations, \(\mathcal{O} = \{ o_1, \dots \}\). The One-Shot model, DARTS and ProxylessNAS all consider each edge as a mixture of operations, \(m_\mathcal{O}\), but with different tweaks.</p>
<p>In One-Shot, \(m_\mathcal{O}(x)\) is the sum of all the operations. In DARTS, it is a weighted sum where weights are softmax over a real-valued architecture weighting vector \(\alpha\) of length \(\vert \mathcal{O} \vert\). ProxylessNAS transforms the softmax probabilities of \(\alpha\) into a binary gate and uses the binary gate to keep only one operation active at a time.</p>
\[\begin{aligned}
m^\text{one-shot}_\mathcal{O}(x) &= \sum_{i=1}^{\vert \mathcal{O} \vert} o_i(x) \\
m^\text{DARTS}_\mathcal{O}(x) &= \sum_{i=1}^{\vert \mathcal{O} \vert} p_i o_i(x) = \sum_{i=1}^{\vert \mathcal{O} \vert} \frac{\exp(\alpha_i)}{\sum_j \exp(\alpha_j)} o_i(x) \\
m^\text{binary}_\mathcal{O}(x) &= \sum_{i=1}^{\vert \mathcal{O} \vert} g_i o_i(x) = \begin{cases}
o_1(x) & \text{with probability }p_1, \\
\dots &\\
o_{\vert \mathcal{O} \vert}(x) & \text{with probability }p_{\vert \mathcal{O} \vert}
\end{cases} \\
\text{ where } g &= \text{binarize}(p_1, \dots, p_N) = \begin{cases}
[1, 0, \dots, 0] & \text{with probability }p_1, \\
\dots & \\
[0, 0, \dots, 1] & \text{with probability }p_N. \\
\end{cases}
\end{aligned}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/proxylessNAS-training.png" alt="Training steps of ProxylessNAS" /></p>
<p><em>Figure 22. ProxylessNAS has two training steps running alternatively. (Image source: <a href="https://arxiv.org/abs/1812.00332">Cai et al., 2019</a>)</em></p>
<p>ProxylessNAS runs two training steps alternatively:</p>
<ol>
<li>When training weight parameters \(w\), it freezes the architecture parameters \(\alpha\) and stochastically samples binary gates \(g\) according to the above \(m^\text{binary}_\mathcal{O}(x)\). The weight parameters can be updated with standard gradient descent.</li>
<li>When training architecture parameters \(\alpha\), it freezes \(w\), resets the binary gates and then updates \(\alpha\) on the validation set. Following the idea of <em>BinaryConnect</em>, the gradient w.r.t. architecture parameters can be approximately estimated using \(\partial \mathcal{L} / \partial g_i\) in replacement for \(\partial \mathcal{L} / \partial p_i\):</li>
</ol>
\[\begin{aligned}
\frac{\partial \mathcal{L}}{\partial \alpha_i}
&= \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial p_j} \frac{\partial p_j}{\partial \alpha_i}
\approx \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial g_j} \frac{\partial p_j}{\partial \alpha_i}
= \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial g_j} \frac{\partial \frac{e^{\alpha_j}}{\sum_k e^{\alpha_k}}}{\partial \alpha_i} \\
&= \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial g_j} \frac{\sum_k e^{\alpha_k} (\mathbf{1}_{i=j} e^{\alpha_j}) - e^{\alpha_j} e^{\alpha_i} }{(\sum_k e^{\alpha_k})^2}
= \sum_{j=1}^{\vert \mathcal{O} \vert} \frac{\partial \mathcal{L}}{\partial g_j} p_j (\mathbf{1}_{i=j} -p_i)
\end{aligned}\]
<p>Instead of BinaryConnect, REINFORCE can also be used for parameter updates with the goal for maximizing the reward, while no RNN meta-controller is involved.</p>
<p>Computing \(\partial \mathcal{L} / \partial g_i\) needs to calculate and store \(o_i(x)\), which requires \(\vert \mathcal{O} \vert\) times GPU memory. To resolve this issue, they factorize the task of choosing one path out of \(N\) into multiple binary selection tasks (Intuition: “if a path is the best choice, it should be better than any other path”). At every update step, only two paths are sampled while others are masked. These two selected paths are updated according to the above equation and then scaled properly so that other path weights are unchanged. After this process, one of the sampled paths is enhanced (path weight increases) and the other is attenuated (path weight decreases), while all other paths stay unaltered.</p>
<p>Besides accuracy, ProxylessNAS also considers <em>latency</em> as an important metric to optimize, as different devices might have very different requirements on inference time latency (e.g. GPU, CPU, mobile). To make latency differentiable, they model latency as a continuous function of the network dimensions. The expected latency of a mixed operation can be written as \(\mathbb{E}[\text{latency}] = \sum_j p_j F(o_j)\), where \(F(.)\) is a latency prediction model:</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/proxylessNAS-latency.png" alt="proxylessNAS latency" /></p>
<p><em>Figure 23. Add a differentiable latency loss into the training of ProxylessNAS. (Image source: <a href="https://arxiv.org/abs/1812.00332">Cai et al., 2019</a>)</em></p>
<h2 id="whats-the-future">What’s the Future?</h2>
<p>So far we have seen many interesting new ideas on automating the network architecture engineering through neural architecture search and many have achieved very impressive performance. However, it is a bit hard to do inference on <em>why</em> some architecture work well and how we can develop modules generalizable across tasks rather than being very dataset-specific.</p>
<p>As also noted in <a href="https://arxiv.org/abs/1808.05377">Elsken, et al (2019)</a>:</p>
<blockquote>
<p>“…, so far it provides little insights into why specific architectures work well and how similar the architectures derived in independent runs would be. Identifying common motifs, providing an understanding why those motifs are important for high performance, and investigating if these motifs generalize over different problems would be desirable.”</p>
</blockquote>
<p>In the meantime, purely focusing on improvement over validation accuracy might not be enough (<a href="https://arxiv.org/abs/1812.00332">Cai et al., 2019</a>). Devices like mobile phones for daily usage in general have limited memory and computation power. While AI applications are on the way to affect our daily life, it is unavoidable to be more <em>device-specific</em>.</p>
<p>Another interesting investigation is to consider <em>unlabelled dataset</em> and <a href="/lil-log/2019/11/10/self-supervised-learning.html">self-supervised learning</a> for NAS. The size of labelled dataset is always limited and it is not easy to tell whether such a dataset has biases or big deviation from the real world data distribution.</p>
<p><a href="https://arxiv.org/abs/2003.12056">Liu et al (2020)</a> delve into the question <em>“Can we find high-quality neural architecture without human-annotated labels?”</em> and proposed a new setup called <em>Unsupervised Neural Architecture Search</em> (<strong>UnNAS</strong>). The quality of the architecture needs to be estimated in an unsupervised fashion during the search phase. The paper experimented with three unsupervised <a href="/lil-log/2019/11/10/self-supervised-learning.html#images-based">pretext tasks</a>: image rotation prediction, colorization, and solving the jigsaw puzzle.</p>
<p>They observed in a set of UnNAS experiments that:</p>
<ol>
<li>High rank correlation between supervised accuracy and pretext accuracy <em>on the same dataset</em>. Typically the rank correlation is higher than 0.8, regardless of the dataset, the search space, and the pretext task.</li>
<li>High rank correlation between supervised accuracy and pretext accuracy <em>across datasets</em>.</li>
<li>Better pretext accuracy translates to better supervised accuracy.</li>
<li>Performance of UnNAS architecture is comparable to supervised counterparts, though not better yet.</li>
</ol>
<p>One hypothesis is that the architecture quality is correlated with image statistics. Because CIFAR-10 and ImageNet are all on the natural images, they are comparable and the results are transferable. UnNAS could potentially enable a much larger amount of unlabelled data into the search phase which captures image statistics better.</p>
<p>Hyperparameter search is a long-standing topic in the ML community. And NAS automates architecture engineering. Gradually we are trying to automate processes in ML which usually demand a lot of human efforts. Taking even one more step further, is it possible to automatically discover ML algorithms? <strong>AutoML-Zero</strong> (<a href="https://arxiv.org/abs/2003.03384">Real et al 2020</a>) investigates this idea. Using <a href="#aging-evolutionary-algorithms">aging evolutionary algorithms</a>, AutoML-Zero automatically searches for whole ML algorithms using little restriction on the form with only simple mathematical operations as building blocks.</p>
<p>It learns three component functions. Each function only adopts very basic operations.</p>
<ul>
<li><code class="language-plaintext highlighter-rouge">Setup</code>: initialize memory variables (weights).</li>
<li><code class="language-plaintext highlighter-rouge">Learn</code>: modify memory variables</li>
<li><code class="language-plaintext highlighter-rouge">Predict</code>: make a prediction from an input \(x\).</li>
</ul>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/AutoML-zero-evaluation.png" alt="AutoML-zero evaluation" /></p>
<p><em>Figure 23. Algorithm evaluation on one task (Image source: <a href="https://arxiv.org/abs/2003.03384">Real et al 2020</a>)</em></p>
<p>Three types of operations are considered when mutating a parent genotype:</p>
<ol>
<li>Insert a random instruction or remove an instruction at a random location in a component function;</li>
<li>Randomize all the instructions in a component function;</li>
<li>Modify one of the arguments of an instruction by replacing it with a random choice (e.g. “swap the output address” or “change the value of a constant”)</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/AutoML-zero-progress.png" alt="Progress of AutoML-zero experiment" /></p>
<p><em>Figure 24. An illustration of evolutionary progress on projected binary CIFAR-10 with example code. (Image source: <a href="https://arxiv.org/abs/2003.03384">Real et al 2020</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020nas,
title = "Neural Architecture Search",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/08/06/neural-architecture-search.html"
}
</code></pre></div></div>
<h2 id="appendix-summary-of-nas-papers">Appendix: Summary of NAS Papers</h2>
<table class="info">
<thead>
<tr>
<th>Model name</th>
<th>Search space</th>
<th>Search algorithms</th>
<th>Child model evaluation</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">NEAT (2002)</a></td>
<td>-</td>
<td>Evolution (Genetic algorithm)</td>
<td>-</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1611.01578">NAS (2017)</a></td>
<td>Sequential layer-wise ops</td>
<td>RL (REINFORCE)</td>
<td>Train from scratch until convergence</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1611.02167">MetaQNN (2017)</a></td>
<td>Sequential layer-wise ops</td>
<td>RL (Q-learning with $\epsilon$-greedy)</td>
<td>Train for 20 epochs</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1711.00436">HNAS (2017)</a></td>
<td>Hierarchical structure</td>
<td>Evolution (Tournament selection)</td>
<td>Train for a fixed number of iterations</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1707.07012">NASNet (2018)</a></td>
<td>Cell-based</td>
<td>RL (PPO)</td>
<td>Train for 20 epochs</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1802.01548">AmoebaNet (2018)</a></td>
<td>NASNet search space</td>
<td>Evolution (Tournament selection with aging regularization)</td>
<td>Train for 25 epochs</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1707.04873">EAS (2018a)</a></td>
<td>Network transformation</td>
<td>RL (REINFORCE)</td>
<td>2-stage training</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1712.00559">PNAS (2018)</a></td>
<td>Reduced version of NASNet search space</td>
<td>SMBO; Progressive search for architectures of increasing complexity</td>
<td>Train for 20 epochs</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1802.03268">ENAS (2018)</a></td>
<td>Both sequential and cell-based search space</td>
<td>RL (REINFORCE)</td>
<td>Train one model with shared weights</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1708.05344">SMASH (2017)</a></td>
<td>Memory-bank representation</td>
<td>Random search</td>
<td>HyperNet predicts weights of evaluated architectures.</td>
</tr>
<tr>
<td><a href="http://proceedings.mlr.press/v80/bender18a.html">One-Shot (2018)</a></td>
<td>An over-parameterized one-shot model</td>
<td>Random search (zero out some paths at random)</td>
<td>Train the one-shot model</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1806.09055">DARTS (2019)</a></td>
<td>NASNet search space</td>
<td colspan="2">Gradient descent (Softmax weights over operations)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1812.00332">ProxylessNAS (2019)</a></td>
<td>Tree structure architecture</td>
<td colspan="2">Gradient descent (BinaryConnect) or REINFORCE</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1812.09926">SNAS (2019)</a></td>
<td>NASNet search space</td>
<td colspan="2">Gradient descent (concrete distribution)</td>
</tr>
</tbody>
</table>
<h2 id="reference">Reference</h2>
<p>[1] Thomas Elsken, Jan Hendrik Metzen, Frank Hutter. <a href="https://arxiv.org/abs/1808.05377">“Neural Architecture Search: A Survey”</a> JMLR 20 (2019) 1-21.</p>
<p>[2] Kenneth O. Stanley, et al. <a href="https://www.nature.com/articles/s42256-018-0006-z">“Designing neural networks through neuroevolution”</a> Nature Machine Intelligence volume 1, pages 24–35 (2019).</p>
<p>[3] Kenneth O. Stanley & Risto Miikkulainen. <a href="http://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">“Evolving Neural Networks through Augmenting Topologies”</a> Evolutionary Computation 10(2): 99-127 (2002).</p>
<p>[4] Barret Zoph, Quoc V. Le. <a href="https://arxiv.org/abs/1611.01578">“Neural architecture search with reinforcement learning”</a> ICLR 2017.</p>
<p>[5] Bowen Baker, et al. <a href="https://arxiv.org/abs/1611.02167">“Designing Neural Network Architectures using Reinforcement Learning”</a> ICLR 2017.</p>
<p>[6] Bowen Baker, et al. <a href="https://arxiv.org/abs/1705.10823">“Accelerating neural architecture search using performance prediction”</a> ICLR Workshop 2018.</p>
<p>[7] Barret Zoph, et al. <a href="https://arxiv.org/abs/1707.07012">“Learning transferable architectures for scalable image recognition”</a> CVPR 2018.</p>
<p>[8] Hanxiao Liu, et al. <a href="https://arxiv.org/abs/1711.00436">“Hierarchical representations for efficient architecture search.”</a> ICLR 2018.</p>
<p>[9] Esteban Real, et al. <a href="https://arxiv.org/abs/1802.01548">“Regularized Evolution for Image Classifier Architecture Search”</a> arXiv:1802.01548 (2018).</p>
<p>[10] Han Cai, et al. [“Efficient architecture search by network transformation”] AAAI 2018a.</p>
<p>[11] Han Cai, et al. <a href="https://arxiv.org/abs/1806.02639">“Path-Level Network Transformation for Efficient Architecture Search”</a> ICML 2018b.</p>
<p>[12] Han Cai, Ligeng Zhu & Song Han. <a href="https://arxiv.org/abs/1812.00332">“ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware”</a> ICLR 2019.</p>
<p>[13] Chenxi Liu, et al. <a href="https://arxiv.org/abs/1712.00559">“Progressive neural architecture search”</a> ECCV 2018.</p>
<p>[14] Hieu Pham, et al. <a href="https://arxiv.org/abs/1802.03268">“Efficient neural architecture search via parameter sharing”</a> ICML 2018.</p>
<p>[15] Andrew Brock, et al. <a href="https://arxiv.org/abs/1708.05344">“SMASH: One-shot model architecture search through hypernetworks.”</a> ICLR 2018.</p>
<p>[16] Gabriel Bender, et al. <a href="http://proceedings.mlr.press/v80/bender18a.html">“Understanding and simplifying one-shot architecture search.”</a> ICML 2018.</p>
<p>[17] Hanxiao Liu, Karen Simonyan, Yiming Yang. <a href="https://arxiv.org/abs/1806.09055">“DARTS: Differentiable Architecture Search”</a> ICLR 2019.</p>
<p>[18] Sirui Xie, Hehui Zheng, Chunxiao Liu, Liang Lin. <a href="https://arxiv.org/abs/1812.09926">“SNAS: Stochastic Neural Architecture Search”</a> ICLR 2019.</p>
<p>[19] Chenxi Liu et al. <a href="https://arxiv.org/abs/2003.12056">“Are Labels Necessary for Neural Architecture Search?”</a> ECCV 2020.</p>
<p>[20] Esteban Real, et al. <a href="https://arxiv.org/abs/2003.03384">“AutoML-Zero: Evolving Machine Learning Algorithms From Scratch”</a> ICML 2020.</p>Lilian WengNeural Architecture Search (NAS) automates network architecture engineering. It aims to learn a network topology that can achieve best performance on a certain task. By dissecting the methods for NAS into three components: search space, search algorithm and child model evolution strategy, this post reviews many interesting ideas for better, faster and more cost-efficient automatic neural architecture search.Exploration Strategies in Deep Reinforcement Learning2020-06-07T12:00:00+00:002020-06-07T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/06/07/exploration-strategies-in-deep-reinforcement-learning<blockquote>
<p>Exploitation versus exploration is a critical topic in reinforcement learning. This post introduces several common approaches for better exploration in Deep RL.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2020-06-17: Add <a href="#exploration-via-disagreement">“exploration via disagreement”</a> in the “Forward Dynamics” <a href="#forward-dynamics">section</a>.</span></p>
<p><a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html">Exploitation versus exploration</a> is a critical topic in Reinforcement Learning. We’d like the RL agent to find the best solution as fast as possible. However, in the meantime, committing to solutions too quickly without enough exploration sounds pretty bad, as it could lead to local minima or total failure. Modern <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">RL</a> <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html">algorithms</a> that optimize for the best returns can achieve good exploitation quite efficiently, while exploration remains more like an open topic.</p>
<p>I would like to discuss several common exploration strategies in Deep RL here. As this is a very big topic, my post by no means can cover all the important subtopics. I plan to update it periodically and keep further enriching the content gradually in time.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#classic-exploration-strategies" id="markdown-toc-classic-exploration-strategies">Classic Exploration Strategies</a></li>
<li><a href="#key-exploration-problems" id="markdown-toc-key-exploration-problems">Key Exploration Problems</a> <ul>
<li><a href="#the-hard-exploration-problem" id="markdown-toc-the-hard-exploration-problem">The Hard-Exploration Problem</a></li>
<li><a href="#the-noisy-tv-problem" id="markdown-toc-the-noisy-tv-problem">The Noisy-TV Problem</a></li>
</ul>
</li>
<li><a href="#intrinsic-rewards-as-exploration-bonuses" id="markdown-toc-intrinsic-rewards-as-exploration-bonuses">Intrinsic Rewards as Exploration Bonuses</a> <ul>
<li><a href="#count-based-exploration" id="markdown-toc-count-based-exploration">Count-based Exploration</a> <ul>
<li><a href="#counting-by-density-model" id="markdown-toc-counting-by-density-model">Counting by Density Model</a></li>
<li><a href="#counting-after-hashing" id="markdown-toc-counting-after-hashing">Counting after Hashing</a></li>
</ul>
</li>
<li><a href="#prediction-based-exploration" id="markdown-toc-prediction-based-exploration">Prediction-based Exploration</a> <ul>
<li><a href="#forward-dynamics" id="markdown-toc-forward-dynamics">Forward Dynamics</a></li>
<li><a href="#random-networks" id="markdown-toc-random-networks">Random Networks</a></li>
<li><a href="#physical-properties" id="markdown-toc-physical-properties">Physical Properties</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#memory-based-exploration" id="markdown-toc-memory-based-exploration">Memory-based Exploration</a> <ul>
<li><a href="#episodic-memory" id="markdown-toc-episodic-memory">Episodic Memory</a></li>
<li><a href="#direct-exploration" id="markdown-toc-direct-exploration">Direct Exploration</a></li>
</ul>
</li>
<li><a href="#q-value-exploration" id="markdown-toc-q-value-exploration">Q-Value Exploration</a></li>
<li><a href="#varitional-options" id="markdown-toc-varitional-options">Varitional Options</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="classic-exploration-strategies">Classic Exploration Strategies</h2>
<p>As a quick recap, let’s first go through several classic exploration algorithms that work out pretty well in the multi-armed bandit problem or simple tabular RL.</p>
<ul>
<li><strong>Epsilon-greedy</strong>: The agent does random exploration occasionally with probability \(\epsilon\) and takes the optimal action most of the time with probability \(1-\epsilon\).</li>
<li><strong>Upper confidence bounds</strong>: The agent selects the greediest action to maximize the upper confidence bound \(\hat{Q}_t(a) + \hat{U}_t(a)\), where \(\hat{Q}_t(a)\) is the average rewards associated with action \(a\) up to time \(t\) and \(\hat{U}_t(a)\) is a function reversely proportional to how many times action \(a\) has been taken. See <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#upper-confidence-bounds">here</a> for more details.</li>
<li><strong>Boltzmann exploration</strong>: The agent draws actions from a <a href="https://en.wikipedia.org/wiki/Boltzmann_distribution">boltzmann distribution</a> (softmax) over the learned Q values, regulated by a temperature parameter \(\tau\).</li>
<li><strong>Thompson sampling</strong>: The agent keeps track of a belief over the probability of optimal actions and samples from this distribution. See <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#thompson-sampling">here</a> for more details.</li>
</ul>
<p>The following strategies could be used for better exploration in deep RL training when neural networks are used for function approximation:</p>
<ul>
<li><strong>Entropy loss term</strong>: Add an entropy term \(H(\pi(a \vert s))\) into the loss function, encouraging the policy to take diverse actions.</li>
<li><strong>Noise-based Exploration</strong>: Add noise into the observation, action or even parameter space (<a href="https://arxiv.org/abs/1706.10295">Fortunato, et al. 2017</a>, <a href="https://arxiv.org/abs/1706.01905">Plappert, et al. 2017</a>).</li>
</ul>
<h2 id="key-exploration-problems">Key Exploration Problems</h2>
<p>Good exploration becomes especially hard when the environment rarely provides rewards as feedback or the environment has distracting noise. Many exploration strategies are proposed to solve one or both of the following problems.</p>
<h3 id="the-hard-exploration-problem">The Hard-Exploration Problem</h3>
<p>The “hard-exploration” problem refers to exploration in an environment with very sparse or even deceptive reward. It is difficult because random exploration in such scenarios can rarely discover successful states or obtain meaningful feedback.</p>
<p><a href="https://en.wikipedia.org/wiki/Montezuma%27s_Revenge_(video_game)">Montezuma’s Revenge</a> is a concrete example for the hard-exploration problem. It remains as a few challenging games in Atari for DRL to solve. Many papers use Montezuma’s Revenge to benchmark their results.</p>
<h3 id="the-noisy-tv-problem">The Noisy-TV Problem</h3>
<p>The “Noisy-TV” problem started as a thought experiment in <a href="https://arxiv.org/abs/1810.12894">Burda, et al (2018)</a>. Imagine that an RL agent is rewarded with seeking novel experience, a TV with uncontrollable & unpredictable random noise outputs would be able to attract the agent’s attention forever. The agent obtains new rewards from noisy TV consistently, but it fails to make any meaningful progress and becomes a “couch potato”.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/the-noisy-TV-problem.gif" alt="The noisy-TV problem" /></p>
<p><em>Fig. 1. An agent is rewarded with novel experience in the experiment. If a maze has a noisy TC set up, the agent would be attracted and stop moving in the maze. (Image source: <a href="https://openai.com/blog/reinforcement-learning-with-prediction-based-rewards/">OpenAI Blog: “Reinforcement Learning with Prediction-Based Rewards”</a>)</em></p>
<h2 id="intrinsic-rewards-as-exploration-bonuses">Intrinsic Rewards as Exploration Bonuses</h2>
<p>One common approach to better exploration, especially for solving the <a href="#the-hard-exploration-problem">hard-exploration</a> problem, is to augment the environment reward with an additional bonus signal to encourage extra exploration. The policy is thus trained with a reward composed of two terms, \(r_t = r^e_t + \beta r^i_t\), where \(\beta\) is a hyperparameter adjusting the balance between exploitation and exploration.</p>
<ul>
<li>\(r^e_t\) is an <em>extrinsic</em> reward from the environment at time \(t\), defined according to the task in hand.</li>
<li>\(r^i_t\) is an <em>intrinsic</em> exploration bonus at time \(t\).</li>
</ul>
<p>This intrinsic reward is somewhat inspired by <em>intrinsic motivation</em> in psychology (<a href="https://www.researchgate.net/profile/Pierre-Yves_Oudeyer/publication/29614795_How_can_we_define_intrinsic_motivation/links/09e415107f1b4c8041000000/How-can-we-define-intrinsic-motivation.pdf">Oudeyer & Kaplan, 2008</a>). Exploration driven by curiosity might be an important way for children to grow and learn. In other words, exploratory activities should be rewarding intrinsically in the human mind to encourage such behavior. The intrinsic rewards could be correlated with curiosity, surprise, familiarity of the state, and many other factors.</p>
<p>Same ideas can be applied to RL algorithms. In the following sections, methods of bonus-based exploration rewards are roughly grouped into two categories:</p>
<ol>
<li>Discovery of novel states</li>
<li>Improvement of the agent’s knowledge about the environment.</li>
</ol>
<h3 id="count-based-exploration">Count-based Exploration</h3>
<p>If we consider intrinsic rewards as rewarding conditions that surprise us, we need a way to measure whether a state is novel or appears often. One intuitive way is to count how many times a state has been encountered and to assign a bonus accordingly. The bonus guides the agent’s behavior to prefer rarely visited states to common states. This is known as the <strong>count-based exploration</strong> method.</p>
<p>Let \(N_n(s)\) be the <em>empirical count</em> function that tracks the real number of visits of a state \(s\) in the sequence of \(s_{1:n}\). Unfortunately, using \(N_n(s)\) for exploration directly is not practical, because most of the states would have \(N_n(s)=0\), especially considering that the state space is often continuous or high-dimensional. We need an non-zero count for most states, even when they haven’t been seen before.</p>
<h4 id="counting-by-density-model">Counting by Density Model</h4>
<p><a href="https://arxiv.org/abs/1606.01868">Bellemare, et al. (2016)</a> used a <strong>density model</strong> to approximate the frequency of state visits and a novel algorithm for deriving a <em>pseudo-count</em> from this density model. Let’s first define a conditional probability over the state space, \(\rho_n(s) = \rho(s \vert s_{1:n})\) as the probability of the \((n+1)\)-th state being \(s\) given the first \(n\) states are \(s_{1:n}\). To measure this empirically, we can simply use \(N_n(s)/n\).</p>
<p>Let’s also define a <em>recoding probability</em> of a state \(s\) as the probability assigned by the density model to \(s\) <em>after observing a new occurrence of</em> \(s\), \(\rho'_n(s) = \rho(s \vert s_{1:n}s)\).</p>
<p>The paper introduced two concepts to better regulate the density model, a <em>pseudo-count</em> function \(\hat{N}_n(s)\) and a <em>pseudo-count total</em> \(\hat{n}\). As they are designed to imitate an empirical count function, we would have:</p>
\[\rho_n(s) = \frac{\hat{N}_n(s)}{\hat{n}} \leq \rho'_n(s) = \frac{\hat{N}_n(s) + 1}{\hat{n} + 1}\]
<p>The relationship between \(\rho_n(x)\) and \(\rho'_n(x)\) requires the density model to be <em>learning-positive</em>: for all \(s_{1:n} \in \mathcal{S}^n\) and all \(s \in \mathcal{S}\), \(\rho_n(s) \leq \rho'_n(s)\). In other words, After observing one instance of \(s\), the density model’s prediction of that same \(s\) should increase. Apart from being learning-positive, the density model should be trained completely <em>online</em> with non-randomized mini-batches of experienced states, so naturally we have \(\rho'_n = \rho_{n+1}\).</p>
<p>The pseudo-count can be computed from \(\rho_n(s)\) and \(\rho'_n(s)\) after solving the above linear system:</p>
\[\hat{N}_n(s) = \hat{n} \rho_n(s) = \frac{\rho_n(s)(1 - \rho'_n(s))}{\rho'_n(s) - \rho_n(s)}\]
<p>Or estimated by the <em>prediction gain (PG)</em>:</p>
\[\hat{N}_n(s) \approx (e^{\text{PG}_n(s)} - 1)^{-1} = (e^{\log \rho'_n(s) - \log \rho(s)} - 1)^{-1}\]
<p>A common choice of a count-based intrinsic bonus is \(r^i_t = N(s_t, a_t)^{-1/2}\) (as in MBIE-EB; <a href="https://www.ics.uci.edu/~dechter/courses/ics-295/fall-2019/papers/2008-littman-aij-main.pdf">Strehl & Littman, 2008</a>). The pseudo-count-based exploration bonus is shaped in a similar form, \(r^i_t = \big(\hat{N}_n(s_t, a_t) + 0.01 \big)^{-1/2}\).</p>
<p>Experiments in <a href="https://arxiv.org/abs/1606.01868">Bellemare et al., (2016)</a> adopted a simple <a href="http://proceedings.mlr.press/v32/bellemare14.html">CTS</a> (Context Tree Switching) density model to estimate pseudo-counts. The CTS model takes as input a 2D image and assigns to it a probability according to the product of location-dependent L-shaped filters, where the prediction of each filter is given by a CTS algorithm trained on past images. The CTS model is simple but limited in expressiveness, scalability, and data efficiency. In a following-up paper, <a href="https://arxiv.org/abs/1703.01310">Georg Ostrovski, et al. (2017)</a> improved the approach by training a PixelCNN (<a href="https://arxiv.org/abs/1606.05328">van den Oord et al., 2016</a>) as the density model.</p>
<p>The density model can also be a Gaussian Mixture Model as in <a href="https://arxiv.org/abs/1902.08039">Zhao & Tresp (2018)</a>. They used a variational GMM to estimate the density of trajectories (e.g. concatenation of a sequence of states) and its predicted probabilities to guide prioritization in experience replay in off-policy setting.</p>
<h4 id="counting-after-hashing">Counting after Hashing</h4>
<p>Another idea to make it possible to count high-dimensional states is to map states into <strong>hash codes</strong> so that the occurrences of states become trackable (<a href="https://arxiv.org/abs/1611.04717">Tang et al. 2017</a>). The state space is discretized with a hash function \(\phi: \mathcal{S} \mapsto \mathbb{Z}^k\). An exploration bonus \(r^{i}: \mathcal{S} \mapsto \mathbb{R}\) is added to the reward function, defined as \(r^{i}(s) = {N(\phi(s))}^{-1/2}\), where \(N(\phi(s))\) is an empirical count of occurrences of \(\phi(s)\).</p>
<p><a href="https://arxiv.org/abs/1611.04717">Tang et al. (2017)</a> proposed to use <em>Locality-Sensitive Hashing</em> (<a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing"><em>LSH</em></a>) to convert continuous, high-dimensional data to discrete hash codes. LSH is a popular class of hash functions for querying nearest neighbors based on certain similarity metrics. A hashing scheme \(x \mapsto h(x)\) is locality-sensitive if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. (See how LSH is used in <a href="/lil-log/2020/04/07/the-transformer-family.html#LSH">Transformer improvement</a> if interested.) <a href="https://www.cs.princeton.edu/courses/archive/spr04/cos598B/bib/CharikarEstim.pdf">SimHash</a> is a type of computationally efficient LSH and it measures similarity by angular distance:</p>
\[\phi(s) = \text{sgn}(A g(s)) \in \{-1, 1\}^k\]
<p>where \(A \in \mathbb{R}^{k \times D}\) is a matrix with each entry drawn i.i.d. from a standard Gaussian and \(g: \mathcal{S} \mapsto \mathbb{R}^D\) is an optional preprocessing function. The dimension of binary codes is \(k\), controlling the granularity of the state space discretization. A higher \(k\) leads to higher granularity and fewer collisions.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/count-hashing-exploration.png" alt="#Exploration" /></p>
<p><em>Fig. 2. Algorithm of count-based exploration through hashing high-dimensional states by SimHash. (Image source: <a href="https://arxiv.org/abs/1611.04717">Tang et al. 2017</a>)</em></p>
<p>For high-dimensional images, SimHash may not work well on the raw pixel level. <a href="https://arxiv.org/abs/1611.04717">Tang et al. (2017)</a> designed an autoencoder (AE) which takes as input states \(s\) to learn hash codes. It has one special dense layer composed of \(k\) sigmoid functions as the latent state in the middle and then the sigmoid activation values \(b(s)\) of this layer are binarized by rounding to their closest binary numbers \(\lfloor b(s)\rceil \in \{0, 1\}^D\) as the binary hash codes for state \(s\). The AE loss over \(n\) states includes two terms:</p>
\[\mathcal{L}(\{s_n\}_{n=1}^N) = \underbrace{-\frac{1}{N} \sum_{n=1}^N \log p(s_n)}_\text{reconstruction loss} + \underbrace{\frac{1}{N} \frac{\lambda}{K} \sum_{n=1}^N\sum_{i=1}^k \min \big \{ (1-b_i(s_n))^2, b_i(s_n)^2 \big\}}_\text{sigmoid activation being closer to binary}\]
<p>One problem with this approach is that dissimilar inputs \(s_i, s_j\) may be mapped to identical hash codes but the AE still reconstructs them perfectly. One can imagine replacing the bottleneck layer \(b(s)\) with the hash codes \(\lfloor b(s)\rceil\), but then gradients cannot be back-propagated through the rounding function. Injecting uniform noise could mitigate this effect, as the AE has to learn to push the latent variable far apart to counteract the noise.</p>
<h3 id="prediction-based-exploration">Prediction-based Exploration</h3>
<p>The second category of intrinsic exploration bonuses are rewarded for improvement of the agent’s knowledge about the environment. The agent’s familiarity with the environment dynamics can be estimated through a prediction model. This idea of using a prediction model to measure <em>curiosity</em> was actually proposed quite a long time ago (<a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.45.957">Schmidhuber, 1991</a>).</p>
<h4 id="forward-dynamics">Forward Dynamics</h4>
<p>Learning a <strong>forward dynamics prediction model</strong> is a great way to approximate how much knowledge our model has obtained about the environment and the task MDPs. It captures an agent’s capability of predicting the consequence of its own behavior, \(f: (s_t, a_t) \mapsto s_{t+1}\). Such a model cannot be perfect (e.g. due to partial observation), the error \(e(s_t, a_t) = \| f(s_t, a_t) - s_{t+1} \|^2_2\) can be used for providing intrinsic exploration rewards. The higher the prediction error, the less familiar we are with that state. The faster the error rate drops, the more learning progress signals we acquire.</p>
<p><em>Intelligent Adaptive Curiosity</em> (<strong>IAC</strong>; <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.177.7661&rep=rep1&type=pdf">Oudeyer, et al. 2007</a>) sketched an idea of using a forward dynamics prediction model to estimate learning progress and assigned intrinsic exploration reward accordingly.</p>
<p>IAC relies on a memory which stores all the experiences encountered by the robot, \(M=\{(s_t, a_t, s_{t+1})\}\) and a forward dynamics model \(f\). IAC incrementally splits the state space (i.e. sensorimotor space in the context of robotics, as discussed in the paper) into separate regions based on the transition samples, using a process similar to how a decision tree is split: The split happens when the number of samples is larger than a threshold, and the variance of states in each leaf should be minimal. Each tree node is characterized by its exclusive set of samples and has its own forward dynamics predictor \(f\), named “expert”.</p>
<p>The prediction error \(e_t\) of an expert is pushed into a list associated with each region. The <em>learning progress</em> is then measured as the difference between the mean error rate of a moving window with offset \(\tau\) and the current moving window. The intrinsic reward is defined for tracking the learning progress: \(r^i_t = \frac{1}{k}\sum_{i=0}^{k-1}(e_{t-i-\tau} - e_{t-i})\), where \(k\) is the moving window size. So the larger prediction error rate decrease we can achieve, the higher intrinsic reward we would assign to the agent. In other words, the agent is encouraged to take actions to quickly learn about the environment.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/IAC.png" alt="IAC" /></p>
<p><em>Fig. 3. Architecture of the IAC (Intelligent Adaptive Curiosity) module: the intrinsic reward is assigned w.r.t the learning progress in reducing prediction error of the dynamics model. (Image source: <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.177.7661&rep=rep1&type=pdf">Oudeyer, et al. 2007</a>)</em></p>
<p><a href="https://arxiv.org/abs/1507.00814">Stadie et al. (2015)</a> trained a forward dynamics model in the encoding space defined by \(\phi\), \(f_\phi: (\phi(s_t), a_t) \mapsto \phi(s_{t+1})\). The model’s prediction error at time \(T\) is normalized by the maximum error up to time \(t\), \(\bar{e}_t = \frac{e_t}{\max_{i \leq t} e_i}\), so it is always between 0 and 1. The intrinsic reward is defined accordingly: \(r^i_t = (\frac{\bar{e}_t(s_t, a_t)}{t \cdot C})\), where \(C > 0\) is a decay constant.</p>
<p>Encoding the state space via \(\phi(.)\) is necessary, as experiments in the paper have shown that a dynamics model trained directly on raw pixels has <em>very poor</em> behavior — assigning same exploration bonuses to all the states. In <a href="https://arxiv.org/abs/1507.00814">Stadie et al. (2015)</a>, the encoding function \(\phi\) is learned via an autocoder (AE) and \(\phi(.)\) is one of the output layers in AE. The AE can be statically trained using a set of images collected by a random agent, or dynamically trained together with the policy where the early frames are gathered using <a href="#classic-exploration-strategies">\(\epsilon\)-greedy</a> exploration.</p>
<p><a name="ICM"></a>Instead of autoencoder, <em>Intrinsic Curiosity Module</em> (<strong>ICM</strong>; <a href="https://arxiv.org/abs/1705.05363">Pathak, et al., 2017</a>) learns the state space encoding \(\phi(.)\) with a self-supervised <strong>inverse dynamics</strong> model. Predicting the next state given the agent’s own action is not easy, especially considering that some factors in the environment cannot be controlled by the agent or do not affect the agent. ICM believes that a good state feature space should exclude such factors because <em>they cannot influence the agent’s behavior and thus the agent has no incentive for learning them</em>. By learning an inverse dynamics model \(g: (\phi(s_t), \phi(s_{t+1})) \mapsto a_t\), the feature space only captures those changes in the environment related to the actions of our agent, and ignores the rest.</p>
<p>Given a forward model \(f\), an inverse dynamics model \(g\) and an observation \((s_t, a_t, s_{t+1})\):</p>
\[g_{\psi_I}(\phi(s_t), \phi(s_{t+1})) = \hat{a}_t \quad
f_{\psi_F}(\phi(s_t), a_t) = \hat{\phi}(s_{t+1}) \quad
r_t^i = \| \hat{\phi}(s_{t+1}) - \phi(s_{t+1}) \|_2^2\]
<p>Such \(\phi(.)\) is expected to be robust to uncontrollable aspects of the environment.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ICM.png" alt="ICM" /></p>
<p><em>Fig. 4. ICM (Intrinsic Curiosity Module) assigns the forward dynamics prediction error to the agent as the intrinsic reward. This dynamics model operates in a state encoding space learned through an inverse dynamics model to exclude environmental factors that do not affect the agent’s behavior. (Image source: <a href="https://arxiv.org/abs/1705.05363">Pathak, et al. 2017</a>)</em></p>
<p><a href="https://arxiv.org/abs/1808.04355">Burda, Edwards & Pathak, et al. (2018)</a> did a set of large-scale comparison experiments on purely curiosity-driven learning, meaning that only intrinsic rewards are provided to the agent. In this study, the reward is \(r_t = r^i_t = \| f(s_t, a_t) - \phi(s_{t+1})\|_2^2\). A good choice of \(\phi\) is crucial to learning forward dynamics, which is expected to be <em>compact</em>, <em>sufficient</em> and <em>stable</em>, making the prediction task more tractable and filtering out irrelevant observation.</p>
<p>In comparison of 4 encoding functions:</p>
<ol>
<li>Raw image pixels: No encoding, \(\phi(x) = x\).</li>
<li><a name="random-feature"></a>Random features (RF): Each state is compressed through a fixed random neural network.</li>
<li><a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#vae-variational-autoencoder">VAE</a>: The probabilistic encoder is used for encoding, \(\phi(x) = q(z \vert x)\).</li>
<li>Inverse dynamic features (IDF): The same feature space as used in <a href="#ICM">ICM</a>.</li>
</ol>
<p>All the experiments have the reward signals normalized by a running estimation of standard deviation of the cumulative returns. And all the experiments are running in an infinite horizon setting to avoid “done” flag leaking information.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/large-scale-curiosity-learning.png" alt="Large-scale curiosity learning" /></p>
<p><em>Fig. 5. The mean reward in different games when training with only curiosity signals, generated by different state encoding functions.
(Image source: <a href="https://arxiv.org/abs/1808.04355">Burda, Edwards & Pathak, et al. 2018</a>)</em></p>
<p>Interestingly <em>random features</em> turn out to be quite competitive, but in feature transfer experiments (i.e. train an agent in Super Mario Bros level 1-1 and then test it in another level), learned IDF features can generalize better.</p>
<p>They also compared RF and IDF in an environment with a <a href="#the-noisy-tv-problem">noisy TV</a> on. Unsurprisingly the noisy TV drastically slows down the learning and extrinsic rewards are much lower in time.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/noisy-TV-experiment.png" alt="Noisy TV experiment" /></p>
<p><em>Fig. 6. Experiments using RF and IDF feature encoding in an environment with noisy TV on or off. The plot tracks extrinsic reward per episode as the training progresses. (Image source: <a href="https://arxiv.org/abs/1808.04355">Burda, Edwards & Pathak, et al. 2018</a>)</em></p>
<p>The forward dynamics optimization can be modeled via variational inference as well. <strong>VIME</strong> (short for <em>“Variational information maximizing exploration”</em>; <a href="https://arxiv.org/abs/1605.09674">Houthooft, et al. 2017</a>) is an exploration strategy based on maximization of <em>information gain</em> about the agent’s belief of environment dynamics. How much additional information has been obtained about the forward dynamics can be measured as the reduction in entropy.</p>
<p>Let \(\mathcal{P}\) be the environment transition function, \(p(s_{t+1}\vert s_t, a_t; \theta)\) be the forward prediction model, parameterized by \(\theta \in \Theta\), and \(\xi_t = \{s_1, a_1, \dots, s_t\}\) be the trajectory history. We would like to reduce the entropy after taking a new action and observing the next state, which is to maximize the following:</p>
\[\begin{aligned}
&\sum_t H(\Theta \vert \xi_t, a_t) - H(\Theta \vert S_{t+1}, \xi_t, a_t) \\
=& I(\Theta; S_{t+1} \vert \xi_t, a_t) \quad \scriptstyle{\text{; because } I(X; Y) = I(X) - I(X \vert Y)} \\
=& \mathbb{E}_{s_{t+1} \sim \mathcal{P}(.\vert\xi_t,a_t)} [D_\text{KL}(p(\theta \vert \xi_t, a_t, s_{t+1}) \| p(\theta \vert \xi_t, a_t))] \quad \scriptstyle{\text{; because } I(X; Y) = \mathbb{E}_Y [D_\text{KL} (p_{X \vert Y} \| p_X)]} \\
=& \mathbb{E}_{s_{t+1} \sim \mathcal{P}(.\vert\xi_t,a_t)} [D_\text{KL}(p(\theta \vert \xi_t, a_t, s_{t+1}) \| p(\theta \vert \xi_t))] \quad \scriptstyle{\text{; because } \theta \text{ does not depend on } a_t}
\end{aligned}\]
<p>While taking expectation over the new possible states, the agent is expected to take a new action to increase the KL divergence (<em>“information gain”</em>) between its new belief over the prediction model to the old one. This term can be added into the reward function as an intrinsic reward: \(r^i_t = D_\text{KL} [p(\theta \vert \xi_t, a_t, s_{t+1}) \| p(\theta \vert \xi_t))]\).</p>
<p>However, computing the posterior \(p(\theta \vert \xi_t, a_t, s_{t+1})\) is generally intractable.</p>
\[\begin{aligned}
p(\theta \vert \xi_t, a_t, s_{t+1})
&= \frac{p(\theta \vert \xi_t, a_t) p(s_{t+1} \vert \xi_t, a_t; \theta)}{p(s_{t+1}\vert\xi_t, a_t)} \\
&= \frac{p(\theta \vert \xi_t) p(s_{t+1} \vert \xi_t, a_t; \theta)}{p(s_{t+1}\vert\xi_t, a_t)} & \scriptstyle{\text{; because action doesn't affect the belief.}} \\
&= \frac{\color{red}{p(\theta \vert \xi_t)} p(s_{t+1} \vert \xi_t, a_t; \theta)}{\int_\Theta p(s_{t+1}\vert\xi_t, a_t; \theta) \color{red}{p(\theta \vert \xi_t)} d\theta} & \scriptstyle{\text{; red part is hard to compute directly.}}
\end{aligned}\]
<p>Since it is difficult to compute \(p(\theta\vert\xi_t)\) directly, a natural choice is to approximate it with an alternative distribution \(q_\phi(\theta)\). With variational lower bound, we know the maximization of \(q_\phi(\theta)\) is equivalent to maximizing \(p(\xi_t\vert\theta)\) and minimizing \(D_\text{KL}[q_\phi(\theta) \| p(\theta)]\).</p>
<p>Using the approximation distribution \(q\), the intrinsic reward becomes:</p>
\[r^i_t = D_\text{KL} [q_{\phi_{t+1}}(\theta) \| q_{\phi_t}(\theta))]\]
<p>where \(\phi_{t+1}\) represents \(q\)’s parameters associated with the new relief after seeing \(a_t\) and \(s_{t+1}\). When used as an exploration bonus, it is normalized by division by the moving median of this KL divergence value.</p>
<p>Here the dynamics model is parameterized as a <a href="https://link.springer.com/book/10.1007/978-1-4612-0745-0">Bayesian neural network</a> (BNN), as it maintains a distribution over its weights. The BNN weight distribution \(q_\phi(\theta)\) is modeled as a fully <em>factorized</em> Gaussian with \(\phi = \{\mu, \sigma\}\) and we can easily sample \(\theta \sim q_\phi(.)\). After applying a second-order Taylor expansion, the KL term \(D_\text{KL}[q_{\phi + \lambda \Delta\phi}(\theta) \| q_{\phi}(\theta)]\) can be estimated using <a href="/lil-log/2019/09/05/evolution-strategies.html#estimation-using-fisher-information-matrix">Fisher Information Matrix</a> \(\mathbf{F}_\phi\), which is easy to compute, because \(q_\phi\) is factorized Gaussian and thus the covariance matrix is only a diagonal matrix. See more details in <a href="https://arxiv.org/abs/1605.09674">the paper</a>, especially section 2.3-2.5.</p>
<p><a name="exploration-via-disagreement"></a>All the methods above depend on a single prediction model. If we have multiple such models, we could use the disagreement among models to set the exploration bonus (<a href="https://arxiv.org/abs/1906.04161">Pathak, et al. 2019</a>). High disagreement indicates low confidence in prediction and thus requires more exploration. <a href="https://arxiv.org/abs/1906.04161">Pathak, et al. (2019)</a> proposed to train a set of forward dynamics models and to use the variance over the ensemble of model outputs as \(r_t^i\). Precisely, they encode the state space with <a href="#random-feature">random feature</a> and learn 5 models in the ensemble.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/exploration-via-disagreement.png" alt="Disagreement" /></p>
<p><em>Fig. 7. Illustration of training architecture for self-supervised exploration via disagreement. (Image source: <a href="https://arxiv.org/abs/1906.04161">Pathak, et al. 2019</a>)</em></p>
<p>Because \(r^i_t\) is differentiable, the intrinsic reward in the model could be directly optimized through gradient descent so as to inform the policy agent to change actions. This differentiable exploration approach is very efficient but limited by having a short exploration horizon.</p>
<h4 id="random-networks">Random Networks</h4>
<p>But, what if the prediction task is not about the environment dynamics at all? It turns out when the prediction is for a random task, it still can help exploration.</p>
<p><strong>DORA</strong> (short for <em>“Directed Outreaching Reinforcement Action-Selection”</em>; <a href="https://arxiv.org/abs/1804.04012">Fox & Choshen, et al. 2018</a>) is a novel framework that injects exploration signals based on a newly introduced, <strong>task-independent</strong> MDP. The idea of DORA depends on two parallel MDPs:</p>
<ul>
<li>One is the original task MDP;</li>
<li>The other is an identical MDP but with <em>no reward attached</em>: Rather, every state-action pair is designed to have value 0. The Q-value learned for the second MDP is called <em>E-value</em>. If the model cannot perfectly predict E-value to be zero, it is still missing information.</li>
</ul>
<p>Initially E-value is assigned with value 1. Such positive initialization can encourage directed exploration for better E-value prediction. State-action pairs with high E-value estimation don’t have enough information gathered yet, at least not enough to exclude their high E-values. To some extent, the logarithm of E-values can be considered as a generalization of <em>visit counters</em>.</p>
<p>When using a neural network to do function approximation for E-value, another value head is added to predict E-value and it is simply expected to predict zero. Given a predicted E-value \(E(s_t, a_t)\), the exploration bonus is \(r^i_t = \frac{1}{\sqrt{-\log E(s_t, a_t)}}\).</p>
<p><a name="RND"></a>Similar to DORA, <strong>Random Network Distillation</strong> (<strong>RND</strong>; <a href="https://arxiv.org/abs/1810.12894">Burda, et al. 2018</a>) introduces a prediction task <em>independent of the main task</em>. The RND exploration bonus is defined as the error of a neural network \(\hat{f}(s_t)\) predicting features of the observations given by a <em>fixed randomly initialized</em> neural network \(f(s_t)\). The motivation is that given a new state, if similar states have been visited many times in the past, the prediction should be easier and thus has lower error. The exploration bonus is \(r^i(s_t) = \|\hat{f}(s_t; \theta) - f(s_t) \|_2^2\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RND.png" alt="RND" /></p>
<p><em>Fig. 8. How RND (Random Network Distillation) works for providing an intrinsic reward. The features \(O_{i+1} \mapsto f_{i+1}\) are generated by a fixed random neural network. (Image source: <a href="https://openai.com/blog/reinforcement-learning-with-prediction-based-rewards/">OpenAI Blog: “Reinforcement Learning with Prediction-Based Rewards”</a>)</em></p>
<p>Two factors are important in RND experiments:</p>
<ol>
<li>Non-episodic setting results in better exploration, especially when not using any extrinsic rewards. It means that the return is not truncated at “Game over” and intrinsic return can spread across multiple episodes.</li>
<li>Normalization is important since the scale of the reward is tricky to adjust given a random neural network as a prediction target. The intrinsic reward is normalized by division by a running estimate of the standard deviations of the intrinsic return.</li>
</ol>
<p>The RND setup works well for resolving the hard-exploration problem. For example, maximizing the RND exploration bonus consistently finds more than half of the rooms in Montezuma’s Revenge.</p>
<h4 id="physical-properties">Physical Properties</h4>
<p>Different from games in simulators, some RL applications like Robotics need to understand objects and intuitive reasoning in the physical world. Some prediction tasks require the agent to perform a sequence of interactions with the environment and to observe the corresponding consequences, such as estimating some hidden properties in physics (e.g. mass, friction, etc).</p>
<p>Motivated by such ideas, <a href="https://arxiv.org/abs/1611.01843">Denil, et al. (2017)</a> found that DRL agents can learn to perform necessary exploration to discover such hidden properties. Precisely they considered two experiments:</p>
<ol>
<li><em>“Which is heavier?”</em> — The agent has to interact with the blocks and infer which one is heavier.</li>
<li><em>“Towers”</em> — The agent needs to infer how many rigid bodies a tower is composed of by knocking it down.</li>
</ol>
<p>The agent in the experiments first goes through an exploration phase to interact with the environment and to collect information. Once the exploration phase ends, the agent is asked to output a <em>labeling</em> action to answer the question. Then a positive reward is assigned to the agent if the answer is correct; otherwise a negative one is assigned. Because the answer requires a decent amount of interactions with items in the scene, the agent has to learn to efficiently play around so as to figure out the physics and the correct answer. The exploration naturally happens.</p>
<p>In their experiments, the agent is able to learn in both tasks with performance varied by the difficulty of the task. Although the paper didn’t use the physics prediction task to provide intrinsic reward bonus along with extrinsic reward associated with another learning task, rather it focused on the exploration tasks themselves. I do enjoy the idea of encouraging sophisticated exploration behavior by predicting hidden physics properties in the environment.</p>
<h2 id="memory-based-exploration">Memory-based Exploration</h2>
<p>Reward-based exploration suffers from several drawbacks:</p>
<ul>
<li>Function approximation is slow to catch up.</li>
<li>Exploration bonus is non-stationary.</li>
<li>Knowledge fading, meaning that states cease to be novel and cannot provide intrinsic reward signals in time.</li>
</ul>
<p>Methods in this section rely on external memory to resolve disadvantages of reward bonus-based exploration.</p>
<h3 id="episodic-memory">Episodic Memory</h3>
<p>As mentioned above, <a href="#RND">RND</a> is better running in an non-episodic setting, meaning the prediction knowledge is accumulated across multiple episodes. The exploration strategy, <strong>Never Give Up</strong> (<strong>NGU</strong>; <a href="https://arxiv.org/abs/2002.06038">Badia, et al. 2020a</a>), combines an episodic novelty module that can rapidly adapt within one episode with RND as a lifelong novelty module.</p>
<p>Precisely, the intrinsic reward in NGU consists of two exploration bonuses from two modules, <em>within one episode</em> and <em>across multiple episodes</em>, respectively.</p>
<p>The short-term per-episode reward is provided by an <em>episodic novelty module</em>. It contains an episodic memory \(M\), a dynamically-sized slot-based memory, and an IDF (inverse dynamics features) embedding function \(\phi\), same as the feature encoding in <a href="#ICM">ICM</a></p>
<ol>
<li>At every step the current state embedding \(\phi(s_t)\) is added into \(M\).</li>
<li>The intrinsic bonus is determined by comparing how similar the current observation is to the content of \(M\). A larger difference results in a larger bonus.
<br />
\(r^\text{episodic}_t \approx \frac{1}{\sqrt{\sum_{\phi_i \in N_k} K(\phi(x_t), \phi_i)} + c}\)
<br />
where \(K(x, y)\) is a kernel function for measuring the distance between two samples. \(N_k\) is a set of \(k\) nearest neighbors in \(M\) according to \(K(., .)\). \(c\) is a small constant to keep the denominator non-zero. In the paper, \(K(x, y)\) is configured to be the inverse kernel:
<br />
\(K(x, y) = \frac{\epsilon}{\frac{d^2(x, y)}{d^2_m} + \epsilon}\)
<br />
where \(d(.,.)\) is Euclidean distance between two samples and \(d_m\) is a running average of the squared Euclidean distance of the k-th nearest neighbors for better robustness. \(\epsilon\) is a small constant.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NGU.png" alt="RND" /></p>
<p><em>Fig. 9. The architecture of NGU’s embedding function (left) and reward generator (right). (Image source: <a href="https://arxiv.org/abs/2002.06038">Badia, et al. 2020a</a>)</em></p>
<p>The long-term across-episode novelty relies on RND prediction error in <em>life-long novelty module</em>. The exploration bonus is \(\alpha_t = 1 + \frac{e^\text{RND}(s_t) - \mu_e}{\sigma_e}\) where \(\mu_e\) and \(\sigma_e\) are running mean and std dev for RND error \(e^\text{RND}(s_t)\).</p>
<blockquote>
<p>However in the conclusion section of the <a href="https://arxiv.org/abs/1810.12894">RND paper</a>, I noticed the following statement:</p>
<p>“We find that the RND exploration bonus is sufficient to deal with local exploration, i.e. exploring the consequences of short-term decisions, like whether to interact with a particular object, or avoid it. However global exploration that involves coordinated decisions over long time horizons is beyond the reach of our method. “</p>
<p>And this confuses me a bit how RND can be used as a good life-long novelty bonus provider. If you know why, feel free to leave a comment below.</p>
</blockquote>
<p>The final combined intrinsic reward is \(r^i_t = r^\text{episodic}_t \cdot \text{clip}(\alpha_t, 1, L)\), where \(L\) is a constant maximum reward scalar.</p>
<p>The design of NGU enables it to have two nice properties:</p>
<ol>
<li><em>Rapidly discourages</em> revisiting the same state <em>within</em> the same episode;</li>
<li><em>Slowly discourages</em> revisiting states that have been visited many times <em>across</em> episodes.</li>
</ol>
<p>Later, built on top of NGU, DeepMind proposed “Agent57” (<a href="https://arxiv.org/abs/2003.13350">Badia, et al. 2020b</a>), the first deep RL agent that outperforms the standard human benchmark on <em>all</em> 57 Atari games. Two major improvements in Agent57 over NGU are:</p>
<ol>
<li>A <em>population</em> of policies are trained in Agent57, each equipped with a different exploration parameter pair \(\{(\beta_j, \gamma_j)\}_{j=1}^N\). Recall that given \(\beta_j\), the reward is constructed as \(r_{j,t} = r_t^e + \beta_j r^i_t\) and \(\gamma_j\) is the reward discounting factor. It is natural to expect policies with higher \(\beta_j\) and lower \(\gamma_j\) to make more progress early in training, while the opposite would be expected as training progresses. A meta-controller (<a href="https://arxiv.org/pdf/0805.3415.pdf">sliding-window UCB bandit algorithm</a>) is trained to select which policies should be prioritized.</li>
<li>The second improvement is a new parameterization of Q-value function that decomposes the contributions of the intrinsic and extrinsic rewards in a similar form as the bundled reward: \(Q(s, a; \theta_j) = Q(s, a; \theta_j^e) + \beta_j Q(s, a; \theta_j^i)\). During training, \(Q(s, a; \theta_j^e)\) and \(Q(s, a; \theta_j^i)\) are optimized separately with rewards \(r_j^e\) and \(r_j^i\), respectively.</li>
</ol>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/agent57.png" alt="Agent57" /></p>
<p><em>Fig. 10. A pretty cool illustration of techniques developed in time since DQN in 2015, eventually leading to Agent57. (Image source: <a href="https://deepmind.com/blog/article/Agent57-Outperforming-the-human-Atari-benchmark">DeepMind Blog: “Agent57: Outperforming the human Atari benchmark”</a>)</em></p>
<p>Instead of using the Euclidean distance to measure closeness of states in episodic memory, <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. (2019)</a> took the transition between states into consideration and proposed a method to measure the number of steps needed to visit one state from other states in memory, named <strong>Episodic Curiosity (EC)</strong> module. The novelty bonus depends on reachability between states.</p>
<ol>
<li>At the beginning of each episode, the agent starts with an empty episodic memory \(M\).</li>
<li>At every step, the agent compares the current state with saved states in memory to determine novelty bonus: If the current state is novel (i.e., takes more steps to reach from observations in memory than a threshold), the agent gets a bonus.</li>
<li>The current state is added into the episodic memory if the novelty bonus is high enough. (Imagine that if all the states were added into memory, any new state could be added within 1 step.)</li>
<li>Repeat 1-3 until the end of this episode.</li>
</ol>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/transition-graph.png" alt="Transition graph" /></p>
<p><em>Fig. 11. The nodes in the graph are states, the edges are possible transitions. The blue nodes are states in memory. The green nodes are reachable from the memory within \(k = 2\) steps (not novel). The orange nodes are further away, so they are considered as novel states. (Image source: <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. 2019</a>)</em></p>
<p>In order to estimate reachability between states, we need to access the transition graph, which is unfortunately not entirely known. Thus, <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. (2019)</a> trained a <a href="/lil-log/2018/11/30/meta-learning.html#convolutional-siamese-neural-network">siamese</a> neural network to predict how many steps separate two states. It contains one embedding network \(\phi: \mathcal{S} \mapsto \mathbb{R}^n\) to first encode the states to feature vectors and then one comparator network \(C: \mathbb{R}^n \times \mathbb{R}^n \mapsto [0, 1]\) to output a binary label on whether two states are close enough (i.e., reachable within \(k\) steps) in the transition graph, \(C(\phi(s_i), \phi(s_j)) \mapsto [0, 1]\).</p>
<p>An episodic memory buffer \(M\) stores embeddings of some past observations within the same episode. A new observation will be compared with existing state embeddings via \(C\) and the results are aggregated (e.g. max, 90th percentile) to provide a reachability score \(C^M(\phi(s_t))\). The exploration bonus is \(r^i_t = \big(C' - C^M(f(s_t))\big)\), where \(C'\) is a predefined threshold for determining the sign of the reward (e.g. \(C'=0.5\) works well for fixed-duration episodes). High bonus is awarded to new states when they are not easily reachable from states in the memory buffer.</p>
<p>They claimed that the EC module can overcome the <a href="#the-noisy-tv-problem">noisy-TV</a> problem.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/episodic-memory-overview.png" alt="EC module" /></p>
<p><em>Fig. 12. The architecture of episodic curiosity (EC) module for intrinsic reward generation. (Image source: <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. 2019</a>)</em></p>
<h3 id="direct-exploration">Direct Exploration</h3>
<p><strong>Go-Explore</strong> (<a href="https://arxiv.org/abs/1901.10995">Ecoffet, et al., 2019</a>) is an algorithm aiming to solve the “hard-exploration” problem. It is composed of the following two phases.</p>
<p><strong>Phase 1 (“Explore until solved”)</strong> feels quite like <a href="https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm">Dijkstra’s algorithm</a> for finding shortest paths in a graph. Indeed, no neural network is involved in phase 1. By maintaining a memory of interesting states as well as trajectories leading to them, the agent can go back (given a simulator is <em>deterministic</em>) to promising states and continue doing <em>random</em> exploration from there. The state is mapped into a short discretized code (named “cell”) in order to be memorized. The memory is updated if a new state appears or a better/shorter trajectory is found. When selecting which past states to return to, the agent might select one in the memory uniformly or according to heuristics like recency, visit count, count of neighbors in the memory, etc. This process is repeated until the task is solved and at least one solution trajectory is found.</p>
<p>The above found high-performance trajectories would not work well on evaluation envs with any stochasticity. Thus, <strong>Phase 2 (“Robustification”)</strong> is needed to robustify the solution via imitation learning. They adopted <a href="https://arxiv.org/abs/1812.03381">Backward Algorithm</a>, in which the agent is started near the last state in the trajectory and then runs RL optimization from there.</p>
<p>One important note in phase 1 is: In order to go back to a state deterministically without exploration, Go-Explore depends on a resettable and deterministic simulator, which is a big disadvantage.</p>
<p>To make the algorithm more generally useful to environments with stochasticity, an enhanced version of Go-Explore (<a href="https://arxiv.org/abs/2004.12919">Ecoffet, et al., 2020</a>), named <strong>policy-based Go-Explore</strong> was proposed later.</p>
<ul>
<li>Instead of resetting the simulator state effortlessly, the policy-based Go-Explore learns a <em>goal-conditioned policy</em> and uses that to access a known state in memory repeatedly. The goal-conditioned policy is trained to follow the best trajectory that previously led to the selected states in memory. They include a <strong>Self-Imitation Learning</strong> (<strong>SIL</strong>; <a href="https://arxiv.org/abs/1806.05635">Oh, et al. 2018</a>) loss to help extract as much information as possible from successful trajectories.</li>
<li>Also, they found sampling from policy works better than random actions when the agent returns to promising states to continue exploration.</li>
<li>Another improvement in policy-based Go-Explore is to make the downscaling function of images to cells adjustable. It is optimized so that there would be neither too many nor too few cells in the memory.</li>
</ul>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/policy-based-Go-Explore.png" alt="Policy-based Go-Explore" /></p>
<p><em>Fig. 13. An overview of the Go-Explore algorithm. (Image source: <a href="https://arxiv.org/abs/2004.12919">Ecoffet, et al., 2020</a>)</em></p>
<p>After vanilla Go-Explore, <a href="https://arxiv.org/abs/1907.10247">Yijie Guo, et al. (2019)</a> proposed <strong>DTSIL</strong> (Diverse Trajectory-conditioned Self-Imitation Learning), which shared a similar idea as policy-based Go-Explore above. DTSIL maintains a memory of diverse demonstrations collected during training and uses them to train a trajectory-conditioned policy via <a href="https://arxiv.org/abs/1806.05635">SIL</a>. They prioritize trajectories that end with a rare state during sampling.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DTSIL-algo.png" alt="DTSIL" /></p>
<p><em>Fig. 14. Algorithm of DTSIL (Diverse Trajectory-conditioned Self-Imitation Learning). (Image source: <a href="https://arxiv.org/abs/1907.10247">Yijie Guo, et al. 2019</a>)</em></p>
<p>The similar approach is also seen in <a href="https://arxiv.org/abs/1906.07805">Guo, et al. (2019)</a>. The main idea is to store goals with <em>high uncertainty</em> in memory so that later the agent can revisit these goal states with a goal-conditioned policy repeatedly. In each episode, the agent flips a coin (probability 0.5) to decide whether it will act greedily w.r.t. the policy or do directed exploration by sampling goals from the memory.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/directed-exploration.png" alt="Directed exploration" /></p>
<p><em>Fig. 15. Different components in directed exploration with function approximation. (Image source: <a href="https://arxiv.org/abs/1906.07805">Guo, et al. 2019</a>)</em></p>
<p>The uncertainty measure of a state can be something simple like count-based bonuses or something complex like density or bayesian models. The paper trained a forward dynamics model and took its prediction error as the uncertainty metric.</p>
<h2 id="q-value-exploration">Q-Value Exploration</h2>
<p>Inspired by <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#thompson-sampling">Thompson sampling</a>, <strong>Bootstrapped DQN</strong> (<a href="https://arxiv.org/abs/1602.04621">Osband, et al. 2016</a>) introduces a notion of uncertainty in Q-value approximation in classic <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#deep-q-network">DQN</a> by using the <a href="https://en.wikipedia.org/wiki/Bootstrapping_(statistics)">bootstrapping</a> method. Bootstrapping is to approximate a distribution by sampling with replacement from the same population multiple times and then aggregate the results.</p>
<p>Multiple Q-value heads are trained in parallel but each only consumes a bootstrapped sub-sampled set of data and each has its own corresponding target network. All the Q-value heads share the same backbone network.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/bootstrapped-DQN-algo.png" alt="Bootstrapped DQN" /></p>
<p><em>Fig. 16. The algorithm of Bootstrapped DQN. (Image source: <a href="https://arxiv.org/abs/1602.04621">Osband, et al. 2016</a>)</em></p>
<p>At the beginning of one episode, one Q-value head is sampled uniformly and acts for collecting experience data in this episode. Then a binary mask is sampled from the masking distribution \(m \sim \mathcal{M}\) and decides which heads can use this data for training. The choice of masking distribution \(\mathcal{M}\) determines how bootstrapped samples are generated; For example,</p>
<ul>
<li>If \(\mathcal{M}\) is an independent Bernoulli distribution with \(p=0.5\), this corresponds to the double-or-nothing bootstrap.</li>
<li>If \(\mathcal{M}\) always returns an all-one mask, the algorithm reduces to an ensemble method.</li>
</ul>
<p>However, this kind of exploration is still restricted, because uncertainty introduced by bootstrapping fully relies on the training data. It is better to inject some prior information independent of the data. This “noisy” prior is expected to drive the agent to keep exploring when the reward is sparse. The algorithm of adding random prior into bootstrapped DQN for better exploration (<a href="https://arxiv.org/abs/1806.03335">Osband, et al. 2018</a>) depends on Bayesian linear regression. The core idea of Bayesian regression is: We can <em>“generate posterior samples by training on noisy versions of the data, together with some random regularization”</em>.</p>
<p>Let \(\theta\) be the Q function parameter and \(\theta^-\) for the target Q, the loss function using a randomized prior function \(p\) is:</p>
\[\mathcal{L}(\theta, \theta^{-}, p, \mathcal{D}; \gamma) = \sum_{t\in\mathcal{D}}\Big( r_t + \gamma \max_{a'\in\mathcal{A}} (\underbrace{Q_{\theta^-} + p)}_\text{target Q}(s'_t, a') - \underbrace{(Q_\theta + p)}_\text{Q to optimize}(s_t, a_t) \Big)^2\]
<h2 id="varitional-options">Varitional Options</h2>
<p>Options are policies with termination conditions. There are a large set of options available in the search space and they are independent of an agent’s intentions. By explicitly including intrinsic options into modeling, the agent can obtain intrinsic rewards for exploration.</p>
<p><strong>VIC</strong> (short for <em>“Variational Intrinsic Control”</em>; <a href="https://arxiv.org/abs/1611.07507">Gregor, et al. 2017</a>) is such a framework for providing the agent with intrinsic exploration bonuses based on modeling options and learning policies conditioned on options. Let \(\Omega\) represent an option which starts from \(s_0\) and ends at \(s_f\). An environment probability distribution \(p^J(s_f \vert s_0, \Omega)\) defines where an option \(\Omega\) terminates given a starting state \(s_0\). A controllability distribution \(p^C(\Omega \vert s_0)\) defines the probability distribution of options we can sample from. And by definition we have \(p(s_f, \Omega \vert s_0) = p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0)\).</p>
<p>While choosing options, we would like to achieve two goals:</p>
<ul>
<li>Achieve a diverse set of the final states from \(s_0\) ⇨ Maximization of \(H(s_f \vert s_0)\).</li>
<li>Know precisely which state a given option \(\Omega\) can end with ⇨ Minimization of \(H(s_f \vert s_0, \Omega)\).</li>
</ul>
<p>Combining them, we get mutual information \(I(\Omega; s_f \vert s_0)\) to maximize:</p>
\[\begin{aligned}
I(\Omega; s_f \vert s_0)
&= H(s_f \vert s_0) - H(s_f \vert s_0, \Omega) \\
&= - \sum_{s_f} p(s_f \vert s_0) \log p(s_f \vert s_0) + \sum_{s_f, \Omega} p(s_f, \Omega \vert s_0) \log \frac{p(s_f, \Omega \vert s_0)}{p^C(\Omega \vert s_0)} \\
&= - \sum_{s_f} p(s_f \vert s_0) \log p(s_f \vert s_0) + \sum_{s_f, \Omega} p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0) \log p^J(s_f \vert s_0, \Omega) \\
\end{aligned}\]
<p>Because mutual information is symmetric, we can switch \(s_f\) and \(\Omega\) in several places without breaking the equivalence. Also because \(p(\Omega \vert s_0, s_f)\) is difficult to observe, let us replace it with an approximation distribution \(q\). According to the variational lower bound, we would have \(I(\Omega; s_f \vert s_0) \geq I^{VB}(\Omega; s_f \vert s_0)\).</p>
\[\begin{aligned}
I(\Omega; s_f \vert s_0)
&= I(s_f; \Omega \vert s_0) \\
&= - \sum_{\Omega} p(\Omega \vert s_0) \log p(\Omega \vert s_0) + \sum_{s_f, \Omega} p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0) \log \color{red}{p(\Omega \vert s_0, s_f)}\\
I^{VB}(\Omega; s_f \vert s_0)
&= - \sum_{\Omega} p(\Omega \vert s_0) \log p(\Omega \vert s_0) + \sum_{s_f, \Omega} p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0) \log \color{red}{q(\Omega \vert s_0, s_f)} \\
I(\Omega; s_f \vert s_0) &\geq I^{VB}(\Omega; s_f \vert s_0)
\end{aligned}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/VIC-explicit-options.png" alt="VIC" /></p>
<p><em>Fig. 17. The algorithm for VIC (Variational Intrinsic Control). (Image source: <a href="https://arxiv.org/abs/1611.07507">Gregor, et al. 2017</a>)</em></p>
<p>Here \(\pi(a \vert \Omega, s)\) can be optimized with any RL algorithm. The option inference function \(q(\Omega \vert s_0, s_f)\) is doing supervised learning. The prior \(p^C\) is updated so that it tends to choose \(\Omega\) with higher rewards. Note that \(p^C\) can also be fixed (e.g. a Gaussian). Various \(\Omega\) will result in different behavior through learning. Additionally, <a href="https://arxiv.org/abs/1611.07507">Gregor, et al. (2017)</a> observed that it is difficult to make VIC with explicit options work in practice with function approximation and therefore they also proposed another version of VIC with implicit options.</p>
<p>Different from VIC which models \(\Omega\) conditioned only on the start and end states, <strong>VALOR</strong> (short for <em>“Variational Auto-encoding Learning of Options by Reinforcement”</em>; <a href="https://arxiv.org/abs/1807.10299">Achiam, et al. 2018</a>) relies on the whole trajectory to extract the option context \(c\), which is sampled from a fixed Gaussian distribution. In VALOR:</p>
<ul>
<li>A policy acts as an encoder, translating contexts from a noise distribution into trajectories</li>
<li>A decoder attempts to recover the contexts from the trajectories, and rewards the policies for making contexts easier to distinguish. The decoder never sees the actions during training, so the agent has to interact with the environment in a way that facilitates communication with the decoder for better prediction. Also, the decoder recurrently takes in a sequence of steps in one trajectory to better model the correlation between timesteps.</li>
</ul>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/VALOR-decoder.png" alt="VALOR" /></p>
<p><em>Fig. 18. The decoder of VALOR is a biLSTM which takes \(N = 11\) equally spaced observations from one trajectory as inputs. (Image source: <a href="https://arxiv.org/abs/1807.10299">Achiam, et al. 2018</a>)</em></p>
<p>DIAYN (“Diversity is all you need”; <a href="https://arxiv.org/abs/1802.06070">Eysenbach, et al. 2018</a>) has the idea lying in the same direction, although with a different name — DIAYN models the policies conditioned on a latent <em>skill</em> variable. See my <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#learning-with-random-rewards">previous post</a> for more details.</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020exploration,
title = "Exploration Strategies in Deep Reinforcement Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/06/07/exploration-strategies-in-deep-reinforcement-learning.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Pierre-Yves Oudeyer & Frederic Kaplan. <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.567.6524&rep=rep1&type=pdf">“How can we define intrinsic motivation?”</a> Conf. on Epigenetic Robotics, 2008.</p>
<p>[2] Marc G. Bellemare, et al. <a href="https://arxiv.org/abs/1606.01868">“Unifying Count-Based Exploration and Intrinsic Motivation”</a>. NIPS 2016.</p>
<p>[3] Georg Ostrovski, et al. <a href="https://arxiv.org/abs/1703.01310">“Count-Based Exploration with Neural Density Models”</a>. PMLR 2017.</p>
<p>[4] Rui Zhao & Volker Tresp. <a href="https://arxiv.org/abs/1902.08039">“Curiosity-Driven Experience Prioritization via
Density Estimation”</a>. NIPS 2018.</p>
<p>[5] Haoran Tang, et al. <a href="https://arxiv.org/abs/1611.04717">“#Exploration: A Study of Count-Based Exploration for Deep Reinforcement Learning”</a>. NIPS 2017.</p>
<p>[6] Jürgen Schmidhuber. <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.45.957">“A possibility for implementing curiosity and boredom in model-building neural controllers”</a> 1991.</p>
<p>[7] Pierre-Yves Oudeyer, et al. <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.177.7661&rep=rep1&type=pdf">“Intrinsic Motivation Systems for Autonomous Mental Development”</a> IEEE Transactions on Evolutionary Computation, 2007.</p>
<p>[8] Bradly C. Stadie, et al. <a href="https://arxiv.org/abs/1507.00814">“Incentivizing Exploration In Reinforcement Learning With Deep Predictive Models”</a>. ICLR 2016.</p>
<p>[9] Deepak Pathak, et al. <a href="https://arxiv.org/abs/1705.05363">“Curiosity-driven Exploration by Self-supervised Prediction”</a>. CVPR 2017.</p>
<p>[10] Yuri Burda, Harri Edwards & Deepak Pathak, et al. <a href="https://arxiv.org/abs/1808.04355">“Large-Scale Study of Curiosity-Driven Learning”</a>. arXiv 1808.04355 (2018).</p>
<p>[11] Joshua Achiam & Shankar Sastry. <a href="https://arxiv.org/abs/1703.01732">“Surprise-Based Intrinsic Motivation for Deep Reinforcement Learning”</a> NIPS 2016 Deep RL Workshop.</p>
<p>[12] Rein Houthooft, et al. <a href="https://arxiv.org/abs/1605.09674">“VIME: Variational information maximizing exploration”</a>. NIPS 2016.</p>
<p>[13] Leshem Choshen, Lior Fox & Yonatan Loewenstein. <a href="https://arxiv.org/abs/1804.04012">“DORA the explorer: Directed outreaching reinforcement action-selection”</a>. ICLR 2018</p>
<p>[14] Yuri Burda, et al. <a href="https://arxiv.org/abs/1810.12894">“Exploration by Random Network Distillation”</a> ICLR 2019.</p>
<p>[15] OpenAI Blog: <a href="https://openai.com/blog/reinforcement-learning-with-prediction-based-rewards/">“Reinforcement Learning with
Prediction-Based Rewards”</a> Oct, 2018.</p>
<p>[16] Misha Denil, et al. <a href="https://arxiv.org/abs/1611.01843">“Learning to Perform Physics Experiments via Deep Reinforcement Learning”</a>. ICLR 2017.</p>
<p>[17] Ian Osband, et al. <a href="https://arxiv.org/abs/1602.04621">“Deep Exploration via Bootstrapped DQN”</a>. NIPS 2016.</p>
<p>[18] Ian Osband, John Aslanides & Albin Cassirer. <a href="https://arxiv.org/abs/1806.03335">“Randomized Prior Functions for Deep Reinforcement Learning”</a>. NIPS 2018.</p>
<p>[19] Karol Gregor, Danilo Jimenez Rezende & Daan Wierstra. <a href="https://arxiv.org/abs/1611.07507">“Variational Intrinsic Control”</a>. ICLR 2017.</p>
<p>[20] Joshua Achiam, et al. <a href="https://arxiv.org/abs/1807.10299">“Variational Option Discovery Algorithms”</a>. arXiv 1807.10299 (2018).</p>
<p>[21] Benjamin Eysenbach, et al. <a href="https://arxiv.org/abs/1802.06070">“Diversity is all you need: Learning skills without a reward function.”</a>. ICLR 2019.</p>
<p>[22] Adrià Puigdomènech Badia, et al. <a href="https://arxiv.org/abs/2002.06038">“Never Give Up (NGU): Learning Directed Exploration Strategies”</a> ICLR 2020.</p>
<p>[23] Adrià Puigdomènech Badia, et al. <a href="https://arxiv.org/abs/2003.13350">“Agent57: Outperforming the Atari Human Benchmark”</a>. arXiv 2003.13350 (2020).</p>
<p>[24] DeepMind Blog: <a href="https://deepmind.com/blog/article/Agent57-Outperforming-the-human-Atari-benchmark">“Agent57: Outperforming the human Atari benchmark”</a> Mar 2020.</p>
<p>[25] Nikolay Savinov, et al. <a href="https://arxiv.org/abs/1810.02274">“Episodic Curiosity through Reachability”</a> ICLR 2019.</p>
<p>[26] Adrien Ecoffet, et al. <a href="https://arxiv.org/abs/1901.10995">“Go-Explore: a New Approach for Hard-Exploration Problems”</a>. arXiv 1901.10995 (2019).</p>
<p>[27] Adrien Ecoffet, et al. <a href="https://arxiv.org/abs/2004.12919">“First return then explore”</a>. arXiv 2004.12919 (2020).</p>
<p>[28] Junhyuk Oh, et al. <a href="https://arxiv.org/abs/1806.05635">“Self-Imitation Learning”</a>. ICML 2018.</p>
<p>[29] Yijie Guo, et al. <a href="https://arxiv.org/abs/1907.10247">“Self-Imitation Learning via Trajectory-Conditioned Policy for Hard-Exploration Tasks”</a>. arXiv 1907.10247 (2019).</p>
<p>[30] Zhaohan Daniel Guo & Emma Brunskill. <a href="https://arxiv.org/abs/1906.07805">“Directed Exploration for Reinforcement Learning”</a>. arXiv 1906.07805 (2019).</p>
<p>[31] Deepak Pathak, et al. <a href="https://arxiv.org/abs/1906.04161">“Self-Supervised Exploration via Disagreement.”</a> ICML 2019.</p>Lilian WengExploitation versus exploration is a critical topic in reinforcement learning. This post introduces several common approaches for better exploration in Deep RL.The Transformer Family2020-04-07T12:00:00+00:002020-04-07T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/04/07/the-transformer-family<blockquote>
<p>Inspired by recent progress on various enhanced versions of Transformer models, this post presents how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving, etc.</p>
</blockquote>
<!--more-->
<p>It has been almost two years since my last post on <a href="/lil-log/2018/06/24/attention-attention.html">attention</a>. Recent progress on new and enhanced versions of Transformer motivates me to write another post on this specific topic, focusing on how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving and more.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#notations" id="markdown-toc-notations">Notations</a></li>
<li><a href="#attention-and-self-attention" id="markdown-toc-attention-and-self-attention">Attention and Self-Attention</a></li>
<li><a href="#multi-head-self-attention" id="markdown-toc-multi-head-self-attention">Multi-Head Self-Attention</a></li>
<li><a href="#transformer" id="markdown-toc-transformer">Transformer</a></li>
<li><a href="#adaptive-computation-time-act" id="markdown-toc-adaptive-computation-time-act">Adaptive Computation Time (ACT)</a></li>
<li><a href="#improved-attention-span" id="markdown-toc-improved-attention-span">Improved Attention Span</a> <ul>
<li><a href="#longer-attention-span-transformer-xl" id="markdown-toc-longer-attention-span-transformer-xl">Longer Attention Span (Transformer-XL)</a></li>
<li><a href="#adaptive-attention-span" id="markdown-toc-adaptive-attention-span">Adaptive Attention Span</a></li>
<li><a href="#localized-attention-span-image-transformer" id="markdown-toc-localized-attention-span-image-transformer">Localized Attention Span (Image Transformer)</a></li>
</ul>
</li>
<li><a href="#less-time-and-memory-cost" id="markdown-toc-less-time-and-memory-cost">Less Time and Memory Cost</a> <ul>
<li><a href="#sparse-attention-matrix-factorization-sparse-transformers" id="markdown-toc-sparse-attention-matrix-factorization-sparse-transformers">Sparse Attention Matrix Factorization (Sparse Transformers)</a></li>
<li><a href="#locality-sensitive-hashing-reformer" id="markdown-toc-locality-sensitive-hashing-reformer">Locality-Sensitive Hashing (Reformer)</a></li>
</ul>
</li>
<li><a href="#make-it-recurrent-universal-transformer" id="markdown-toc-make-it-recurrent-universal-transformer">Make it Recurrent (Universal Transformer)</a></li>
<li><a href="#stabilization-for-rl-gtrxl" id="markdown-toc-stabilization-for-rl-gtrxl">Stabilization for RL (GTrXL)</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h3 id="notations">Notations</h3>
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(d\)</td>
<td>The model size / hidden state dimension / positional encoding size.</td>
</tr>
<tr>
<td>\(h\)</td>
<td>The number of heads in multi-head attention layer.</td>
</tr>
<tr>
<td>\(L\)</td>
<td>The segment length of input sequence.</td>
</tr>
<tr>
<td>\(\mathbf{X} \in \mathbb{R}^{L \times d}\)</td>
<td>The input sequence where each element has been mapped into an embedding vector of shape \(d\), same as the model size.</td>
</tr>
<tr>
<td>\(\mathbf{W}^k \in \mathbb{R}^{d \times d_k}\)</td>
<td>The key weight matrix.</td>
</tr>
<tr>
<td>\(\mathbf{W}^q \in \mathbb{R}^{d \times d_k}\)</td>
<td>The query weight matrix.</td>
</tr>
<tr>
<td>\(\mathbf{W}^v \in \mathbb{R}^{d \times d_v}\)</td>
<td>The value weight matrix. Often we have \(d_k = d_v = d\).</td>
</tr>
<tr>
<td>\(\mathbf{W}^k_i, \mathbf{W}^q_i \in \mathbb{R}^{d \times d_k/h}; \mathbf{W}^v_i \in \mathbb{R}^{d \times d_v/h}\)</td>
<td>The weight matrices per head.</td>
</tr>
<tr>
<td>\(\mathbf{W}^o \in \mathbb{R}^{d_v \times d}\)</td>
<td>The output weight matrix.</td>
</tr>
<tr>
<td>\(\mathbf{Q} = \mathbf{X}\mathbf{W}^q \in \mathbb{R}^{L \times d_k}\)</td>
<td>The query embedding inputs.</td>
</tr>
<tr>
<td>\(\mathbf{K} = \mathbf{X}\mathbf{W}^k \in \mathbb{R}^{L \times d_k}\)</td>
<td>The key embedding inputs.</td>
</tr>
<tr>
<td>\(\mathbf{V} = \mathbf{X}\mathbf{W}^v \in \mathbb{R}^{L \times d_v}\)</td>
<td>The value embedding inputs.</td>
</tr>
<tr>
<td>\(S_i\)</td>
<td>A collection of key positions for the \(i\)-th query \(\mathbf{q}_i\) to attend to.</td>
</tr>
<tr>
<td>\(\mathbf{A} \in \mathbb{R}^{L \times L}\)</td>
<td>The self-attention matrix between a input sequence of lenght \(L\) and itself. \(\mathbf{A} = \text{softmax}(\mathbf{Q}\mathbf{K}^\top / \sqrt{d_k})\).</td>
</tr>
<tr>
<td>\(a_{ij} \in \mathbf{A}\)</td>
<td>The scalar attention score between query \(\mathbf{q}_i\) and key \(\mathbf{k}_j\).</td>
</tr>
<tr>
<td>\(\mathbf{P} \in \mathbb{R}^{L \times d}\)</td>
<td>position encoding matrix, where the \(i\)-th row \(\mathbf{p}_i\) is the positional encoding for input \(\mathbf{x}_i\).</td>
</tr>
</tbody>
</table>
<h2 id="attention-and-self-attention">Attention and Self-Attention</h2>
<p><em>Attention</em> is a mechanism in the neural network that a model can learn to make predictions by selectively attending to a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed as a weighted average.</p>
<p><em>Self-attention</em> is a type of attention mechanism where the model makes prediction for one part of a data sample using other parts of the observation about the same sample. Conceptually, it feels quite similar to <a href="https://en.wikipedia.org/wiki/Non-local_means">non-local means</a>. Also note that self-attention is permutation-invariant; in other words, it is an operation on sets.</p>
<p>There are various forms of attention / self-attention, Transformer (<a href="https://arxiv.org/abs/1706.03762">Vaswani et al., 2017</a>) relies on the <em>scaled dot-product attention</em>: given a query matrix \(\mathbf{Q}\), a key matrix \(\mathbf{K}\) and a value matrix \(\mathbf{V}\), the output is a weighted sum of the value vectors, where the weight assigned to each value slot is determined by the dot-product of the query with the corresponding key:</p>
\[\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}(\frac{\mathbf{Q} {\mathbf{K}}^\top}{\sqrt{d_k}})\mathbf{V}\]
<p>And for a query and a key vector \(\mathbf{q}_i, \mathbf{k}_j \in \mathbb{R}^d\) (row vectors in query and key matrices), we have a scalar score:</p>
\[a_{ij} = \text{softmax}(\frac{\mathbf{q}_i {\mathbf{k}_j}^\top}{\sqrt{d_k}})
= \frac{\exp(\mathbf{q}_i {\mathbf{k}_j}^\top)}{ \sqrt{d_k} \sum_{r \in S_i} \exp(\mathbf{q}_i {\mathbf{k}_r}^\top) }\]
<p>where \(S_i\) is a collection of key positions for the \(i\)-th query to attend to.</p>
<p>See my old <a href="/lil-log/2018/06/24/attention-attention.html#a-family-of-attention-mechanisms">post</a> for other types of attention if interested.</p>
<h2 id="multi-head-self-attention">Multi-Head Self-Attention</h2>
<p>The <em>multi-head self-attention</em> module is a key component in Transformer. Rather than only computing the attention once, the multi-head mechanism splits the inputs into smaller chunks and then computes the scaled dot-product attention over each subspace in parallel. The independent attention outputs are simply concatenated and linearly transformed into expected dimensions.</p>
\[\begin{aligned}
\text{MultiHeadAttention}(\mathbf{X}_q, \mathbf{X}_k, \mathbf{X}_v) &= [\text{head}_1; \dots; \text{head}_h] \mathbf{W}^o \\
\text{where head}_i &= \text{Attention}(\mathbf{X}_q\mathbf{W}^q_i, \mathbf{X}_k\mathbf{W}^k_i, \mathbf{X}_v\mathbf{W}^v_i)
\end{aligned}\]
<p>where \([.;.]\) is a concatenation operation. \(\mathbf{W}^q_i, \mathbf{W}^k_i \in \mathbb{R}^{d \times d_k/h}, \mathbf{W}^v_i \in \mathbb{R}^{d \times d_v/h}\) are weight matrices to map input embeddings of size \(L \times d\) into query, key and value matrices. And \(\mathbf{W}^o \in \mathbb{R}^{d_v \times d}\) is the output linear transformation. All the weights should be learned during training.</p>
<p style="width: 30%;" class="center"><img src="/lil-log/assets/images/multi-head-attention.png" alt="Multi-head scaled dot-product attention" /></p>
<p><em>Fig. 1. Illustration of the multi-head scaled dot-product attention mechanism. (Image source: Figure 2 in <a href="https://arxiv.org/abs/1706.03762">Vaswani, et al., 2017</a>)</em></p>
<h2 id="transformer">Transformer</h2>
<p>The <strong>Transformer</strong> (which will be referred to as “vanilla Transformer” to distinguish it from other enhanced versions; <a href="https://arxiv.org/abs/1706.03762">Vaswani, et al., 2017</a>) model has an encoder-decoder architecture, as commonly used in many <a href="/lil-log/2018/06/24/attention-attention.html#born-for-translation">NMT</a> models. Later decoder-only Transformer was shown to achieve great performance in language modeling tasks, like in <a href="/lil-log/2019/01/31/generalized-language-models.html#openai-gpt">GPT and BERT</a>.</p>
<p><strong>Encoder-Decoder Architecture</strong></p>
<p>The <strong>encoder</strong> generates an attention-based representation with capability to locate a specific piece of information from a large context. It consists of a stack of 6 identity modules, each containing two submodules, a <em>multi-head self-attention</em> layer and a <em>point-wise</em> fully connected feed-forward network. By point-wise, it means that it applies the same linear transformation (with same weights) to each element in the sequence. This can also be viewed as a convolutional layer with filter size 1. Each submodule has a residual connection and layer normalization. All the submodules output data of the same dimension \(d\).</p>
<p>The function of Transformer <strong>decoder</strong> is to retrieve information from the encoded representation. The architecture is quite similar to the encoder, except that the decoder contains two multi-head attention submodules instead of one in each identical repeating module. The first multi-head attention submodule is <em>masked</em> to prevent positions from attending to the future.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer.png" alt="Transformer" /></p>
<p><em>Fig. 2. The architecture of the vanilla Transformer model. (Image source: <a href="/lil-log/2018/06/24/attention-attention.html#full-architecture">Figure 17</a>)</em></p>
<p><strong>Positional Encoding</strong></p>
<p>Because self-attention operation is permutation invariant, it is important to use proper <strong>positional encoding</strong>to provide <em>order information</em> to the model. The positional encoding \(\mathbf{P} \in \mathbb{R}^{L \times d}\) has the same dimension as the input embedding, so it can be added on the input directly. The vanilla Transformer considered two types of encodings:</p>
<p>(1) <em>Sinusoidal positional encoding</em> is defined as follows, given the token position \(i=1,\dots,L\) and the dimension \(\delta=1,\dots,d\):</p>
\[\text{PE}(i,\delta) =
\begin{cases}
\sin(\frac{i}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta'\\
\cos(\frac{i}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta' + 1\\
\end{cases}\]
<p>In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in different dimensions, from \(2\pi\) to \(10000 \cdot 2\pi\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/sinoidual-positional-encoding.png" alt="Transformer" /></p>
<p><em>Fig. 3. Sinusoidal positional encoding with \(L=32\) and \(d=128\). The value is between -1 (black) and 1 (white) and the value 0 is in gray.</em></p>
<p>(2) <em>Learned positional encoding</em>, as its name suggested, assigns each element with a learned column vector which encodes its <em>absolute</em> position (<a href="https://arxiv.org/abs/1705.03122">Gehring, et al. 2017</a>).</p>
<p><strong>Quick Follow-ups</strong></p>
<p>Following the vanilla Transformer, <a href="https://arxiv.org/abs/1808.04444">Al-Rfou et al. (2018)</a> added a set of auxiliary losses to enable training a deep Transformer model on character-level language modeling which outperformed LSTMs. Several types of auxiliary tasks are used:</p>
<ul>
<li>Instead of producing only one prediction at the sequence end, every <em>immediate position</em> is also asked to make a correct prediction, forcing the model to predict given smaller contexts (e.g. first couple tokens at the beginning of a context window).</li>
<li>Each intermediate Transformer layer is used for making predictions as well. Lower layers are weighted to contribute less and less to the total loss as training progresses.</li>
<li>Each position in the sequence can predict multiple targets, i.e. two or more predictions of the future tokens.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer-aux-losses.png" alt="Transformer" /></p>
<p><em>Fig. 4. Auxiliary prediction tasks used in deep Transformer for character-level language modeling. (Image source: <a href="https://arxiv.org/abs/1808.04444">Al-Rfou et al. (2018)</a>)</em></p>
<h2 id="adaptive-computation-time-act">Adaptive Computation Time (ACT)</h2>
<p><strong>Adaptive Computation Time</strong> (short for <strong>ACT</strong>; <a href="https://arxiv.org/abs/1603.08983">Graves, 2016</a>) is a mechanism for dynamically deciding how many computational steps are needed in a recurrent neural network. Here is a cool <a href="https://distill.pub/2016/augmented-rnns/#adaptive-computation-time">tutorial</a> on ACT from distill.pub.</p>
<p>Let’s say, we have a RNN model \(\mathcal{R}\) composed of input weights \(W_x\), a parametric state transition function \(\mathcal{S}(.)\), a set of output weights \(W_y\) and an output bias \(b_y\). Given an input sequence \((x_1, \dots, x_L)\), the output sequence \((y_1, \dots, y_L)\) is computed by:</p>
\[s_t = \mathcal{S}(s_{t-1}, W_x x_t), \quad y_t = W_y s_t + b_y\quad\text{for }t=1, \dots, L\]
<p>ACT enables the above RNN setup to perform a variable number of steps at each input element. Multiple computational steps lead to a sequence of intermediate states \((s_t^1, \dots, s_t^{N(t)})\) and outputs \((y_t^1, \dots, y_t^{N(t)})\) — they all share the same state transition function \(\mathcal{S}(.)\), as well as the same output weights \(W_y\) and bias \(b_y\):</p>
\[\begin{aligned}
s_t^0 &= s_{t-1} \\
s_t^n &= \mathcal{S}(s_{t}^{n-1}, x_t^n) = \mathcal{S}(s_{t}^{n-1}, x_t + \delta_{n,1}) \text{ for } n=1, \dots, N(t)\\
y_t^n &= W_y s_t^n + b_y
\end{aligned}\]
<p>where \(\delta_{n,1}\) is a binary flag indicating whether the input step has been incremented.</p>
<p>The number of steps \(N(t)\) is determined by an extra sigmoidal halting unit \(h\), with associated weight matrix \(W_h\) and bias \(b_h\), outputting a halting probability \(p_t^n\) at immediate step \(n\) for \(t\)-th input element:</p>
\[h_t^n = \sigma(W_h s_t^n + b_h)\]
<p>In order to allow the computation to halt after a single step, ACT introduces a small constant \(\epsilon\) (e.g. 0.01), so that whenever the cumulative probability goes above \(1-\epsilon\), the computation stops.</p>
\[\begin{aligned}
N(t) &= \min(\min\{n': \sum_{n=1}^{n'} h_t^n \geq 1 -\epsilon\}, M) \\
p_t^n &= \begin{cases}
h_t^n & \text{if }n < N(t) \\
R(t) = 1 - \sum_{n=1}^{N(t)-1} h_t^n & \text{if }n= N(t)\\
\end{cases}
\end{aligned}\]
<p>where \(M\) is an upper limit for the number of immediate steps allowed.</p>
<p>The final state and output are mean-field updates:</p>
\[s_t = \sum_{n=1}^{N(t)} p_t^n s_t^n,\quad y_t = \sum_{n=1}^{N(t)} p_t^n y_t^n\]
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/ACT-computation-graph.png" alt="ACT computation graph" /></p>
<p><em>Fig. 5. The computation graph of a RNN with ACT mechanism. (Image source: <a href="https://arxiv.org/abs/1603.08983">Graves, 2016</a>)</em></p>
<p>To avoid unnecessary pondering over each input, ACT adds a <em>ponder cost</em> \(\mathcal{P}(x) = \sum_{t=1}^L N(t) + R(t)\) in the loss function to encourage a smaller number of intermediate computational steps.</p>
<h2 id="improved-attention-span">Improved Attention Span</h2>
<p>The goal of improving attention span is to make the context that can be used in self-attention longer, more efficient and flexible.</p>
<h3 id="longer-attention-span-transformer-xl">Longer Attention Span (Transformer-XL)</h3>
<p>The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments.</p>
<p>This <em>context segmentation</em> causes several issues:</p>
<ul>
<li>The model cannot capture very long term dependencies.</li>
<li>It is hard to predict the first few tokens in each segment given no or thin context.</li>
<li>The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens.</li>
</ul>
<p><strong>Transformer-XL</strong> (<a href="https://arxiv.org/abs/1901.02860">Dai et al., 2019</a>; “XL” means “extra long”) solves the context segmentation problem with two main modifications:</p>
<ol>
<li>Reusing hidden states between segments.</li>
<li>Adopting a new positional encoding that is suitable for reused states.</li>
</ol>
<p><strong>Hidden State Reuse</strong></p>
<p>The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer-XL-training.png" alt="Training phrase of Transformer-XL" /></p>
<p><em>Fig. 6. A comparison between the training phrase of vanilla Transformer & Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in <a href="https://arxiv.org/abs/1901.02860">Dai et al., 2019</a>).</em></p>
<p>Let’s label the hidden state of the \(n\)-th layer for the \((\tau + 1)\)-th segment in the model as \(\mathbf{h}_{\tau+1}^{(n)} \in \mathbb{R}^{L \times d}\). In addition to the hidden state of the last layer for the same segment \(\mathbf{h}_{\tau+1}^{(n-1)}\), it also depends on the hidden state of the same layer for the previous segment \(\mathbf{h}_{\tau}^{(n)}\). By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments.</p>
\[\begin{aligned}
\color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} &= [\text{stop-gradient}(\mathbf{h}_{\tau}^{(n-1)}) \circ \mathbf{h}_{\tau+1}^{(n-1)}] \\
\mathbf{Q}_{\tau+1}^{(n)} &= \mathbf{h}_{\tau+1}^{(n-1)}\mathbf{W}^q \\
\mathbf{K}_{\tau+1}^{(n)} &= \color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} \mathbf{W}^k \\
\mathbf{V}_{\tau+1}^{(n)} &= \color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} \mathbf{W}^v \\
\mathbf{h}_{\tau+1}^{(n)} &= \text{transformer-layer}(\mathbf{Q}_{\tau+1}^{(n)}, \mathbf{K}_{\tau+1}^{(n)}, \mathbf{V}_{\tau+1}^{(n)})
\end{aligned}\]
<p>Note that both key and value rely on the extended hidden state, while the query only consumes hidden state at current step. The concatenation operation \([. \circ .]\) is along the sequence length dimension.</p>
<p><strong>Relative Positional Encoding</strong></p>
<p>In order to work with this new form of attention span, Transformer-XL proposed a new type of positional encoding. If using the same approach by vanilla Transformer and encoding the absolute position, the previous and current segments will be assigned with the same encoding, which is undesired.</p>
<p>To keep the positional information flow coherently across segments, Transformer-XL encodes the <em>relative</em> position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. \(i-j\), between one key vector \(\mathbf{k}_{\tau, j}\) and its query \(\mathbf{q}_{\tau, i}\).</p>
<p>If omitting the scalar \(1/\sqrt{d_k}\) and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position \(i\) and key at position \(j\) as:</p>
\[\begin{aligned}
a_{ij}
&= \mathbf{q}_i {\mathbf{k}_j}^\top = (\mathbf{x}_i + \mathbf{p}_i)\mathbf{W}^q ((\mathbf{x}_j + \mathbf{p}_j)\mathbf{W}^k)^\top \\
&= \mathbf{x}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{x}_j^\top + \mathbf{x}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{p}_j^\top + \mathbf{p}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{x}_j^\top + \mathbf{p}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{p}_j^\top
\end{aligned}\]
<p>Transformer-XL reparameterizes the above four terms as follows:</p>
\[a_{ij}^\text{rel} =
\underbrace{ \mathbf{x}_i\mathbf{W}^q \color{blue}{ {\mathbf{W}_E^k}^\top } \mathbf{x}_j^\top }_\text{content-based addressing} +
\underbrace{ \mathbf{x}_i\mathbf{W}^q \color{blue}{ {\mathbf{W}_R^k}^\top } \color{green}{\mathbf{r}_{i-j}^\top} }_\text{content-dependent positional bias} +
\underbrace{ \color{red}{\mathbf{u}} \color{blue}{ {\mathbf{W}_E^k}^\top } \mathbf{x}_j^\top }_\text{global content bias} +
\underbrace{ \color{red}{\mathbf{v}} \color{blue}{ {\mathbf{W}_R^k}^\top } \color{green}{\mathbf{r}_{i-j}^\top} }_\text{global positional bias}\]
<ul>
<li>Replace \(\mathbf{p}_j\) with relative positional encoding \(\mathbf{r}_{i-j} \in \mathbf{R}^{d}\);</li>
<li>Replace \(\mathbf{p}_i\mathbf{W}^q\) with two trainable parameters \(\mathbf{u}\) (for content) and \(\mathbf{v}\) (for location) in two different terms;</li>
<li>Split \(\mathbf{W}^k\) into two matrices, \(\mathbf{W}^k_E\) for content information and \(\mathbf{W}^k_R\) for location information.</li>
</ul>
<h3 id="adaptive-attention-span">Adaptive Attention Span</h3>
<p>One key advantage of Transformer is the capability of capturing long-term dependencies. Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other. If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.</p>
<p>This is the motivation for <strong>Adaptive Attention Span</strong>. <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al., (2019)</a> proposed a self-attention mechanism that seeks an optimal attention span. They hypothesized that different attention heads might assign scores differently within the same context window (See Fig. 7) and thus the optimal span would be trained separately per head.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/attention-per-head.png" alt="Attention per head" /></p>
<p><em>Fig. 7. Two attention heads in the same model, A & B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al. 2019</a>)</em></p>
<p>Given the \(i\)-th token, we need to compute the attention weights between this token and other keys at positions \(j \in S_i\), where \(S_i\) defineds the \(i\)-th token’s context window.</p>
\[\begin{aligned}
e_{ij} &= \mathbf{q}_i {\mathbf{k}_j}^\top \\
a_{ij} &= \text{softmax}(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{r=i-s}^{i-1} \exp(e_{ir})} \\
\mathbf{y}_i &= \sum_{r=i-s}^{i-1}a_{ir}\mathbf{v}_r = \sum_{r=i-s}^{i-1}a_{ir}\mathbf{x}_r\mathbf{W}^v
\end{aligned}\]
<p>A <em>soft mask function</em> \(m_z\) is added to control for an effective adjustable attention span, which maps the distance between query and key into a [0, 1] value. \(m_z\) is parameterized by \(z \in [0, s]\) and \(z\) is to be learned:</p>
\[m_z(x) = \text{clamp}(\frac{1}{R}(R+z-x), 0, 1)\]
<p>where \(R\) is a hyper-parameter which defines the softness of \(m_z\).</p>
<p style="width: 55%;" class="center"><img src="/lil-log/assets/images/soft-masking-function.png" alt="Soft masking function" /></p>
<p><em>Fig. 8. The soft masking function used in the adaptive attention span. (Image source: <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al. 2019</a>.)</em></p>
<p>The soft mask function is applied to the softmax elements in the attention weights:</p>
\[a_{ij} = \frac{m_z(i-j)\exp(s_{ij})}{\sum_{r=i-s}^{i-1}m_z(i-r) \exp(s_{ir})}\]
<p>In the above equation, \(z\) is differentiable so it is trained jointly with other parts of the model. Parameters \(z^{(i)}, i=1, \dots, h\) are learned <em>separately per head</em>. Moreover, the loss function has an extra L1 penalty on \(\sum_{i=1}^h z^{(i)}\).</p>
<p>Using <a href="#adaptive-computation-time-act">Adaptive Computation Time</a>, the approach can be further enhanced to have flexible attention span length, adaptive to the current input dynamically. The span parameter \(z_t\) of an attention head at time \(t\) is a sigmoidal function, \(z_t = S \sigma(\mathbf{v} \cdot \mathbf{x}_t +b)\), where the vector \(\mathbf{v}\) and the bias scalar \(b\) are learned jointly with other parameters.</p>
<p>In the experiments of Transformer with adaptive attention span, <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al. (2019)</a> found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans. Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.</p>
<h3 id="localized-attention-span-image-transformer">Localized Attention Span (Image Transformer)</h3>
<p>The original, also the most popular, use case for Transformer is to do language modeling. The text sequence is one-dimensional in a clearly defined chronological order and thus the attention span grows linearly with increased context size.</p>
<p>However, if we want to use Transformer on images, it is unclear how to define the scope of context or the order. <strong>Image Transformer</strong> (<a href="https://arxiv.org/abs/1802.05751">Parmer, et al 2018</a>) embraces a formulation of image generation similar to sequence modeling within the Transformer framework. Additionally, Image Transformer restricts the self-attention span to only <em>local</em> neighborhoods, so that the model can scale up to process more images in parallel and keep the likelihood loss tractable.</p>
<p>The encoder-decoder architecture remains for image-conditioned generation:</p>
<ul>
<li>The encoder generates a contextualized, per-pixel-channel representation of the source image;</li>
<li>The decoder <em>autoregressively</em> generates an output image, one channel per pixel at each time step.</li>
</ul>
<p>Let’s label the representation of the current pixel to be generated as the query \(\mathbf{q}\). Other positions whose representations will be used for computing \(\mathbf{q}\) are key vector \(\mathbf{k}_1, \mathbf{k}_2, \dots\) and they together form a memory matrix \(\mathbf{M}\). The scope of \(\mathbf{M}\) defines the context window for pixel query \(\mathbf{q}\).</p>
<p>Image Transformer introduced two types of localized \(\mathbf{M}\), as illustrated below.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/image-transformer-attention.png" alt="Attention patterns in Image Transformer" /></p>
<p><em>Fig. 9. Illustration of 1D and 2D attention span for visual inputs in Image Transformer. The black line marks a query block and the cyan outlines the actual attention span for pixel q. (Image source: Figure 2 in <a href="https://arxiv.org/abs/1802.05751">Parmer et al, 2018</a>)</em></p>
<p>(1) <em>1D Local Attention</em>: The input image is flattened in the <a href="https://en.wikipedia.org/wiki/Raster_scan#Scanning_pattern">raster scanning</a> order, that is, from left to right and top to bottom. The linearized image is then partitioned into non-overlapping query blocks. The context window consists of pixels in the same query block as \(\mathbf{q}\) and a fixed number of additional pixels generated before this query block.</p>
<p>(2) <em>2D Local Attention</em>: The image is partitioned into multiple non-overlapping rectangular query blocks. The query pixel can attend to all others in the same memory blocks. To make sure the pixel at the top-left corner can also have a valid context window, the memory block is extended to the top, left and right by a fixed amount, respectively.</p>
<h2 id="less-time-and-memory-cost">Less Time and Memory Cost</h2>
<p>This section introduces several improvements made on Transformer to reduce the computation time and memory consumption.</p>
<h3 id="sparse-attention-matrix-factorization-sparse-transformers">Sparse Attention Matrix Factorization (Sparse Transformers)</h3>
<p>The compute and memory cost of the vanilla Transformer grows quadratically with sequence length and thus it is hard to be applied on very long sequences.</p>
<p><strong>Sparse Transformer</strong> (<a href="https://arxiv.org/abs/1904.10509">Child et al., 2019</a>) introduced <em>factorized self-attention</em>, through sparse matrix factorization, making it possible to train dense attention networks with hundreds of layers on sequence length up to 16,384, which would be infeasible on modern hardware otherwise.</p>
<p>Given a set of attention connectivity pattern \(\mathcal{S} = \{S_1, \dots, S_n\}\), where each \(S_i\) records a set of key positions that the \(i\)-th query vector attends to.</p>
\[\begin{aligned}
\text{Attend}(\mathbf{X}, \mathcal{S}) &= \Big( a(\mathbf{x}_i, S_i) \Big)_{i \in \{1, \dots, L\}} \\
\text{ where } a(\mathbf{x}_i, S_i) &= \text{softmax}\Big(\frac{(\mathbf{x}_i \mathbf{W}^q)(\mathbf{x}_j \mathbf{W}^k)_{j \in S_i}^\top}{\sqrt{d_k}}\Big) (\mathbf{x}_j \mathbf{W}^v)_{j \in S_i}
\end{aligned}\]
<p>Note that although the size of \(S_i\) is not fixed, \(a(\mathbf{x}_i, S_i)\) is always of size \(d_v\) and thus \(\text{Attend}(\mathbf{X}, \mathcal{S}) \in \mathbb{R}^{L \times d_v}\).</p>
<p>In anto-regressive models, one attention span is defined as \(S_i = \{j: j \leq i\}\) as it allows each token to attend to all the positions in the past.</p>
<p>In factorized self-attention, the set \(S_i\) is decomposed into a <em>tree</em> of dependencies, such that for every pair of \((i, j)\) where \(j \leq i\), there is a path connecting \(i\) back to \(j\) and \(i\) can attend to \(j\) either directly or indirectly.</p>
<p>Precisely, the set \(S_i\) is divided into \(p\) <em>non-overlapping</em> subsets, where the \(m\)-th subset is denoted as \(A^{(m)}_i \subset S_i, m = 1,\dots, p\). Therefore the path between the output position \(i\) and any \(j\) has a maximum length \(p + 1\). For example, if \((j, a, b, c, \dots, i)\) is a path of indices between \(i\) and \(j\), we would have \(j \in A_a^{(1)}, a \in A_b^{(2)}, b \in A_c^{(3)}, \dots\), so on and so forth.</p>
<p><strong>Sparse Factorized Attention</strong></p>
<p>Sparse Transformer proposed two types of fractorized attention. It is easier to understand the concepts as illustrated in Fig. 10 with 2D image inputs as examples.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/sparse-attention.png" alt="Sparse attention" /></p>
<p><em>Fig. 10. The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: <a href="https://arxiv.org/abs/1904.10509">Child et al., 2019</a> + a few of extra annotations.)</em></p>
<p>(1) <em>Strided</em> attention with stride \(\ell \sim \sqrt{n}\). This works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous \(\ell\) pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).</p>
\[\begin{aligned}
A_i^{(1)} &= \{ t, t+1, \dots, i\} \text{, where } t = \max(0, i - \ell) \\
A_i^{(2)} &= \{j: (i-j) \mod \ell = 0\}
\end{aligned}\]
<p>(2) <em>Fixed</em> attention. A small set of tokens summarize previous locations and propagate that information to all future locations.</p>
\[\begin{aligned}
A_i^{(1)} &= \{j: \lfloor \frac{j}{\ell} \rfloor = \lfloor \frac{i}{\ell} \rfloor \} \\
A_i^{(2)} &= \{j: j \mod \ell \in \{\ell-c, \dots, \ell-1\} \}
\end{aligned}\]
<p>where \(c\) is a hyperparameter. If \(c=1\), it restricts the representation whereas many depend on a few positions. The paper chose \(c\in \{ 8, 16, 32 \}\) for \(\ell \in \{ 128, 256 \}\).</p>
<p><strong>Use Factorized Self-Attention in Transformer</strong></p>
<p>There are three ways to use sparse factorized attention patterns in Transformer architecture:</p>
<ol>
<li>One attention type per residual block and then interleave them, <br />
\(\text{attention}(\mathbf{X}) = \text{Attend}(\mathbf{X}, A^{(n \mod p)}) \mathbf{W}^o\), where \(n\) is the index of the current residual block.</li>
<li>Set up a single head which attends to locations that all the factorized heads attend to, <br />
\(\text{attention}(\mathbf{X}) = \text{Attend}(\mathbf{X}, \cup_{m=1}^p A^{(m)}) \mathbf{W}^o\).</li>
<li>Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2. => This option often performs the best.</li>
</ol>
<p>Sparse Transformer also proposed a set of changes so as to train the Transformer up to hundreds of layers, including gradient checkpointing, recomputing attention & FF layers during the backward pass, mixed precision training, efficient block-sparse implementation, etc. Please check the <a href="https://arxiv.org/abs/1904.10509">paper</a> for more details.</p>
<h3 id="locality-sensitive-hashing-reformer">Locality-Sensitive Hashing (Reformer)</h3>
<p>The improvements proposed by the <strong>Reformer</strong> model (<a href="https://arxiv.org/abs/2001.04451">Kitaev, et al. 2020</a>) aim to solve the following pain points in Transformer:</p>
<ul>
<li>Memory in a model with \(N\) layers is \(N\)-times larger than in a single-layer model because we need to store activations for back-propagation.</li>
<li>The intermediate FF layers are often quite large.</li>
<li>The attention matrix on sequences of length \(L\) often requires \(O(L^2)\) in both memory and time.</li>
</ul>
<p>Reformer proposed two main changes:</p>
<ol>
<li>Replace the dot-product attention with <em>locality-sensitive hashing (LSH) attention</em>, reducing the complexity from \(O(L^2)\) to \(O(L\log L)\).</li>
<li>Replace the standard residual blocks with <em>reversible residual layers</em>, which allows storing activations only once during training instead of \(N\) times (i.e. proportional to the number of layers).</li>
</ol>
<p><a name="LSH"></a><strong>Locality-Sensitive Hashing Attention</strong></p>
<p>In \(\mathbf{Q} \mathbf{K}^\top\) part of the <a href="#attention-and-self-attention">attention formula</a>, we are only interested in the largest elements as only large elements contribute a lot after softmax. For each query \(\mathbf{q}_i \in \mathbf{Q}\), we are looking for row vectors in \(\mathbf{K}\) closest to \(\mathbf{q}_i\). In order to find nearest neighbors quickly in high-dimensional space, Reformer incorporates <a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing">Locality-Sensitive Hashing (LSH)</a> into its attention mechanism.</p>
<p>A hashing scheme \(x \mapsto h(x)\) is <em>locality-sensitive</em> if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. The Reformer adopts a hashing scheme as such, given a fixed random matrix \(\mathbf{R} \in \mathbb{R}^{d \times b/2}\) (where \(b\) is a hyperparam), the hash function is \(h(x) = \arg\max([xR; −xR])\).</p>
<!-- If we omit the scalar in self-attention and summarize the denominator into a normalizing term $$Z(.)$$, an normal attention output looks as follows:
$$
\mathbf{o}_i = \sum_{j \in S_i} \exp(\mathbf{q}_i \cdot \mathbf{k}_j - Z(i, S_i)) \mathbf{v}_j \text{, where } S_i = \{j: j \leq i\}
$$
-->
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/LSH-attention-matrix.png" alt="LSH attention matrix" /></p>
<p><em>Fig. 11. Illustration of Locality-Sensitive Hashing (LSH) attention. (Image source: right part of Figure 1 in <a href="https://arxiv.org/abs/2001.04451">Kitaev, et al. 2020</a>).</em></p>
<p>In LSH attention, a query can only attend to positions in the same hashing bucket, \(S_i = \{j: h(\mathbf{q}_i) = h(\mathbf{k}_j)\}\). It is carried out in the following process, as illustrated in Fig. 11:</p>
<ul>
<li>(a) The attention matrix for full attention is often sparse.</li>
<li>(b) Using LSH, we can sort the keys and queries to be aligned according to their hash buckets.</li>
<li>(c) Set \(\mathbf{Q} = \mathbf{K}\) (precisely \(\mathbf{k}_j = \mathbf{q}_j / \|\mathbf{q}_j\|\)), so that there are equal numbers of keys and queries in one bucket, easier for batching. Interestingly, this “shared-QK” config does not affect the performance of the Transformer.</li>
<li>(d) Apply batching where chunks of \(m\) consecutive queries are grouped together.</li>
</ul>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/LSH-attention.png" alt="LSH attention" /></p>
<p><em>Fig. 12. The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in <a href="https://arxiv.org/abs/2001.04451">Kitaev, et al. 2020</a>).</em></p>
<p><strong>Reversible Residual Network</strong></p>
<p>Another improvement by Reformer is to use <em>reversible residual layers</em> (<a href="https://arxiv.org/abs/1707.04585">Gomez et al. 2017</a>). The motivation for reversible residual network is to design the architecture in a way that activations at any given layer can be recovered from the activations at the following layer, using only the model parameters. Hence, we can save memory by recomputing the activation during backprop rather than storing all the activations.</p>
<p>Given a layer \(x \mapsto y\), the normal residual layer does \(y = x + F(x)\), but the reversible layer splits both input and output into pairs \((x_1, x_2) \mapsto (y_1, y_2)\) and then executes the following:</p>
\[y_1 = x_1 + F(x_2),\; y_2 = x_2 + G(y_1)\]
<p>and reversing is easy:</p>
\[x_2 = y_2 - G(y_1), \; x_1 = y_1 − F(x_2)\]
<p>Reformer applies the same idea to Transformer by combination attention (\(F\)) and feed-forward layers (\(G\)) within a reversible net block:</p>
\[Y_1 = X_1 + \text{Attention}(X_2), \; Y_2 = X_2 + \text{FeedForward}(Y_1)\]
<p>The memory can be further reduced by chunking the feed-forward computation:
\(Y_2 = [Y_2^{(1)}; \dots; Y_2^{(c)}] = [X_2^{(1)} + \text{FeedForward}(Y_1^{(1)}); \dots; X_2^{(c)} + \text{FeedForward}(Y_1^{(c)})]\)</p>
<p>The resulting reversible Transformer does not need to store activation in every layer.</p>
<h2 id="make-it-recurrent-universal-transformer">Make it Recurrent (Universal Transformer)</h2>
<p>The <strong>Universal Transformer</strong> (<a href="https://arxiv.org/abs/1807.03819">Dehghani, et al. 2019</a>) combines self-attention in Transformer with the recurrent mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned inductive biases of RNN.</p>
<p>Rather than going through a fixed number of layers, Universal Transformer dynamically adjusts the number of steps using <a href="#adaptive-computation-time-act">adaptive computation time</a>. If we fix the number of steps, an Universal Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.</p>
<p>On a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state representation per token. The recurrent function evolves in parallel across token positions and the information between positions is shared through self-attention.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/universal-transformer-loop.png" alt="Universal Transformer Recurrent Step" /></p>
<p><em>Fig. 13. How the Universal Transformer refines a set of hidden state representations repeatedly for every position in parallel. (Image source: Figure 1 in <a href="https://arxiv.org/abs/1807.03819">Dehghani, et al. 2019</a>).</em></p>
<p>Given an input sequence of length \(L\), Universal Transformer iteratively updates the representation \(\mathbf{H}^t \in \mathbb{R}^{L \times d}\) at step \(t\) for an adjustable number of steps. At step 0, \(\mathbf{H}^0\) is initialized to be same as the input embedding matrix. All the positions are processed in parallel in the multi-head self-attention mechanism and then go through a recurrent transition function.</p>
\[\begin{aligned}
\mathbf{A}^t &= \text{LayerNorm}(\mathbf{H}^{t-1} + \text{MultiHeadAttention}(\mathbf{H}^{t-1} + \mathbf{P}^t) \\
\mathbf{H}^t &= \text{LayerNorm}(\mathbf{A}^{t-1} + \text{Transition}(\mathbf{A}^t))
\end{aligned}\]
<p>where \(\text{Transition}(.)\) is either a <a href="https://arxiv.org/abs/1610.02357">separable convolution</a> or a fully-connected neural network that consists of two position-wise (i.e. applied to each row of \(\mathbf{A}^t\) individually) affine transformation + one ReLU.</p>
<p>The positional encoding \(\mathbf{P}^t\) uses sinusoidal position signal but with an additional time dimension:</p>
\[\text{PE}(i, t, \delta) =
\begin{cases}
\sin(\frac{i}{10000^{2\delta'/d}}) \oplus \sin(\frac{t}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta'\\
\cos(\frac{i}{10000^{2\delta'/d}}) \oplus \cos(\frac{t}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta' + 1\\
\end{cases}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/universal-transformer.png" alt="Universal Transformer" /></p>
<p><em>Fig. 14. A simplified illustration of Universal Transformer. The encoder and decoder share the same basic recurrent structure. But the decoder also attends to final encoder representation \(\mathbf{H}^T\). (Image source: Figure 2 in <a href="https://arxiv.org/abs/1807.03819">Dehghani, et al. 2019</a>)</em></p>
<p>In the adaptive version of Universal Transformer, the number of recurrent steps \(T\) is dynamically determined by <a href="#adaptive-computation-time-act">ACT</a>. Each position is equipped with a dynamic ACT halting mechanism. Once a per-token recurrent block halts, it stops taking more recurrent updates but simply copies the current value to the next step until all the blocks halt or until the model reaches a maximum step limit.</p>
<h2 id="stabilization-for-rl-gtrxl">Stabilization for RL (GTrXL)</h2>
<p>The self-attention mechanism avoids compressing the whole past into a fixed-size hidden state and does not suffer from vanishing or exploding gradients as much as RNNs. Reinforcement Learning tasks can for sure benefit from these traits. <em>However</em>, it is quite difficult to train Transformer even in supervised learning, let alone in the RL context. It could be quite challenging to stabilize and train a LSTM agent by itself, after all.</p>
<p>The <strong>Gated Transformer-XL</strong> (<strong>GTrXL</strong>; <a href="https://arxiv.org/abs/1910.06764">Parisotto, et al. 2019</a>) is one attempt to use Transformer for RL. GTrXL succeeded in stabilizing training with two changes on top of <a href="#longer-attention-span-transformer-xl">Transformer-XL</a>:</p>
<ol>
<li>The layer normalization is only applied on the input stream in a residual module, but NOT on the shortcut stream. A key benefit to this reordering is to allow the original input to flow from the first to last layer.</li>
<li>The residual connection is replaced with a GRU-style (Gated Recurrent Unit; <a href="https://arxiv.org/abs/1412.3555">Chung et al., 2014</a>) <em>gating</em> mechanism.</li>
</ol>
\[\begin{aligned}
r &= \sigma(W_r^{(l)} y + U_r^{(l)} x) \\
z &= \sigma(W_z^{(l)} y + U_z^{(l)} x - b_g^{(l)}) \\
\hat{h} &= \tanh(W_g^{(l)} y + U_g^{(l)} (r \odot x)) \\
g^{(l)}(x, y) &= (1-z)\odot x + z\odot \hat{h}
\end{aligned}\]
<p>The gating function parameters are explicitly initialized to be close to an identity map - this is why there is a \(b_g\) term. A \(b_g > 0\) greatly helps with the learning speedup.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/gated-transformer-XL.png" alt="GTrXL" /></p>
<p><em>Fig. 15. Comparison of the model architecture of Transformer-XL, Transformer-XL with the layer norm reordered, and Gated Transformer-XL. (Image source: Figure 1 in <a href="https://arxiv.org/abs/1910.06764">Parisotto, et al. 2019</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020transformer,
title = "The Transformer Family",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/03/27/the-transformer-family.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Ashish Vaswani, et al. <a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">“Attention is all you need.”</a> NIPS 2017.</p>
<p>[2] Rami Al-Rfou, et al. <a href="https://arxiv.org/abs/1808.04444">“Character-level language modeling with deeper self-attention.”</a> AAAI 2019.</p>
<p>[3] Olah & Carter, <a href="http://doi.org/10.23915/disti">“Attention and Augmented Recurrent Neural Networks”</a>, Distill, 2016.</p>
<p>[4] Sainbayar Sukhbaatar, et al. <a href="https://arxiv.org/abs/1905.07799">“Adaptive Attention Span in Transformers”</a>. ACL 2019.</p>
<p>[5] Rewon Child, et al. <a href="https://arxiv.org/abs/1904.10509">“Generating Long Sequences with Sparse Transformers”</a> arXiv:1904.10509 (2019).</p>
<p>[6] Nikita Kitaev, et al. <a href="https://arxiv.org/abs/2001.04451">“Reformer: The Efficient Transformer”</a> ICLR 2020.</p>
<p>[7] Alex Graves. (“Adaptive Computation Time for Recurrent Neural Networks”)[https://arxiv.org/abs/1603.08983]</p>
<p>[8] Niki Parmar, et al. <a href="https://arxiv.org/abs/1802.05751">“Image Transformer”</a> ICML 2018.</p>
<p>[9] Zihang Dai, et al. <a href="https://arxiv.org/abs/1901.02860">“Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.”</a> ACL 2019.</p>
<p>[10] Aidan N. Gomez, et al. <a href="https://arxiv.org/abs/1707.04585">“The Reversible Residual Network: Backpropagation Without Storing Activations”</a> NIPS 2017.</p>
<p>[11] Mostafa Dehghani, et al. <a href="https://arxiv.org/abs/1807.03819">“Universal Transformers”</a> ICLR 2019.</p>
<p>[12] Emilio Parisotto, et al. <a href="https://arxiv.org/abs/1910.06764">“Stabilizing Transformers for Reinforcement Learning”</a> arXiv:1910.06764 (2019).</p>Lilian WengInspired by recent progress on various enhanced versions of Transformer models, this post presents how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving, etc.Curriculum for Reinforcement Learning2020-01-29T18:00:00+00:002020-01-29T18:00:00+00:00https://lilianweng.github.io/lil-log/2020/01/29/curriculum-for-reinforcement-learning<blockquote>
<p>A curriculum is an efficient tool for humans to progressively learn from simple concepts to hard problems. It breaks down complex knowledge by providing a sequence of learning steps of increasing difficulty. In this post, we will examine how the idea of curriculum can help reinforcement learning models learn to solve complicated tasks.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2020-02-03: mentioning <a href="#pcg">PCG</a> in the “Task-Specific Curriculum” section.</span><br />
<span style="color: #286ee0;">[Updated on 2020-02-04: Add a new <a href="#curriculum-through-distillation">“curriculum through distillation”</a> section.</span></p>
<p>It sounds like an impossible task if we want to teach integral or derivative to a 3-year-old who does not even know basic arithmetics. That’s why education is important, as it provides a systematic way to break down complex knowledge and a nice curriculum for teaching concepts from simple to hard. A curriculum makes learning difficult things easier and approachable for us humans. But, how about machine learning models? Can we train our models more efficiently with a curriculum? Can we design a curriculum to speed up learning?</p>
<p>Back in 1993, Jeffrey Elman has proposed the idea of training neural networks with a curriculum. His early work on learning simple language grammar demonstrated the importance of such a strategy: starting with a restricted set of simple data and gradually increasing the complexity of training samples; otherwise the model was not able to learn at all.</p>
<p>Compared to training without a curriculum, we would expect the adoption of the curriculum to expedite the speed of convergence and may or may not improve the final model performance. To design an efficient and effective curriculum is not easy. Keep in mind that, a bad curriculum may even hamper learning.</p>
<p>Next, we will look into several categories of curriculum learning, as illustrated in Fig. 1. Most cases are applied to Reinforcement Learning, with a few exceptions on Supervised Learning.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/types-of-curriculum-2.png" alt="Types of curriculum" /></p>
<p><em>Fig. 1. Five types of curriculum for reinforcement learning.</em></p>
<p>In “The importance of starting small” paper (<a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.128.4487&rep=rep1&type=pdf">Elman 1993</a>), I especially like the starting sentences and find them both inspiring and affecting:</p>
<blockquote>
<p>“Humans differ from other species along many dimensions, but two are particularly noteworthy. Humans display an exceptional capacity to learn; and humans are remarkable for the unusually long time it takes to reach maturity. The adaptive advantage of learning is clear, and it may be argued that, through culture, learning has created the basis for a non-genetically based transmission of behaviors which may accelerate the evolution of our species.”</p>
</blockquote>
<p>Indeed, learning is probably the best superpower we humans have.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#task-specific-curriculum" id="markdown-toc-task-specific-curriculum">Task-Specific Curriculum</a></li>
<li><a href="#teacher-guided-curriculum" id="markdown-toc-teacher-guided-curriculum">Teacher-Guided Curriculum</a></li>
<li><a href="#curriculum-through-self-play" id="markdown-toc-curriculum-through-self-play">Curriculum through Self-Play</a></li>
<li><a href="#automatic-goal-generation" id="markdown-toc-automatic-goal-generation">Automatic Goal Generation</a></li>
<li><a href="#skill-based-curriculum" id="markdown-toc-skill-based-curriculum">Skill-Based Curriculum</a></li>
<li><a href="#curriculum-through-distillation" id="markdown-toc-curriculum-through-distillation">Curriculum through Distillation</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="task-specific-curriculum">Task-Specific Curriculum</h2>
<p><a href="https://www.researchgate.net/profile/Y_Bengio/publication/221344862_Curriculum_learning/links/546cd2570cf2193b94c577ac/Curriculum-learning.pdf">Bengio, et al. (2009)</a> provided a good overview of curriculum learning in the old days. The paper presented two ideas with toy experiments using a manually designed task-specific curriculum:</p>
<ol>
<li>Cleaner Examples may yield better generalization faster.</li>
<li>Introducing gradually more difficult examples speeds up online training.</li>
</ol>
<p>It is plausible that some curriculum strategies could be useless or even harmful. A good question to answer in the field is: <em>What could be the general principles that make some curriculum strategies work better than others?</em> The Bengio 2009 paper hypothesized it would be beneficial to make learning focus on “interesting” examples that are neither too hard or too easy.</p>
<p>If our naive curriculum is to train the model on samples with a gradually increasing level of complexity, we need a way to quantify the difficulty of a task first. One idea is to use its minimal loss with respect to another model while this model is pretrained on other tasks (<a href="https://arxiv.org/abs/1802.03796">Weinshall, et al. 2018</a>). In this way, the knowledge of the pretrained model can be transferred to the new model by suggesting a rank of training samples. Fig. 2 shows the effectiveness of the <code class="language-plaintext highlighter-rouge">curriculum</code> group (green), compared to <code class="language-plaintext highlighter-rouge">control</code> (random order; yellow) and <code class="language-plaintext highlighter-rouge">anti</code> (reverse the order; red) groups.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/curriculum-by-transfer-learning.png" alt="Curriculum by transfer learning" /></p>
<p><em>Fig. 2. Image classification accuracy on test image set (5 member classes of “small mammals” in CIFAR100). There are 4 experimental groups, (a) <code class="language-plaintext highlighter-rouge">curriculum</code>: sort the labels by the confidence of another trained classifier (e.g. the margin of an SVM); (b) <code class="language-plaintext highlighter-rouge">control-curriculum</code>: sort the labels randomly; (c) <code class="language-plaintext highlighter-rouge">anti-curriculum</code>: sort the labels reversely; (d) <code class="language-plaintext highlighter-rouge">None</code>: no curriculum. (Image source: <a href="https://arxiv.org/abs/1802.03796">Weinshall, et al. 2018</a>)</em></p>
<p><a href="https://arxiv.org/abs/1410.4615">Zaremba & Sutskever (2014)</a> did an interesting experiment on training LSTM to predict the output of a short Python program for mathematical ops without actually executing the code. They found curriculum is necessary for learning. The program’s complexity is controlled by two parameters, <code class="language-plaintext highlighter-rouge">length</code> ∈ [1, a] and <code class="language-plaintext highlighter-rouge">nesting</code>∈ [1, b]. Three strategies are considered:</p>
<ol>
<li>Naive curriculum: increase <code class="language-plaintext highlighter-rouge">length</code> first until reaching <code class="language-plaintext highlighter-rouge">a</code>; then increase <code class="language-plaintext highlighter-rouge">nesting</code> and reset <code class="language-plaintext highlighter-rouge">length</code> to 1; repeat this process until both reach maximum.</li>
<li>Mix curriculum: sample <code class="language-plaintext highlighter-rouge">length</code> ~ [1, a] and <code class="language-plaintext highlighter-rouge">nesting</code> ~ [1, b]</li>
<li>Combined: naive + mix.</li>
</ol>
<p>They noticed that combined strategy always outperformed the naive curriculum and would generally (but not always) outperform the mix strategy — indicating that it is quite important to mix in easy tasks during training to <em>avoid forgetting</em>.</p>
<p><a name="pcg"></a>Procedural content generation (<a href="https://en.wikipedia.org/wiki/Procedural_generation">PCG</a>) is a popular approach for creating video games of various levels of difficulty. PCG involves algorithmic randomness and a heavy dose of human expertise in designing game elements and dependencies among them. Procedurally generated levels have been introduced into several benchmark environments for evaluating whether an RL agent can generalize to a new level that it is not trained on (<a href="/lil-log/2019/06/23/meta-reinforcement-learning.html">meta-RL</a>!), such as <a href="http://www.gvgai.net/">GVGAI</a>, OpenAI <a href="https://openai.com/blog/quantifying-generalization-in-reinforcement-learning/">CoinRun</a> and <a href="https://openai.com/blog/procgen-benchmark/">Procgen benchmark</a>. Using GVGAI, <a href="https://arxiv.org/abs/1806.10729">Justesen, et al. (2018)</a> demonstrated that an RL policy can easily overfit to a specific game but training over a simple curriculum that grows the task difficulty together with the model performance helps its generalization to new human-designed levels. Similar results are also found in CoinRun (<a href="https://arxiv.org/abs/1812.02341">Cobbe, et al. 2018</a>). POET (<a href="https://arxiv.org/abs/1901.01753">Wang et al, 2019</a>) is another example for leveraging evolutionary algorithm and procedural generated game levels to improve RL generalization, which I’ve described in details in my <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#evolutionary-algorithm-on-environment-generation">meta-RL post</a>.</p>
<p>To follow the curriculum learning approaches described above, generally we need to figure out two problems in the training procedure:</p>
<ol>
<li>Design a metric to quantify how hard a task is so that we can sort tasks accordingly.</li>
<li>Provide a sequence of tasks with an increasing level of difficulty to the model during training.</li>
</ol>
<p>However, the order of tasks does not have to be sequential. In our Rubik’s cube paper (<a href="https://arxiv.org/abs/1910.07113.">OpenAI et al, 2019</a>), we depended on <em>Automatic domain randomization</em> (<strong>ADR</strong>) to generate a curriculum by growing a distribution of environments with increasing complexity. The difficulty of each task (i.e. solving a Rubik’s cube in a set of environments) depends on the randomization ranges of various environmental parameters. Even with a simplified assumption that all the environmental parameters are uncorrelated, we were able to create a decent curriculum for our robot hand to learn the task.</p>
<h2 id="teacher-guided-curriculum">Teacher-Guided Curriculum</h2>
<p><a name="grave-et-al-2017"></a>The idea of <em>Automatic Curriculum Learning</em> was proposed by <a href="https://arxiv.org/abs/1704.03003">Graves, et al. 2017</a> slightly earlier. It considers a \(N\)-task curriculum as an <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html">\(N\)-armed bandit</a> problem and an adaptive policy which learns to optimize the returns from this bandit.</p>
<p>Two categories of learning signals have been considered in the paper:</p>
<ol>
<li>Loss-driven progress: the loss function change before and after one gradient update. This type of reward signals tracks the speed of the learning process, because the greatest task loss decrease is equivalent to the fastest learning.</li>
<li>Complex-driven progress: the KL divergence between posterior and prior distribution over network weights. This type of learning signals are inspired by the <a href="https://en.wikipedia.org/wiki/Minimum_description_length">MDL</a> principle, “increasing the model complexity by a certain amount is only worthwhile if it compresses the data by a greater amount”. The model complexity is therefore expected to increase most in response to the model nicely generalizing to training examples.</li>
</ol>
<p><a name="TSCL"></a>This framework of proposing curriculum automatically through another RL agent was formalized as <em>Teacher-Student Curriculum Learning</em> (<strong>TSCL</strong>; <a href="https://arxiv.org/abs/1707.00183">Matiisen, et al. 2017</a>). In TSCL, a <em>student</em> is an RL agent working on actual tasks while a <em>teacher</em> agent is a policy for selecting tasks. The student aims to master a complex task that might be hard to learn directly. To make this task easier to learn, we set up the teacher agent to guide the student’s training process by picking proper sub-tasks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/teacher-student-curriculum.png" alt="Teacher-student curriculum" /></p>
<p><em>Fig. 3. The setup of teacher-student curriculum learning. (Image source: <a href="https://arxiv.org/abs/1707.00183">Matiisen, et al. 2017</a> + my annotation in red.)</em></p>
<p>In the process, the student should learn tasks which:</p>
<ol>
<li>can help the student make fastest learning progress, or</li>
<li>are at risk of being forgotten.</li>
</ol>
<blockquote>
<p>Note: The setup of framing the teacher model as an RL problem feels quite similar to Neural Architecture Search (NAS), but differently the RL model in TSCL operates on the task space and NAS operates on the main model architecture space.</p>
</blockquote>
<p>Training the teacher model is to solve a <a href="https://en.wikipedia.org/wiki/Partially_observable_Markov_decision_process">POMDP</a> problem:</p>
<ul>
<li>The unobserved \(s_t\) is the full state of the student model.</li>
<li>The observed \(o = (x_t^{(1)}, \dots, x_t^{(N)})\) are a list of scores for \(N\) tasks.</li>
<li>The action \(a\) is to pick on subtask.</li>
<li>The reward per step is the score delta.\(r_t = \sum_{i=1}^N x_t^{(i)} - x_{t-1}^{(i)}\) (i.e., equivalent to maximizing the score of all tasks at the end of the episode).</li>
</ul>
<p>The method of estimating learning progress from noisy task scores while balancing exploration vs exploitation can be borrowed from the non-stationary multi-armed bandit problem — use <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#ε-greedy-algorithm">ε-greedy</a>, or <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#thompson-sampling">Thompson sampling</a>.</p>
<p>The core idea, in summary, is to use one policy to propose tasks for another policy to learn better. Interestingly, both works above (in the discrete task space) found that uniformly sampling from all tasks is a surprisingly strong benchmark.</p>
<p>What if the task space is continuous? <a href="https://arxiv.org/abs/1910.07224">Portelas, et al. (2019)</a> studied a continuous teacher-student framework, where the teacher has to sample parameters from continuous task space to generate a learning curriculum. Given a newly sampled parameter \(p\), the absolute learning progress (short for ALP) is measured as \(\text{ALP}_p = \vert r - r_\text{old} \vert\), where \(r\) is the episodic reward associated with \(p\) and \(r_\text{old}\) is the reward associated with \(p_\text{old}\). Here, \(p_\text{old}\) is a previous sampled parameter closest to \(p\) in the task space, which can be retrieved by nearest neighbor. Note that how this ALP score is different from learning signals in <a href="#TSCL">TSCL</a> or <a href="#grave-et-al-2017">Grave, et al. 2017</a> above: ALP score measures the reward difference between two tasks rather than performance at two time steps of the same task.</p>
<p>On top of the task parameter space, a Gaussian mixture model is trained to fit the distribution of \(\text{ALP}_p\) over \(p\). ε-greedy is used when sampling the tasks: with some probability, sampling a random task; otherwise sampling proportionally to ALP score from the GMM model.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ALP-GMM-algorithm.png" alt="ALP-GMM" /></p>
<p><em>Fig. 4. The algorithm of ALP-GMM (absolute learning progress Gaussian mixture model). (Image source: <a href="https://arxiv.org/abs/1910.07224">Portelas, et al., 2019</a>)</em></p>
<h2 id="curriculum-through-self-play">Curriculum through Self-Play</h2>
<p>Different from the teacher-student framework, two agents are doing very different things. The teacher learns to pick a task for the student without any knowledge of the actual task content. What if we want to make both train on the main task directly? How about even make them compete with each other?</p>
<p><a href="https://arxiv.org/abs/1703.05407">Sukhbaatar, et al. (2017)</a> proposed a framework for automatic curriculum learning through <strong>asymmetric self-play</strong>. Two agents, Alice and Bob, play the same task with different goals: Alice challenges Bob to achieve the same state and Bob attempts to complete it as fast as he can.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/self-play-maze.png" alt="Self-play experiments in MazeBase" /></p>
<p><em>Fig. 5. Illustration of the self-play setup when training two agents. The example task is <a href="https://github.com/facebook/MazeBase">MazeBase</a>: An agent is asked to reach a goal flag in a maze with a light switch, a key and a wall with a door. Toggling the key switch can open or close the door and Turning off the light makes only the glowing light switch available to the agent. (Image source: <a href="https://arxiv.org/abs/1703.05407">Sukhbaatar, et al. 2017</a>)</em></p>
<p>Let us consider Alice and Bob as two separate copies for one RL agent trained in the same environment but with different brains. Each of them has independent parameters and loss objective. The self-play-driven training consists of two types of episodes:</p>
<ul>
<li>In the <em>self-play episode</em>, Alice alters the state from \(s_0\) to \(s_t\) and then Bob is asked to return the environment to its original state \(s_0\) to get an internal reward.</li>
<li>In the <em>target task episode</em>, Bob receives an external reward if he visits the target flag.</li>
</ul>
<p>Note that since B has to repeat the actions between the same pair of \((s_0, s_t)\) of A, this framework only works in reversible or resettable environments.</p>
<p>Alice should learn to push Bob out of his comfort zone, but not give him impossible tasks. Bob’s reward is set as \(R_B = -\gamma t_B\) and Alice’s reward is \(R_A = \gamma \max(0, t_B - t_A)\), where \(t_B\) is the total time for B to complete the task, \(t_A\) is the time until Alice performs the STOP action and \(\gamma\) is a scalar constant to rescale the reward to be comparable with the external task reward. If B fails a task, \(t_B = t_\max - t_A\).
Both policies are goal-conditioned. The losses imply:</p>
<ol>
<li>B wants to finish a task asap.</li>
<li>A prefers tasks that take more time of B.</li>
<li>A does not want to take too many steps when B is failing.</li>
</ol>
<p>In this way, the interaction between Alice and Bob automatically builds a curriculum of increasingly challenging tasks. Meanwhile, as A has done the task herself before proposing the task to B, the task is guaranteed to be solvable.</p>
<p>The paradigm of A suggesting tasks and then B solving them does sound similar to the Teacher-Student framework. However, in asymmetric self-play, Alice, who plays a teacher role, also works on the same task to find challenging cases for Bob, rather than optimizes B’s learning process explicitly.</p>
<h2 id="automatic-goal-generation">Automatic Goal Generation</h2>
<p>Often RL policy needs to be able to perform over a set of tasks. The goal should be carefully chosen so that at every training stage, it would not be too hard or too easy for the current policy. A goal \(g \in \mathcal{G}\) can be defined as a set of states \(S^g\) and a goal is considered as achieved whenever an agent arrives at any of those states.</p>
<p>The approach of Generative Goal Learning (<a href="https://arxiv.org/abs/1705.06366">Florensa, et al. 2018</a>) relies on a <strong>Goal GAN</strong> to generate desired goals automatically. In their experiment, the reward is very sparse, just a binary flag for whether a goal is achieved or not and the policy is conditioned on goal,</p>
\[\begin{aligned}
\pi^{*}(a_t\vert s_t, g) &= \arg\max_\pi \mathbb{E}_{g\sim p_g(.)} R^g(\pi) \\
\text{where }R^g(\pi) &= \mathbb{E}_\pi(.\mid s_t, g) \mathbf{1}[\exists t \in [1,\dots, T]: s_t \in S^g]
\end{aligned}\]
<p>Here \(R^g(\pi)\) is the expected return, also equivalent to the success probability. Given sampled trajectories from the current policy, as long as any state belongs to the goal set, the return will be positive.</p>
<p>Their approach iterates through 3 steps until the policy converges:</p>
<ol>
<li>Label a set of goals based on whether they are at the appropriate level of difficulty for the current policy.
<ul>
<li>The set of goals at the appropriate level of difficulty are named <strong>GOID</strong> (short for “Goals of Intermediate Difficulty”).<br />\(\text{GOID}_i := \{g : R_\text{min} \leq R^g(\pi_i) \leq R_\text{max} \} \subseteq G\)</li>
<li>Here \(R_\text{min}\) and \(R_\text{max}\) can be interpreted as a minimum and maximum probability of reaching a goal over T time-steps.</li>
</ul>
</li>
<li>Train a Goal GAN model using labelled goals from step 1 to produce new goals</li>
<li>Use these new goals to train the policy, improving its coverage objective.</li>
</ol>
<p>The Goal GAN generates a curriculum automatically:</p>
<ul>
<li>Generator \(G(z)\): produces a new goal. => expected to be a goal uniformly sampled from \(GOID\) set.</li>
<li>Discriminator \(D(g)\): evaluates whether a goal can be achieved. => expected to tell whether a goal is from \(GOID\) set.</li>
</ul>
<p>The Goal GAN is constructed similar to LSGAN (Least-Squared GAN; <a href="https://arxiv.org/abs/1611.04076">Mao et al., (2017)</a>), which has better stability of learning compared to vanilla GAN. According to LSGAN, we should minimize the following losses for \(D\) and \(G\) respectively:</p>
\[\begin{aligned}
\mathcal{L}_\text{LSGAN}(D) &= \frac{1}{2} \mathbb{E}_{g \sim p_\text{data}(g)} [ (D(g) - b)^2] + \frac{1}{2} \mathbb{E}_{z \sim p_z(z)} [ (D(G(z)) - a)^2] \\
\mathcal{L}_\text{LSGAN}(G) &= \frac{1}{2} \mathbb{E}_{z \sim p_z(z)} [ (D(G(z)) - c)^2]
\end{aligned}\]
<p>where \(a\) is the label for fake data, \(b\) for real data, and \(c\) is the value that \(G\) wants \(D\) to believe for fake data. In LSGAN paper’s experiments, they used \(a=-1, b=1, c=0\).</p>
<p>The Goal GAN introduces an extra binary flag \(y_b\) indicating whether a goal \(g\) is real (\(y_g = 1\)) or fake (\(y_g = 0\)) so that the model can use negative samples for training:</p>
\[\begin{aligned}
\mathcal{L}_\text{GoalGAN}(D) &= \frac{1}{2} \mathbb{E}_{g \sim p_\text{data}(g)} [ (D(g) - b)^2 + (1-y_g) (D(g) - a)^2] + \frac{1}{2} \mathbb{E}_{z \sim p_z(z)} [ (D(G(z)) - a)^2] \\
\mathcal{L}_\text{GoalGAN}(G) &= \frac{1}{2} \mathbb{E}_{z \sim p_z(z)} [ (D(G(z)) - c)^2]
\end{aligned}\]
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/generative-goal-learning-algorithm.png" alt="Generative goal learning" /></p>
<p><em>Fig. 6. The algorithm of Generative Goal Learning. (Image source: (<a href="https://arxiv.org/abs/1705.06366">Florensa, et al. 2018</a>)</em></p>
<p>Following the same idea, <a href="https://arxiv.org/abs/1909.12892">Racaniere & Lampinen, et al. (2019)</a> designs a method to make the objectives of goal generator more sophisticated. Their method contains three components, same as generative goal learning above:</p>
<ul>
<li><strong>Solver</strong>/Policy \(\pi\): In each episode, the solver gets a goal \(g\) at the beginning and get a single binary reward \(R^g\) at the end.</li>
<li><strong>Judge</strong>/Discriminator \(D(.)\): A classifier to predict the binary reward (whether goal can be achieved or not); precisely it outputs the logit of a probability of achieving the given goal, \(\sigma(D(g)) = p(R^g=1\vert g)\), where \(\sigma\) is the sigmoid function.</li>
<li><strong>Setter</strong>/Generator \(G(.)\): The goal setter takes as input a desired feasibility score \(f \in \text{Unif}(0, 1)\) and generates \(g = G(z, f)\), where the latent variable \(z\) is sampled by \(z \sim \mathcal{N}(0, I)\). The goal generator is designed to reversible, so \(G^{-1}\) can map backwards from a goal \(g\) to a latent \(z = G^{-1}(g, f)\)</li>
</ul>
<p>The generator is optimized with three objectives:</p>
<ul>
<li>(1) Goal <strong>validity</strong>: The proposed goal should be achievable by an expert policy. The corresponding generative loss is designed to increase the likelihood of generating goals that the solver policy has achieved before (like in <a href="https://arxiv.org/abs/1707.01495">HER</a>).
<ul>
<li>\(\mathcal{L}_\text{val}\) is the negative log-likelihood of generated goals that have been solved by the solver in the past.</li>
<li>
\[\begin{align*}
\mathcal{L}_\text{val} = \mathbb{E}_{\substack{
g \sim \text{ achieved by solver}, \\
\xi \in \text{Uniform}(0, \delta), \\
f \in \text{Uniform}(0, 1)
}} \big[ -\log p(G^{-1}(g + \xi, f)) \big]
\end{align*}\]
</li>
</ul>
</li>
<li>(2) Goal <strong>feasibility</strong>: The proposed goal should be achievable by the current policy; that is, the level of difficulty should be appropriate.
<ul>
<li>\(\mathcal{L}_\text{feas}\) is the output probability by the judge model \(D\) on the generated goal \(G(z, f)\) should match the desired $f$.</li>
<li>
\[\begin{align*}
\mathcal{L}_\text{feas} = \mathbb{E}_{\substack{
z \in \mathcal{N}(0, 1), \\
f \in \text{Uniform}(0, 1)
}} \big[ D(G(z, f)) - \sigma^{-1}(f)^2 \big]
\end{align*}\]
</li>
</ul>
</li>
<li>(3) Goal <strong>coverage</strong>: We should maximize the entropy of generated goals to encourage diverse goal and to improve the coverage over the goal space.
<ul>
<li>
\[\begin{align*}
\mathcal{L}_\text{cov} = \mathbb{E}_{\substack{
z \in \mathcal{N}(0, 1), \\
f \in \text{Uniform}(0, 1)
}} \big[ \log p(G(z, f)) \big]
\end{align*}\]
</li>
</ul>
</li>
</ul>
<p>Their experiments showed complex environments require all three losses above. When the environment is changing between episodes, both the goal generator and the discriminator need to be conditioned on environmental observation to produce better results. If there is a desired goal distribution, an additional loss can be added to match a desired goal distribution using Wasserstein distance. Using this loss, the generator can push the solver toward mastering the desired tasks more efficiently.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/setter-judge-goal-generation.png" alt="Goal setter and judge models" /></p>
<p><em>Fig. 7. Training schematic for the (a) solver/policy, (b) judge/discriminator, and (c) setter/goal generator models. (Image source: <a href="https://arxiv.org/abs/1909.12892">Racaniere & Lampinen, et al., 2019</a>)</em></p>
<h2 id="skill-based-curriculum">Skill-Based Curriculum</h2>
<p>Another view is to decompose what an agent is able to complete into a variety of skills and each skill set could be mapped into a task. Let’s imagine when an agent interacts with the environment in an unsupervised manner, is there a way to discover useful skills from such interaction and further build into the solutions for more complicated tasks through a curriculum?</p>
<p><a href="https://arxiv.org/abs/1912.04226">Jabri, et al. (2019)</a> developed an automatic curriculum, <strong>CARML</strong> (short for “Curricula for Unsupervised Meta-Reinforcement Learning”), by modeling unsupervised trajectories into a latent skill space, with a focus on training <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html">meta-RL</a> policies (i.e. can transfer to unseen tasks). The setting of training environments in CARML is similar to <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#learning-with-random-rewards">DIAYN</a>. Differently, CARML is trained on pixel-level observations but DIAYN operates on the true state space. An RL algorithm \(\pi_\theta\), parameterized by \(\theta\), is trained via unsupervised interaction formulated as a CMP combined with a learned reward function \(r\). This setting naturally works for the meta-learning purpose, since a customized reward function can be given only at the test time.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CARML.png" alt="CARML" /></p>
<p><em>Fig. 8. An illustration of CARML, containing two steps: (1) organizing experiential data into the latent skill space; (2) meta-training the policy with the reward function constructed from the learned skills. (Image source: <a href="https://arxiv.org/abs/1912.04226">Jabri, et al 2019</a>)</em></p>
<p>CARML is framed as a <a href="https://chrischoy.github.io/research/Expectation-Maximization-and-Variational-Inference/">variational Expectation-Maximization (EM)</a>.</p>
<p>(1) <strong>E-Step</strong>: This is the stage for organizing experiential data. Collected trajectories are modeled with a mixture of latent components forming the <a href="https://en.wikipedia.org/wiki/Basis_(linear_algebra)">basis</a> of <em>skills</em>.</p>
<p>Let \(z\) be a latent task variable and \(q_\phi\) be a variational distribution of \(z\), which could be a mixture model with discrete \(z\) or a VAE with continuous \(z\). A variational posterior \(q_\phi(z \vert s)\) works like a classifier, predicting a skill given a state, and we would like to maximize \(q_\phi(z \vert s)\) to discriminate between data produced by different skills as much as possible. In E-step, \(q_\phi\) is fitted to a set of trajectories produced by \(\pi_\theta\).</p>
<p>Precisely, given a trajectory \(\tau = (s_1,\dots,s_T)\), we would like to find \(\phi\) such that</p>
\[\max_\phi \mathbb{E}_{z\sim q_\phi(z)} \big[ \log q_\phi(\tau \vert z) \big]
= \max_\phi \mathbb{E}_{z\sim q_\phi(z)} \big[ \sum_{s_i \in \tau} \log q_\phi(s_i \vert z) \big]\]
<p>A simplifying assumption is made here to ignore the order of states in one trajectory.</p>
<p>(2) <strong>M-Step</strong>: This is the stage for doing meta-RL training with \(\pi_\theta\). The learned skill space is considered as a training task distribution. CARML is agnostic to the type of meta-RL algorithm for policy parameter updates.</p>
<p>Given a trajectory \(\tau\), it makes sense for the policy to maximize the mutual information between \(\tau\) and \(z\), \(I(\tau;z) = H(\tau) - H(\tau \vert z)\), because:</p>
<ul>
<li>maximizing \(H(\tau)\) => diversity in the policy data space; expected to be large.</li>
<li>minimizing \(H(\tau \vert z)\) => given a certain skill, the behavior should be restricted; expected to be small.</li>
</ul>
<p>Then we have,</p>
\[\begin{aligned}
I(\tau; z)
&= \mathcal{H}(z) - \mathcal{H}(z \vert s_1,\dots, s_T) \\
&\geq \mathbb{E}_{s \in \tau} [\mathcal{H}(z) - \mathcal{H}(z\vert s)] & \scriptstyle{\text{; discard the order of states.}} \\
&= \mathbb{E}_{s \in \tau} [\mathcal{H}(s_t) - \mathcal{H}(s\vert z)] & \scriptstyle{\text{; by definition of MI.}} \\
&= \mathbb{E}_{z\sim q_\phi(z), s\sim \pi_\theta(s|z)} [\log q_\phi(s|z) - \log \pi_\theta(s)] \\
&\approx \mathbb{E}_{z\sim q_\phi(z), s\sim \pi_\theta(s|z)} [\color{green}{\log q_\phi(s|z) - \log q_\phi(s)}] & \scriptstyle{\text{; assume learned marginal distr. matches policy.}}
\end{aligned}\]
<p>We can set the reward as \(\log q_\phi(s \vert z) - \log q_\phi(s)\), as shown in the <span style="color: green;">red</span> part in the equation above. In order to balance between task-specific exploration (as in <span style="color: red;">red</span> below) and latent skill matching (as in <span style="color: blue;">blue</span> below) , a parameter \(\lambda \in [0, 1]\) is added. Each realization of \(z \sim q_\phi(z)\) induces a reward function \(r_z(s)\) (remember that reward + CMP => MDP) as follows:</p>
\[\begin{aligned}
r_z(s)
&= \lambda \log q_\phi(s|z) - \log q_\phi(s) \\
&= \lambda \log q_\phi(s|z) - \log \frac{q_\phi(s|z) q_\phi(z)}{q_\phi(z|s)} \\
&= \lambda \log q_\phi(s|z) - \log q_\phi(s|z) - \log q_\phi(z) + \log q_\phi(z|s) \\
&= (\lambda - 1) \log \color{red}{q_\phi(s|z)} + \color{blue}{\log q_\phi(z|s)} + C
\end{aligned}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CARML-algorithm.png" alt="CARML algorithm" /></p>
<p><em>Fig. 9. The algorithm of CARML. (Image source: <a href="https://arxiv.org/abs/1912.04226">Jabri, et al 2019</a>)</em></p>
<p>Learning a latent skill space can be done in different ways, such as in <a href="https://openreview.net/forum?id=rk07ZXZRb">Hausman, et al. 2018</a>. The goal of their approach is to learn a task-conditioned policy, \(\pi(a \vert s, t^{(i)})\), where \(t^{(i)}\) is from a discrete list of \(N\) tasks, \(\mathcal{T} = [t^{(1)}, \dots, t^{(N)}]\). However, rather than learning \(N\) separate solutions, one per task, it would be nice to learn a latent skill space so that each task could be represented in a distribution over skills and thus skills are <em>reused between tasks</em>. The policy is defined as \(\pi_\theta(a \vert s,t) = \int \pi_\theta(a \vert z,s,t) p_\phi(z \vert t)\mathrm{d}z\), where \(\pi_\theta\) and \(p_\phi\) are policy and embedding networks to learn, respectively. If \(z\) is discrete, i.e. drawn from a set of \(K\) skills, then the policy becomes a mixture of \(K\) sub-policies. The policy training uses <a href="http://127.0.0.1:4000/lil-log/2018/04/07/policy-gradient-algorithms.html#sac">SAC</a> and the dependency on \(z\) is introduced in the entropy term.</p>
<h2 id="curriculum-through-distillation">Curriculum through Distillation</h2>
<p>[I was thinking of the name of this section for a while, deciding between cloning, inheritance, and distillation. Eventually, I picked distillation because it sounds the coolest B-)]</p>
<p>The motivation for the <strong>progressive neural network</strong> (<a href="https://arxiv.org/abs/1606.04671">Rusu et al. 2016</a>) architecture is to efficiently transfer learned skills between different tasks and in the meantime avoid catastrophic forgetting. The curriculum is realized through a set of progressively stacked neural network towers (or “columns”, as in the paper).</p>
<p>A progressive network has the following structure:</p>
<ol>
<li>It starts with a single column containing \(L\) layers of neurons, in which the corresponding activation layers are labelled as \(h^{(1)}_i, i=1, \dots, L\). We first train this single-column network for one task to convergence, achieving parameter config \(\theta^{(1)}\).</li>
<li>Once switch to the next task, we need to add a new column to adapt to the new context while freezing \(\theta^{(1)}\) to lock down the learned skills from the previous task. The new column has activation layers labelled as \(h^{(2)}_i, i=1, \dots, L\), and parameters \(\theta^{(2)}\).</li>
<li>
<p>Step 2 can be repeated with every new task. The \(i\)-th layer activation in the \(k\)-th column depends on the previous activation layers in all the existing columns:</p>
\[h^{(k)}_i = f(W^{(k)}_i h^{(k)}_{i-1} + \sum_{j < k} U_i^{(k:j)} h^{(j)}_{i-1})\]
<p>where \(W^{(k)}_i\) is the weight matrix of the layer \(i\) in the column \(k\); \(U_i^{(k:j)}, j < k\) are the weight matrices for projecting the layer \(i-1\) of the column \(j\) to the layer \(i\) of column \(k\) (\(j < k\)). The above weights matrices should be learned. \(f(.)\) is a non-linear activation function by choice.</p>
</li>
</ol>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/progressive-networks.png" alt="Progressive networks" /></p>
<p><em>Fig. 10. The progressive neural network architecture. (Image source: <a href="https://arxiv.org/abs/1610.04286">Rusu, et al. 2017</a>)</em></p>
<p>The paper experimented with Atari games by training a progressive network on multiple games to check whether features learned in one game can transfer to another. That is indeed the case. Though interestingly, learning a high dependency on features in the previous columns does not always indicate good transfer performance on the new task. One hypothesis is that features learned from the old task might introduce biases into the new task, leading to policy getting trapped in a sub-optimal solution. Overall, the progressive network works better than only fine-tuning the top layer and can achieve similar transfer performance as fine-tuning the entire network.</p>
<p>One use case for the progressive network is to do sim2real transfer (<a href="https://arxiv.org/abs/1610.04286">Rusu, et al. 2017</a>), in which the first column is trained in simulator with a lot of samples and then the additional columns (could be for different real-world tasks) are added and trained with a few real data samples.</p>
<p><a href="https://arxiv.org/abs/1806.01780">Czarnecki, et al. (2018)</a> proposed another RL training framework, <strong>Mix & Match</strong> (short for <strong>M&M</strong>) to provide curriculum through coping knowledge between agents. Given a sequence of agents from simple to complex, \(\pi_1, \dots, \pi_K\), each parameterized with some shared weights (e.g. by shared some lower common layers). M&M trains a mixture of agents, but only the final performance of the most complex one \(\pi_K\) matters.</p>
<p>In the meantime, M&M learns a categorical distribution \(c \sim \text{Categorical}(1, \dots, K \vert \alpha)\) with <a href="https://en.wikipedia.org/wiki/Probability_mass_function">pmf</a> \(p(c=i) = \alpha_i\) probability to pick which policy to use at a given time. The mixed M&M policy is a simple weighted sum: \(\pi_\text{mm}(a \vert s) = \sum_{i=1}^K \alpha_i \pi_i(a \vert s)\). Curriculum learning is realized by dynamically adjusting \(\alpha_i\), from \(\alpha_K=0\) to \(\alpha_K=1\). The tuning of \(\alpha\) can be manual or through <a href="/lil-log/2019/09/05/evolution-strategies.html#hyperparameter-tuning-pbt">population-based training</a>.</p>
<p>To encourage cooperation rather than competition among policies, besides the RL loss \(\mathcal{L}_\text{RL}\), another <a href="https://arxiv.org/abs/1511.06295">distillation</a>-like loss \(\mathcal{L}_\text{mm}(\theta)\) is added. The knowledge transfer loss \(\mathcal{L}_\text{mm}(\theta)\) measures the KL divergence between two policies, \(\propto D_\text{KL}(\pi_{i}(. \vert s) \| \pi_j(. \vert s))\) for \(i < j\). It encourages complex agents to match the simpler ones early on. The final loss is \(\mathcal{L} = \mathcal{L}_\text{RL}(\theta \vert \pi_\text{mm}) + \lambda \mathcal{L}_\text{mm}(\theta)\).</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/mix-and-match.png" alt="Mix & Match" /></p>
<p><em>Fig. 11. The Mix & Match architecture for training a mixture of policies. (Image source: <a href="https://arxiv.org/abs/1806.01780">Czarnecki, et al., 2018</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020curriculum,
title = "Curriculum for Reinforcement Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/01/29/curriculum-for-reinforcement-learning.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Jeffrey L. Elman. <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.128.4487&rep=rep1&type=pdf">“Learning and development in neural networks: The importance of starting small.”</a> Cognition 48.1 (1993): 71-99.</p>
<p>[2] Yoshua Bengio, et al. <a href="https://www.researchgate.net/profile/Y_Bengio/publication/221344862_Curriculum_learning/links/546cd2570cf2193b94c577ac/Curriculum-learning.pdf">“Curriculum learning.”</a> ICML 2009.</p>
<p>[3] Daphna Weinshall, Gad Cohen, and Dan Amir. <a href="https://arxiv.org/abs/1802.03796">“Curriculum learning by transfer learning: Theory and experiments with deep networks.”</a> ICML 2018.</p>
<p>[4] Wojciech Zaremba and Ilya Sutskever. <a href="https://arxiv.org/abs/1410.4615">“Learning to execute.”</a> arXiv preprint arXiv:1410.4615 (2014).</p>
<p>[5] Tambet Matiisen, et al. <a href="https://arxiv.org/abs/1707.00183">“Teacher-student curriculum learning.”</a> IEEE Trans. on neural networks and learning systems (2017).</p>
<p>[6] Alex Graves, et al. <a href="https://arxiv.org/abs/1704.03003">“Automated curriculum learning for neural networks.”</a> ICML 2017.</p>
<p>[7] Remy Portelas, et al. <a href="https://arxiv.org/abs/1910.07224">Teacher algorithms for curriculum learning of Deep RL in continuously parameterized environments</a>. CoRL 2019.</p>
<p>[8] Sainbayar Sukhbaatar, et al. <a href="https://arxiv.org/abs/1703.05407">“Intrinsic Motivation and Automatic Curricula via Asymmetric Self-Play.”</a> ICLR 2018.</p>
<p>[9] Carlos Florensa, et al. <a href="https://arxiv.org/abs/1705.06366">“Automatic Goal Generation for Reinforcement Learning Agents”</a> ICML 2019.</p>
<p>[10] Sebastien Racaniere & Andrew K. Lampinen, et al. <a href="https://arxiv.org/abs/1909.12892">“Automated Curriculum through Setter-Solver Interactions”</a> ICLR 2020.</p>
<p>[11] Allan Jabri, et al. <a href="https://arxiv.org/abs/1912.04226">“Unsupervised Curricula for Visual Meta-Reinforcement Learning”</a> NeuriPS 2019.</p>
<p>[12] Karol Hausman, et al. <a href="https://openreview.net/forum?id=rk07ZXZRb">“Learning an Embedding Space for Transferable Robot Skills “</a> ICLR 2018.</p>
<p>[13] Josh Merel, et al. <a href="https://arxiv.org/abs/1911.06636">“Reusable neural skill embeddings for vision-guided whole body movement and object manipulation”</a> arXiv preprint arXiv:1911.06636 (2019).</p>
<p>[14] OpenAI, et al. <a href="https://arxiv.org/abs/1910.07113">“Solving Rubik’s Cube with a Robot Hand.”</a> arXiv preprint arXiv:1910.07113 (2019).</p>
<p>[15] Niels Justesen, et al. <a href="https://arxiv.org/abs/1806.10729">“Illuminating Generalization in Deep Reinforcement Learning through Procedural Level Generation”</a> NeurIPS 2018 Deep RL Workshop.</p>
<p>[16] Karl Cobbe, et al. <a href="https://arxiv.org/abs/1812.02341">“Quantifying Generalization in Reinforcement Learning”</a> arXiv preprint arXiv:1812.02341 (2018).</p>
<p>[17] Andrei A. Rusu et al. <a href="https://arxiv.org/abs/1606.04671">“Progressive Neural Networks”</a> arXiv preprint arXiv:1606.04671 (2016).</p>
<p>[18] Andrei A. Rusu et al. <a href="https://arxiv.org/abs/1610.04286">“Sim-to-Real Robot Learning from Pixels with Progressive Nets.”</a> CoRL 2017.</p>
<p>[19] Wojciech Marian Czarnecki, et al. <a href="https://arxiv.org/abs/1806.01780">“Mix & Match – Agent Curricula for Reinforcement Learning.”</a> ICML 2018.</p>Lilian WengA curriculum is an efficient tool for humans to progressively learn from simple concepts to hard problems. It breaks down complex knowledge by providing a sequence of learning steps of increasing difficulty. In this post, we will examine how the idea of curriculum can help reinforcement learning models learn to solve complicated tasks.Self-Supervised Representation Learning2019-11-10T18:00:00+00:002019-11-10T18:00:00+00:00https://lilianweng.github.io/lil-log/2019/11/10/self-supervised-learning<blockquote>
<p>Self-supervised learning opens up a huge opportunity for better utilizing unlabelled data, while learning in a supervised learning manner. This post covers many interesting ideas of self-supervised learning tasks on images, videos, and control problems.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2020-01-09: add a new section on <a href="#contrastive-predictive-coding">Contrastive Predictive Coding</a>].</span>
<br />
<span style="color: #286ee0;">[Updated on 2020-04-13: add a <a href="#momentum-contrast">“Momentum Contrast”</a> section on MoCo, SimCLR and CURL.</span>
<br />
<span style="color: #286ee0;">[Updated on 2020-07-08: add a <a href="#bisimulation">“Bisimulation”</a> section on DeepMDP and DBC.</span>
<br />
<span style="color: #286ee0;">[Updated on 2020-09-12: add <a href="#mocov2">MoCo V2</a> and <a href="#BYOL">BYOL</a> in the <a href="#momentum-contrast">“Momentum Contrast”</a> section.</span></p>
<p>Given a task and enough labels, supervised learning can solve it really well. Good performance usually requires a decent amount of labels, but collecting manual labels is expensive (i.e. ImageNet) and hard to be scaled up. Considering the amount of unlabelled data (e.g. free text, all the images on the Internet) is substantially more than a limited number of human curated labelled datasets, it is kinda wasteful not to use them. However, unsupervised learning is not easy and usually works much less efficiently than supervised learning.</p>
<p>What if we can get labels for free for unlabelled data and train unsupervised dataset in a supervised manner? We can achieve this by framing a supervised learning task in a special form to predict only a subset of information using the rest. In this way, all the information needed, both inputs and labels, has been provided. This is known as <em>self-supervised learning</em>.</p>
<p>This idea has been widely used in language modeling. The default task for a language model is to predict the next word given the past sequence. <a href="/lil-log/2019/01/31/generalized-language-models.html#bert">BERT</a> adds two other auxiliary tasks and both rely on self-generated labels.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/self-sup-lecun.png" alt="Self-supervised learning summary" /></p>
<p><em>Fig. 1. A great summary of how self-supervised learning tasks can be constructed (Image source: <a href="https://www.youtube.com/watch?v=7I0Qt7GALVk">LeCun’s talk</a>)</em></p>
<p><a href="https://github.com/jason718/awesome-self-supervised-learning">Here</a> is a nicely curated list of papers in self-supervised learning. Please check it out if you are interested in reading more in depth.</p>
<p>Note that this post does not focus on either NLP / <a href="/lil-log/2019/01/31/generalized-language-models.html">language modeling</a> or <a href="https://lilianweng.github.io/lil-log/tag/generative-model">generative modeling</a>.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#why-self-supervised-learning" id="markdown-toc-why-self-supervised-learning">Why Self-Supervised Learning?</a></li>
<li><a href="#images-based" id="markdown-toc-images-based">Images-Based</a> <ul>
<li><a href="#distortion" id="markdown-toc-distortion">Distortion</a></li>
<li><a href="#patches" id="markdown-toc-patches">Patches</a></li>
<li><a href="#colorization" id="markdown-toc-colorization">Colorization</a></li>
<li><a href="#generative-modeling" id="markdown-toc-generative-modeling">Generative Modeling</a></li>
<li><a href="#contrastive-predictive-coding" id="markdown-toc-contrastive-predictive-coding">Contrastive Predictive Coding</a></li>
<li><a href="#momentum-contrast" id="markdown-toc-momentum-contrast">Momentum Contrast</a></li>
</ul>
</li>
<li><a href="#video-based" id="markdown-toc-video-based">Video-Based</a> <ul>
<li><a href="#tracking" id="markdown-toc-tracking">Tracking</a></li>
<li><a href="#frame-sequence" id="markdown-toc-frame-sequence">Frame Sequence</a></li>
<li><a href="#video-colorization" id="markdown-toc-video-colorization">Video Colorization</a></li>
</ul>
</li>
<li><a href="#control-based" id="markdown-toc-control-based">Control-Based</a> <ul>
<li><a href="#multi-view-metric-learning" id="markdown-toc-multi-view-metric-learning">Multi-View Metric Learning</a></li>
<li><a href="#autonomous-goal-generation" id="markdown-toc-autonomous-goal-generation">Autonomous Goal Generation</a></li>
<li><a href="#bisimulation" id="markdown-toc-bisimulation">Bisimulation</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="why-self-supervised-learning">Why Self-Supervised Learning?</h2>
<p>Self-supervised learning empowers us to exploit a variety of labels that come with the data for free. The motivation is quite straightforward. Producing a dataset with clean labels is expensive but unlabeled data is being generated all the time. To make use of this much larger amount of unlabeled data, one way is to set the learning objectives properly so as to get supervision from the data itself.</p>
<p>The <em>self-supervised task</em>, also known as <em>pretext task</em>, guides us to a supervised loss function. However, we usually don’t care about the final performance of this invented task. Rather we are interested in the learned intermediate representation with the expectation that this representation can carry good semantic or structural meanings and can be beneficial to a variety of practical downstream tasks.</p>
<p>For example, we might rotate images at random and train a model to predict how each input image is rotated. The rotation prediction task is made-up, so the actual accuracy is unimportant, like how we treat auxiliary tasks. But we expect the model to learn high-quality latent variables for real-world tasks, such as constructing an object recognition classifier with very few labeled samples.</p>
<p>Broadly speaking, all the generative models can be considered as self-supervised, but with different goals: Generative models focus on creating diverse and realistic images, while self-supervised representation learning care about producing good features generally helpful for many tasks. Generative modeling is not the focus of this post, but feel free to check my <a href="https://lilianweng.github.io/lil-log/tag/generative-model">previous posts</a>.</p>
<h2 id="images-based">Images-Based</h2>
<p>Many ideas have been proposed for self-supervised representation learning on images. A common workflow is to train a model on one or multiple pretext tasks with unlabelled images and then use one intermediate feature layer of this model to feed a multinomial logistic regression classifier on ImageNet classification. The final classification accuracy quantifies how good the learned representation is.</p>
<p>Recently, some researchers proposed to train supervised learning on labelled data and self-supervised pretext tasks on unlabelled data simultaneously with shared weights, like in <a href="https://arxiv.org/abs/1905.03670">Zhai et al, 2019</a> and <a href="https://arxiv.org/abs/1909.11825">Sun et al, 2019</a>.</p>
<h3 id="distortion">Distortion</h3>
<p>We expect small distortion on an image does not modify its original semantic meaning or geometric forms. Slightly distorted images are considered the same as original and thus the learned features are expected to be invariant to distortion.</p>
<p><mark><b>Exemplar-CNN</b></mark> (<a href="https://arxiv.org/abs/1406.6909">Dosovitskiy et al., 2015</a>) create surrogate training datasets with unlabeled image patches:</p>
<ol>
<li>Sample \(N\) patches of size 32 × 32 pixels from different images at varying positions and scales, only from regions containing considerable gradients as those areas cover edges and tend to contain objects or parts of objects. They are <em>“exemplary”</em> patches.</li>
<li>Each patch is distorted by applying a variety of random transformations (i.e., translation, rotation, scaling, etc.). All the resulting distorted patches are considered to belong to the <em>same surrogate class</em>.</li>
<li>The pretext task is to discriminate between a set of surrogate classes. We can arbitrarily create as many surrogate classes as we want.</li>
</ol>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/examplar-cnn.png" alt="Examplar CNN" /></p>
<p><em>Fig. 2. The original patch of a cute deer is in the top left corner. Random transformations are applied, resulting in a variety of distorted patches. All of them should be classified into the same class in the pretext task. (Image source: <a href="https://arxiv.org/abs/1406.6909">Dosovitskiy et al., 2015</a>)</em></p>
<p><mark><b>Rotation</b></mark> of an entire image (<a href="https://arxiv.org/abs/1803.07728">Gidaris et al. 2018</a> is another interesting and cheap way to modify an input image while the semantic content stays unchanged. Each input image is first rotated by a multiple of \(90^\circ\) at random, corresponding to \([0^\circ, 90^\circ, 180^\circ, 270^\circ]\). The model is trained to predict which rotation has been applied, thus a 4-class classification problem.</p>
<p>In order to identify the same image with different rotations, the model has to learn to recognize high level object parts, such as heads, noses, and eyes, and the relative positions of these parts, rather than local patterns. This pretext task drives the model to learn semantic concepts of objects in this way.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/self-sup-rotation.png" alt="Self supervised by rotation prediction" /></p>
<p><em>Fig. 3. Illustration of self-supervised learning by rotating the entire input images. The model learns to predict which rotation is applied. (Image source: <a href="https://arxiv.org/abs/1803.07728">Gidaris et al. 2018</a>)</em></p>
<h3 id="patches">Patches</h3>
<p>The second category of self-supervised learning tasks extract multiple patches from one image and ask the model to predict the relationship between these patches.</p>
<p><a href="https://arxiv.org/abs/1505.05192">Doersch et al. (2015)</a> formulates the pretext task as predicting the <mark><b>relative position</b></mark> between two random patches from one image. A model needs to understand the spatial context of objects in order to tell the relative position between parts.</p>
<p>The training patches are sampled in the following way:</p>
<ol>
<li>Randomly sample the first patch without any reference to image content.</li>
<li>Considering that the first patch is placed in the middle of a 3x3 grid, and the second patch is sampled from its 8 neighboring locations around it.</li>
<li>To avoid the model only catching low-level trivial signals, such as connecting a straight line across boundary or matching local patterns, additional noise is introduced by:
<ul>
<li>Add gaps between patches</li>
<li>Small jitters</li>
<li>Randomly downsample some patches to as little as 100 total pixels, and then upsampling it, to build robustness to pixelation.</li>
<li>Shift green and magenta toward gray or randomly drop 2 of 3 color channels (See <a href="#chromatic-aberration">“chromatic aberration”</a> below)</li>
</ul>
</li>
<li>The model is trained to predict which one of 8 neighboring locations the second patch is selected from, a classification problem over 8 classes.</li>
</ol>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/self-sup-by-relative-position.png" alt="Self-supervised learning by context" /></p>
<p><em>Fig. 4. Illustration of self-supervised learning by predicting the relative position of two random patches. (Image source: <a href="https://arxiv.org/abs/1505.05192">Doersch et al., 2015</a>)</em></p>
<p><a href="#chromatic-aberration"></a>Other than trivial signals like boundary patterns or textures continuing, another interesting and a bit surprising trivial solution was found, called <a href="https://en.wikipedia.org/wiki/Chromatic_aberration"><em>“chromatic aberration”</em></a>. It is triggered by different focal lengths of lights at different wavelengths passing through the lens. In the process, there might exist small offsets between color channels. Hence, the model can learn to tell the relative position by simply comparing how green and magenta are separated differently in two patches. This is a trivial solution and has nothing to do with the image content. Pre-processing images by shifting green and magenta toward gray or randomly dropping 2 of 3 color channels can avoid this trivial solution.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/chromatic-aberration.png" alt="Chromatic aberration" /></p>
<p><em>Fig. 5. Illustration of how chromatic aberration happens. (Image source: <a href="https://upload.wikimedia.org/wikipedia/commons/a/aa/Chromatic_aberration_lens_diagram.svg">wikipedia</a>)</em></p>
<p>Since we have already set up a 3x3 grid in each image in the above task, why not use all of 9 patches rather than only 2 to make the task more difficult? Following this idea, <a href="https://arxiv.org/abs/1603.09246">Noroozi & Favaro (2016)</a> designed a <mark><b>jigsaw puzzle</b></mark> game as pretext task: The model is trained to place 9 shuffled patches back to the original locations.</p>
<p>A convolutional network processes each patch independently with shared weights and outputs a probability vector per patch index out of a predefined set of permutations. To control the difficulty of jigsaw puzzles, the paper proposed to shuffle patches according to a predefined permutation set and configured the model to predict a probability vector over all the indices in the set.</p>
<p>Because how the input patches are shuffled does not alter the correct order to predict. A potential improvement to speed up training is to use permutation-invariant graph convolutional network (GCN) so that we don’t have to shuffle the same set of patches multiple times, same idea as in this <a href="https://arxiv.org/abs/1911.00025">paper</a>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/self-sup-jigsaw-puzzle.png" alt="Jigsaw puzzle" /></p>
<p><em>Fig. 6. Illustration of self-supervised learning by solving jigsaw puzzle. (Image source: <a href="https://arxiv.org/abs/1603.09246">Noroozi & Favaro, 2016</a>)</em></p>
<p>Another idea is to consider “feature” or “visual primitives” as a scalar-value attribute that can be summed up over multiple patches and compared across different patches. Then the relationship between patches can be defined by <mark><b>counting features</b></mark> and simple arithmetic (<a href="https://arxiv.org/abs/1708.06734">Noroozi, et al, 2017</a>).</p>
<p>The paper considers two transformations:</p>
<ol>
<li><em>Scaling</em>: If an image is scaled up by 2x, the number of visual primitives should stay the same.</li>
<li><em>Tiling</em>: If an image is tiled into a 2x2 grid, the number of visual primitives is expected to be the sum, 4 times the original feature counts.</li>
</ol>
<p>The model learns a feature encoder \(\phi(.)\) using the above feature counting relationship. Given an input image \(\mathbf{x} \in \mathbb{R}^{m \times n \times 3}\), considering two types of transformation operators:</p>
<ol>
<li>Downsampling operator, \(D: \mathbb{R}^{m \times n \times 3} \mapsto \mathbb{R}^{\frac{m}{2} \times \frac{n}{2} \times 3}\): downsample by a factor of 2</li>
<li>Tiling operator \(T_i: \mathbb{R}^{m \times n \times 3} \mapsto \mathbb{R}^{\frac{m}{2} \times \frac{n}{2} \times 3}\): extract the \(i\)-th tile from a 2x2 grid of the image.</li>
</ol>
<p>We expect to learn:</p>
\[\phi(\mathbf{x}) = \phi(D \circ \mathbf{x}) = \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})\]
<p><a href="#counting-feature-loss"></a>Thus the MSE loss is: \(\mathcal{L}_\text{feat} = \|\phi(D \circ \mathbf{x}) - \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})\|^2_2\). To avoid trivial solution \(\phi(\mathbf{x}) = \mathbf{0}, \forall{\mathbf{x}}\), another loss term is added to encourage the difference between features of two different images: \(\mathcal{L}_\text{diff} = \max(0, c -\|\phi(D \circ \mathbf{y}) - \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})\|^2_2)\), where \(\mathbf{y}\) is another input image different from \(\mathbf{x}\) and \(c\) is a scalar constant. The final loss is:</p>
\[\mathcal{L}
= \mathcal{L}_\text{feat} + \mathcal{L}_\text{diff}
= \|\phi(D \circ \mathbf{x}) - \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})\|^2_2 + \max(0, M -\|\phi(D \circ \mathbf{y}) - \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})\|^2_2)\]
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/self-sup-counting-features.png" alt="Counting features" /></p>
<p><em>Fig. 7. Self-supervised representation learning by counting features. (Image source: <a href="https://arxiv.org/abs/1708.06734">Noroozi, et al, 2017</a>)</em></p>
<h3 id="colorization">Colorization</h3>
<p><mark><b>Colorization</b></mark> can be used as a powerful self-supervised task: a model is trained to color a grayscale input image; precisely the task is to map this image to a distribution over quantized color value outputs (<a href="https://arxiv.org/abs/1603.08511">Zhang et al. 2016</a>).</p>
<p>The model outputs colors in the the <a href="https://en.wikipedia.org/wiki/CIELAB_color_space">CIE L<em>a</em>b* color space</a>. The L<em>a</em>b* color is designed to approximate human vision, while, in contrast, RGB or CMYK models the color output of physical devices.</p>
<ul>
<li>L* component matches human perception of lightness; L* = 0 is black and L* = 100 indicates white.</li>
<li>a* component represents green (negative) / magenta (positive) value.</li>
<li>b* component models blue (negative) /yellow (positive) value.</li>
</ul>
<p>Due to the multimodal nature of the colorization problem, cross-entropy loss of predicted probability distribution over binned color values works better than L2 loss of the raw color values. The a<em>b</em> color space is quantized with bucket size 10.</p>
<p>To balance between common colors (usually low a<em>b</em> values, of common backgrounds like clouds, walls, and dirt) and rare colors (which are likely associated with key objects in the image), the loss function is rebalanced with a weighting term that boosts the loss of infrequent color buckets. This is just like why we need both <a href="https://en.wikipedia.org/wiki/Tf%E2%80%93idf">tf and idf</a> for scoring words in information retrieval model. The weighting term is constructed as: (1-λ) * Gaussian-kernel-smoothed empirical probability distribution + λ * a uniform distribution, where both distributions are over the quantized a<em>b</em> color space.</p>
<h3 id="generative-modeling">Generative Modeling</h3>
<p>The pretext task in generative modeling is to reconstruct the original input while learning meaningful latent representation.</p>
<p>The <mark><b>denoising autoencoder</b></mark> (<a href="https://www.cs.toronto.edu/~larocheh/publications/icml-2008-denoising-autoencoders.pdf">Vincent, et al, 2008</a>) learns to recover an image from a version that is partially corrupted or has random noise. The design is inspired by the fact that humans can easily recognize objects in pictures even with noise, indicating that key visual features can be extracted and separated from noise. See my <a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#denoising-autoencoder">old post</a>.</p>
<p>The <mark><b>context encoder</b></mark> (<a href="https://arxiv.org/abs/1604.07379">Pathak, et al., 2016</a>) is trained to fill in a missing piece in the image. Let \(\hat{M}\) be a binary mask, 0 for dropped pixels and 1 for remaining input pixels. The model is trained with a combination of the reconstruction (L2) loss and the adversarial loss. The removed regions defined by the mask could be of any shape.</p>
\[\begin{aligned}
\mathcal{L}(\mathbf{x}) &= \mathcal{L}_\text{recon}(\mathbf{x}) + \mathcal{L}_\text{adv}(\mathbf{x})\\
\mathcal{L}_\text{recon}(\mathbf{x}) &= \|(1 - \hat{M}) \odot (\mathbf{x} - E(\hat{M} \odot \mathbf{x})) \|_2^2 \\
\mathcal{L}_\text{adv}(\mathbf{x}) &= \max_D \mathbb{E}_{\mathbf{x}} [\log D(\mathbf{x}) + \log(1 - D(E(\hat{M} \odot \mathbf{x})))]
\end{aligned}\]
<p>where \(E(.)\) is the encoder and \(D(.)\) is the decoder.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/context-encoder.png" alt="Context encoder" /></p>
<p><em>Fig. 8. Illustration of context encoder. (Image source: <a href="https://arxiv.org/abs/1604.07379">Pathak, et al., 2016</a>)</em></p>
<p>When applying a mask on an image, the context encoder removes information of all the color channels in partial regions. How about only hiding a subset of channels? The <mark><b>split-brain autoencoder</b></mark> (<a href="https://arxiv.org/abs/1611.09842">Zhang et al., 2017</a>) does this by predicting a subset of color channels from the rest of channels. Let the data tensor \(\mathbf{x} \in \mathbb{R}^{h \times w \times \vert C \vert }\) with \(C\) color channels be the input for the \(l\)-th layer of the network. It is split into two disjoint parts, \(\mathbf{x}_1 \in \mathbb{R}^{h \times w \times \vert C_1 \vert}\) and \(\mathbf{x}_2 \in \mathbb{R}^{h \times w \times \vert C_2 \vert}\), where \(C_1 , C_2 \subseteq C\). Then two sub-networks are trained to do two complementary predictions: one network \(f_1\) predicts \(\mathbf{x}_2\) from \(\mathbf{x}_1\) and the other network \(f_1\) predicts \(\mathbf{x}_1\) from \(\mathbf{x}_2\). The loss is either L1 loss or cross entropy if color values are quantized.</p>
<p>The split can happen once on the RGB-D or L<em>a</em>b* colorspace, or happen even in every layer of a CNN network in which the number of channels can be arbitrary.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/split-brain-autoencoder.png" alt="Split-brain autoencoder" /></p>
<p><em>Fig. 9. Illustration of split-brain autoencoder. (Image source: <a href="https://arxiv.org/abs/1611.09842">Zhang et al., 2017</a>)</em></p>
<p>The generative adversarial networks (GANs) are able to learn to map from simple latent variables to arbitrarily complex data distributions. Studies have shown that the latent space of such generative models captures semantic variation in the data; e.g. when training GAN models on human faces, some latent variables are associated with facial expression, glasses, gender, etc (<a href="https://arxiv.org/abs/1511.06434">Radford et al., 2016</a>).</p>
<p><mark><b>Bidirectional GANs</b></mark> (<a href="https://arxiv.org/abs/1605.09782">Donahue, et al, 2017</a>) introduces an additional encoder \(E(.)\) to learn the mappings from the input to the latent variable \(\mathbf{z}\). The discriminator \(D(.)\) predicts in the joint space of the input data and latent representation, \((\mathbf{x}, \mathbf{z})\), to tell apart the generated pair \((\mathbf{x}, E(\mathbf{x}))\) from the real one \((G(\mathbf{z}), \mathbf{z})\). The model is trained to optimize the objective: \(\min_{G, E} \max_D V(D, E, G)\), where the generator \(G\) and the encoder \(E\) learn to generate data and latent variables that are realistic enough to confuse the discriminator and at the same time the discriminator \(D\) tries to differentiate real and generated data.</p>
\[V(D, E, G) = \mathbb{E}_{\mathbf{x} \sim p_\mathbf{x}} [ \underbrace{\mathbb{E}_{\mathbf{z} \sim p_E(.\vert\mathbf{x})}[\log D(\mathbf{x}, \mathbf{z})]}_{\log D(\text{real})} ] + \mathbb{E}_{\mathbf{z} \sim p_\mathbf{z}} [ \underbrace{\mathbb{E}_{\mathbf{x} \sim p_G(.\vert\mathbf{z})}[\log 1 - D(\mathbf{x}, \mathbf{z})]}_{\log(1- D(\text{fake}))}) ]\]
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/bi-GAN.png" alt="BiGAN" /></p>
<p><em>Fig. 10. Illustration of how Bidirectional GAN works. (Image source: <a href="https://arxiv.org/abs/1605.09782">Donahue, et al, 2017</a>)</em></p>
<h3 id="contrastive-predictive-coding">Contrastive Predictive Coding</h3>
<p>The <mark><b>Contrastive Predictive Coding (CPC)</b></mark> (<a href="https://arxiv.org/abs/1807.03748">van den Oord, et al. 2018</a>) is an approach for unsupervised learning from high-dimensional data by translating a generative modeling problem to a classification problem. The <em>contrastive loss</em> or <em>InfoNCE loss</em> in CPC, inspired by <a href="/lil-log/2017/10/15/learning-word-embedding.html#noise-contrastive-estimation-nce">Noise Contrastive Estimation (NCE)</a>, uses cross-entropy loss to measure how well the model can classify the “future” representation amongst a set of unrelated “negative” samples. Such design is partially motivated by the fact that the unimodal loss like MSE has no enough capacity but learning a full generative model could be too expensive.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CPC-audio.png" alt="CPC on audio input" /></p>
<p><em>Fig. 11. Illustration of applying Contrastive Predictive Coding on the audio input. (Image source: <a href="https://arxiv.org/abs/1807.03748">van den Oord, et al. 2018</a>)</em></p>
<p>CPC uses an encoder to compress the input data \(z_t = g_\text{enc}(x_t)\) and an <em>autoregressive</em> decoder to learn the high-level context that are potentially shared across future predictions, \(c_t = g_\text{ar}(z_{\leq t})\). The end-to-end training relies on the NCE-inspired contrastive loss.</p>
<p>While predicing future information, CPC is optimized to maximize the the mutual information between input \(x\) and context vector \(c\):</p>
\[I(x; c) = \sum_{x, c} p(x, c) \log\frac{p(x, c)}{p(x)p(c)} = \sum_{x, c} p(x, c)\log\frac{p(x|c)}{p(x)}\]
<p>Rather than modeling the future observations \(p_k(x_{t+k} \vert c_t)\) directly (which could be fairly expensive), CPC models a density function to preserve the mutual information between \(x_{t+k}\) and \(c_t\):</p>
\[f_k(x_{t+k}, c_t) = \exp(z_{t+k}^\top W_k c_t) \propto \frac{p(x_{t+k}|c_t)}{p(x_{t+k})}\]
<p>where \(f_k\) can be unnormalized and a linear transformation \(W_k^\top c_t\) is used for the prediction with a different \(W_k\) matrix for every step \(k\).</p>
<p>Given a set of \(N\) random samples \(X = \{x_1, \dots, x_N\}\) containing only one positive sample \(x_t \sim p(x_{t+k} \vert c_t)\) and \(N-1\) negative samples \(x_{i \neq t} \sim p(x_{t+k})\), the cross-entropy loss for classifying the positive sample (where \(\frac{f_k}{\sum f_k}\) is the prediction) correctly is:</p>
\[\mathcal{L}_N = - \mathbb{E}_X \Big[\log \frac{f_k(x_{t+k}, c_t)}{\sum_{i=1}^N f_k (x_i, c_t)}\Big]\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CPC-image.png" alt="CPC on images" /></p>
<p><em>Fig. 12. Illustration of applying Contrastive Predictive Coding on images. (Image source: <a href="https://arxiv.org/abs/1807.03748">van den Oord, et al. 2018</a>)</em></p>
<p>When using CPC on images (<a href="https://arxiv.org/abs/1905.09272">Henaff, et al. 2019</a>), the predictor network should only access a masked feature set to avoid a trivial prediction. Precisely:</p>
<ol>
<li>Each input image is divided into a set of overlapped patches and each patch is encoded by a resnet encoder, resulting in compressed feature vector \(z_{i,j}\).</li>
<li>A masked conv net makes prediction with a mask such that the receptive field of a given output neuron can only see things above it in the image. Otherwise, the prediction problem would be trivial. The prediction can be made in both directions (top-down and bottom-up).</li>
<li>The prediction is made for \(z_{i+k, j}\) from context \(c_{i,j}\): \(\hat{z}_{i+k, j} = W_k c_{i,j}\).</li>
</ol>
<p>A contrastive loss quantifies this prediction with a goal to correctly identify the target among a set of negative representation \(\{z_l\}\) sampled from other patches in the same image and other images in the same batch:</p>
\[\mathcal{L}_\text{CPC}
= -\sum_{i,j,k} \log p(z_{i+k, j} \vert \hat{z}_{i+k, j}, \{z_l\})
= -\sum_{i,j,k} \log \frac{\exp(\hat{z}_{i+k, j}^\top z_{i+k, j})}{\exp(\hat{z}_{i+k, j}^\top z_{i+k, j}) + \sum_l \exp(\hat{z}_{i+k, j}^\top z_l)}\]
<h3 id="momentum-contrast">Momentum Contrast</h3>
<p><a name="moco"></a><mark><b>Momentum Contrast</b></mark> (<strong>MoCo</strong>; <a href="https://arxiv.org/abs/1911.05722">He et al, 2019</a>) provides a framework of unsupervised learning visual representation as a <em>dynamic dictionary look-up</em>. The dictionary is structured as a large FIFO queue of encoded representations of data samples.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/MoCo.png" alt="MoCo" /></p>
<p><em>Fig. 13. Illustration of how Momentum Contrast (MoCo) learns visual representations. (Image source: <a href="https://arxiv.org/abs/1911.05722">He et al, 2019</a>)</em></p>
<p>Given a query sample \(x_q\), we get a query representation \(q\) through an encoder \(f_q\): \(q = f_q(x_q)\). Key samples are encoded by a momentum encoder \(k_i = f_k (x^k_i)\) to produce a list of key representations \(\{k_1, k_2, \dots \}\) in the dictionary. Let’s assume among them there is a single <em>positive</em> key \(k^+\) in the dictionary that matches \(q\). In the paper, \(k^+\) is created using a copy of \(x_q\) with different augmentation. Then the <a href="#contrastive-predictive-coding">InfoNCE</a> contrastive loss is applied for one positive and \(K\) negative samples:</p>
\[\mathcal{L}_q = - \log \frac{\exp(q \cdot k^+ / \tau)}{\sum_{i=0}^K \exp(q \cdot k_i / \tau)}\]
<p>where \(\tau\) is a temperature hyper-parameter.</p>
<p>Compared to another similar idea of <strong>memory bank</strong> (<a href="https://arxiv.org/abs/1805.01978v1">Wu et al, 2018</a>) which stores representations of all the data points in the database and samples a random set of keys as negative examples, a queue-based dictionary in MoCo enables us to reuse representations of immediate preceding mini-batches of data.</p>
<p>The MoCo dictionary is not differentiable as a queue, so we cannot rely on back-propagation to update the key encoder \(f_k\). One naive way might be to use the same encoder for both \(f_q\) and \(f_k\). Differently, MoCo proposed to use a momentum-based update. Say, the parameters of \(f_q\) and \(f_k\) are labeled as \(\theta_q\) and \(\theta_k\), respectively.</p>
\[\theta_k \leftarrow m \theta_k + (1-m) \theta_q\]
<p>where \(m \in [0, 1)\) is a momentum coefficient. No gradient flows through \(f_k\)’s update.</p>
<p style="width: 68%;" class="center"><img src="/lil-log/assets/images/MoCo-algo.png" alt="MoCo Algorithm" /></p>
<p><em>Fig. 14. Pseudo code pf MoCo in PyTorch style. (Image source: <a href="https://arxiv.org/abs/1911.05722">He et al, 2019</a>)</em></p>
<p><strong>SimCLR</strong> (<a href="https://arxiv.org/abs/2002.05709">Chen et al, 2020</a>) proposed a simple framework for contrastive learning of visual representations. It learns representations for visual inputs by maximizing agreement between differently augmented views of the same sample via a contrastive loss in the latent space.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/SimCLR.png" alt="SimCLR" /></p>
<p><em>Fig. 15. A simple framework for contrastive learning of visual representations. (Image source: <a href="https://arxiv.org/abs/2002.05709">Chen et al, 2020</a>)</em></p>
<p>SimCLR works in the following three steps:</p>
<p>(1) Randomly sample a mini-batch of \(n\) samples and each sample is applied with two different data augmentation operations, resulting in \(2n\) augmented samples in total.</p>
\[\tilde{\mathbf{x}}_i = t(\mathbf{x}),\quad\tilde{\mathbf{x}}_j = t'(\mathbf{x}),\quad t, t' \sim \mathcal{T}\]
<p>where two separate data augmentation operators, \(t\) and \(t’\), are sampled from the same family of augmentations \(\mathcal{T}\). Data augmentation includes random crop, resize with random flip, color distortions, and Gaussian blur.</p>
<p>(2) Given one positive pair, other \(2(n-1)\) data points are treated as negative samples. The representation is produced by a base encoder \(f(.)\):</p>
\[\mathbf{h}_i = f(\tilde{\mathbf{x}}_i),\quad \mathbf{h}_j = f(\tilde{\mathbf{x}}_j)\]
<p>(3) The contrastive loss is defined using cosine similarity \(\text{sim}(.,.)\). Note that the loss operates on top of an extra projection of the representation via \(g(.)\) rather than on the representation \(\mathbf{h}\) directly. But only the representation \(\mathbf{h}\) is used for downstream tasks.</p>
\[\begin{aligned}
\mathbf{z}_i &= g(\mathbf{h}_i),\quad
\mathbf{z}_j = g(\mathbf{h}_j),\quad
\text{sim}(\mathbf{z}_i, \mathbf{z}_j) = \frac{\mathbf{z}_i^\top\mathbf{z}_j}{\|\mathbf{z}_i\| \|\mathbf{z}_j\|} \\
\mathcal{L}_{i,j} &= - \log\frac{\exp(\text{sim}(\mathbf{z}_i, \mathbf{z}_j) / \tau)}{\sum_{k=1}^{2n} \mathbf{1}_{[k \neq i]} \exp(\text{sim}(\mathbf{z}_i, \mathbf{z}_k) / \tau)}
\end{aligned}\]
<p>where \(\mathbf{1}_{[k \neq i]}\) is an indicator function: 1 if \(k\neq i\) 0 otherwise. \(\tau\) is a temperature hyperparameter.</p>
<p style="width: 58%;" class="center"><img src="/lil-log/assets/images/SimCLR-algo.png" alt="SimCLR Algorithm" /></p>
<p><em>Fig. 16. The algorithm for SimCLR. (Image source: <a href="https://arxiv.org/abs/2002.05709">Chen et al, 2020</a>).</em></p>
<p><a name="mocov2"></a>
The advantage of MoCo compared to SimCLR is that MoCo decouples the batch size from the number of negatives, but SimCLR requires a large batch size in order to have enough negative samples. So surprisingly SimCLR would suffer performance drops when their batch size is reduced.</p>
<p>Two designs in SimCLR, namely, (1) an MLP projection head and (2) stronger data augmentation (e.g. crop, blur and stronger color distortation), are proved to be very efficient. Combining them with MoCo, we get <strong>MoCo V2</strong> (<a href="https://arxiv.org/abs/2003.04297">Chen et al, 2020</a>) which achieves even better transfer performance with no dependency on a very large batch size.</p>
<p><a name="BYOL"></a> <strong>BYOL</strong> (“Bootstrap your own latent”; <a href="https://arxiv.org/abs/2006.07733">Grill, et al 2020</a>) claims to achieve a new state-of-the-art results without using <em>negative samples</em>. It relies on two neural networks, referred to as <em>online</em> and <em>target</em> networks, that interact and learn from each other. The target network (parameterized by \(\xi\)) has the same architecture as the online one (parameterized by \(\theta\)), but with polyak averaged weights, \(\xi \leftarrow \tau \xi + (1-\tau) \theta\).</p>
<p>The goal is to learn a presentation \(y\) that can be used in downstream tasks. The online network parameterized by \(\theta\) contains:</p>
<ul>
<li>an encoder \(f_\theta\),</li>
<li>a projector \(g_\theta\),</li>
<li>and a predictor \(q_\theta\).</li>
</ul>
<p>The target network has the same network architecture, but with different parameter \(\xi\), updated by polyak averaging \(\theta\): \(\xi \leftarrow \tau \xi + (1-\tau) \theta\).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/BYOL.png" alt="BYOL" /></p>
<p><em>Fig. 17. The model architecture of BYOL (“Bootstrapping your own latent”). (Image source: <a href="https://arxiv.org/abs/2006.07733">Grill, et al 2020</a>). After training, we only care about \(f_\theta\) for producing representation, \(y=f_\theta(x)\), and everything else is discarded.</em></p>
<p>Given an image \(x\), the BYOL loss is constructed as follows:</p>
<ol>
<li>Create two augmented views: \(v=t(x); v'=t'(x)\) with augmentations sampled \(t \sim \mathcal{T}, t' \sim \mathcal{T}'\);</li>
<li>Then they are encoded into representations, \(y_\theta=f_\theta(v), y'=f_\xi(v')\);</li>
<li>Then they are projected into latent variables, \(z_\theta=g_\theta(y_\theta), z'=g_\xi(v')\);</li>
<li>The online network outputs a prediction \(q_\theta(z_\theta)\);</li>
<li>Both \(q_\theta(z_\theta)\) and \(z’\) are L2-normalized, giving us \(\bar{q}_\theta(z_\theta) = q_\theta(z_\theta) / \| q_\theta(z_\theta) \|\) and \(\bar{z'} = z' / \|z'\|\);</li>
<li>The loss \(\mathcal{L}^\text{BYOL}_\theta\) is MSE between L2-normalized prediction \(\bar{q}_\theta(z)\) and \(\bar{z'}\);</li>
<li>The other symmetric loss \(\tilde{\mathcal{L}}^\text{BYOL}_\theta\) can be generated by switching \(v'\) and \(v\); that is, feeding \(v'\) to online network and \(v\) to target network.</li>
<li>The final loss is \(\mathcal{L}^\text{BYOL}_\theta + \tilde{\mathcal{L}}^\text{BYOL}_\theta\) and only parameters \(\theta\) are optimized.</li>
</ol>
<p>Different from most popular contrastive learning based approaches, BYOL does not use negative pairs. Most bootstrapping approaches rely on pseudo-labels or cluster indices, but BYOL directly boostrapps the latent representation.</p>
<p>It is quite interesting and surprising that without negative samples, BYOL still works well. Later I ran into this <a href="https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html">post by Abe Fetterman & Josh Albrecht</a>, where they highlighted two surprising findings while they were trying to reproduce BYOL:</p>
<ol>
<li>BYOL generally performs no better than random when batch normalization is removed.</li>
<li>The presence of batch normalization implicitly causes a form of contrastive learning.</li>
</ol>
<p>They believe that using negative samples is important for avoiding model collapse (i.e. what if your model outputs all-zeros representation for every data point?). Batch normalization injects dependency on negative samples inexplicitly because no matter how similar a batch of inputs are, the values are re-distributed (spread out ~ \(\mathcal{N}(0, 1\)) and therefore batch normalization prevents model collapse. Strongly recommend you to read the <a href="https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html">full article</a> if you are working in this field.</p>
<p><strong>CURL</strong> (<a href="https://arxiv.org/abs/2004.04136">Srinivas & Laskin, et al. 2020</a>) applies the above ideas in Reinforcement Learning. It learns a visual representation for RL tasks by matching embeddings of two data-augmented versions, \(o_q\) and \(o_k\), of the raw observation \(o\) via contrastive loss. CURL primarily relies on random crop data augmentation. The key encoder is implemented as a momentum encoder with weights as EMA of the query encoder weights, same as in <a href="#moco">MoCo</a>.</p>
<p>One significant difference between RL and supervised visual tasks is that RL depends on <em>temporal</em> consistency between consecutive frames. Therefore, CURL applies augmentation consistently on each stack of frames to retain information about the temporal structure of the observation.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/CURL.png" alt="CURL" /></p>
<p><em>Fig. 18. The architecture of CURL. (Image source: <a href="https://arxiv.org/abs/2004.04136">Srinivas & Laskin, et al. 2020</a>)</em></p>
<h2 id="video-based">Video-Based</h2>
<p>A video contains a sequence of semantically related frames. Nearby frames are close in time and more correlated than frames further away. The order of frames describes certain rules of reasonings and physical logics; such as that object motion should be smooth and gravity is pointing down.</p>
<p>A common workflow is to train a model on one or multiple pretext tasks with unlabelled videos and then feed one intermediate feature layer of this model to fine-tune a simple model on downstream tasks of action classification, segmentation or object tracking.</p>
<h3 id="tracking">Tracking</h3>
<p>The movement of an object is traced by a sequence of video frames. The difference between how the same object is captured on the screen in close frames is usually not big, commonly triggered by small motion of the object or the camera. Therefore any visual representation learned for the same object across close frames should be close in the latent feature space. Motivated by this idea, <a href="https://arxiv.org/abs/1505.00687">Wang & Gupta, 2015</a> proposed a way of unsupervised learning of visual representation by <mark><b>tracking moving objects</b></mark> in videos.</p>
<p>Precisely patches with motion are tracked over a small time window (e.g. 30 frames). The first patch \(\mathbf{x}\) and the last patch \(\mathbf{x}^+\) are selected and used as training data points. If we train the model directly to minimize the difference between feature vectors of two patches, the model may only learn to map everything to the same value. To avoid such a trivial solution, same as <a href="#counting-feature-loss">above</a>, a random third patch \(\mathbf{x}^-\) is added. The model learns the representation by enforcing the distance between two tracked patches to be closer than the distance between the first patch and a random one in the feature space, \(D(\mathbf{x}, \mathbf{x}^-)) > D(\mathbf{x}, \mathbf{x}^+)\), where \(D(.)\) is the cosine distance,</p>
\[D(\mathbf{x}_1, \mathbf{x}_2) = 1 - \frac{f(\mathbf{x}_1) f(\mathbf{x}_2)}{\|f(\mathbf{x}_1)\| \|f(\mathbf{x}_2\|)}\]
<p>The loss function is:</p>
\[\mathcal{L}(\mathbf{x}, \mathbf{x}^+, \mathbf{x}^-)
= \max\big(0, D(\mathbf{x}, \mathbf{x}^+) - D(\mathbf{x}, \mathbf{x}^-) + M\big) + \text{weight decay regularization term}\]
<p>where \(M\) is a scalar constant controlling for the minimum gap between two distances; \(M=0.5\) in the paper. The loss enforces \(D(\mathbf{x}, \mathbf{x}^-) >= D(\mathbf{x}, \mathbf{x}^+) + M\) at the optimal case.</p>
<p><a href="#triplet-loss"></a>This form of loss function is also known as <a href="https://arxiv.org/abs/1503.03832">triplet loss</a> in the face recognition task, in which the dataset contains images of multiple people from multiple camera angles. Let \(\mathbf{x}^a\) be an anchor image of a specific person, \(\mathbf{x}^p\) be a positive image of this same person from a different angle and \(\mathbf{x}^n\) be a negative image of a different person. In the embedding space, \(\mathbf{x}^a\) should be closer to \(\mathbf{x}^p\) than \(\mathbf{x}^n\):</p>
\[\mathcal{L}_\text{triplet}(\mathbf{x}^a, \mathbf{x}^p, \mathbf{x}^n) = \max(0, \|\phi(\mathbf{x}^a) - \phi(\mathbf{x}^p) \|_2^2 - \|\phi(\mathbf{x}^a) - \phi(\mathbf{x}^n) \|_2^2 + M)\]
<p><a href="#n-pair-loss"></a>A slightly different form of the triplet loss, named <a href="https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective">n-pair loss</a> is also commonly used for learning observation embedding in robotics tasks. See a <a href="#multi-view-metric-learning">later section</a> for more related content.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/tracking-videos.png" alt="tracking videos" /></p>
<p><em>Fig. 19. Overview of learning representation by tracking objects in videos. (a) Identify moving patches in short traces; (b) Feed two related patched and one random patch into a conv network with shared weights. (c) The loss function enforces the distance between related patches to be closer than the distance between random patches. (Image source: <a href="https://arxiv.org/abs/1505.00687">Wang & Gupta, 2015</a>)</em></p>
<p>Relevant patches are tracked and extracted through a two-step unsupervised <a href="https://en.wikipedia.org/wiki/Optical_flow">optical flow</a> approach:</p>
<ol>
<li>Obtain <a href="https://www.vision.ee.ethz.ch/~surf/eccv06.pdf">SURF</a> interest points and use <a href="https://hal.inria.fr/hal-00873267v2/document">IDT</a> to obtain motion of each SURF point.</li>
<li>Given the trajectories of SURF interest points, classify these points as moving if the flow magnitude is more than 0.5 pixels.</li>
</ol>
<p>During training, given a pair of correlated patches \(\mathbf{x}\) and \(\mathbf{x}^+\), \(K\) random patches \(\{\mathbf{x}^-\}\) are sampled in this same batch to form \(K\) training triplets. After a couple of epochs, <em>hard negative mining</em> is applied to make the training harder and more efficient, that is, to search for random patches that maximize the loss and use them to do gradient updates.</p>
<h3 id="frame-sequence">Frame Sequence</h3>
<p>Video frames are naturally positioned in chronological order. Researchers have proposed several self-supervised tasks, motivated by the expectation that good representation should learn the <em>correct sequence</em> of frames.</p>
<p>One idea is to <mark><b>validate frame order</b></mark> (<a href="https://arxiv.org/abs/1603.08561">Misra, et al 2016</a>). The pretext task is to determine whether a sequence of frames from a video is placed in the correct temporal order (“temporal valid”). The model needs to track and reason about small motion of an object across frames to complete such a task.</p>
<p>The training frames are sampled from high-motion windows. Every time 5 frames are sampled \((f_a, f_b, f_c, f_d, f_e)\) and the timestamps are in order \(a < b < c < d < e\). Out of 5 frames, one positive tuple \((f_b, f_c, f_d)\) and two negative tuples, \((f_b, f_a, f_d)\) and \((f_b, f_e, f_d)\) are created. The parameter \(\tau_\max = \vert b-d \vert\) controls the difficulty of positive training instances (i.e. higher → harder) and the parameter \(\tau_\min = \min(\vert a-b \vert, \vert d-e \vert)\) controls the difficulty of negatives (i.e. lower → harder).</p>
<p>The pretext task of video frame order validation is shown to improve the performance on the downstream task of action recognition when used as a pretraining step.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/frame-order-validation.png" alt="frame order validation" /></p>
<p><em>Fig. 20. Overview of learning representation by validating the order of video frames. (a) the data sample process; (b) the model is a triplet siamese network, where all input frames have shared weights. (Image source: <a href="https://arxiv.org/abs/1603.08561">Misra, et al 2016</a>)</em></p>
<p>The task in <em>O3N</em> (Odd-One-Out Network; <a href="https://arxiv.org/abs/1611.06646">Fernando et al. 2017</a>) is based on video frame sequence validation too. One step further from above, the task is to <mark><b>pick the incorrect sequence</b></mark> from multiple video clips.</p>
<p>Given \(N+1\) input video clips, one of them has frames shuffled, thus in the wrong order, and the rest \(N\) of them remain in the correct temporal order. O3N learns to predict the location of the odd video clip. In their experiments, there are 6 input clips and each contain 6 frames.</p>
<p>The <mark><b>arrow of time</b></mark> in a video contains very informative messages, on both low-level physics (e.g. gravity pulls objects down to the ground; smoke rises up; water flows downward.) and high-level event reasoning (e.g. fish swim forward; you can break an egg but cannot revert it.). Thus another idea is inspired by this to learn latent representation by predicting the arrow of time (AoT) — whether video playing forwards or backwards (<a href="https://www.robots.ox.ac.uk/~vgg/publications/2018/Wei18/wei18.pdf">Wei et al., 2018</a>).</p>
<p>A classifier should capture both low-level physics and high-level semantics in order to predict the arrow of time. The proposed <em>T-CAM</em> (Temporal Class-Activation-Map) network accepts \(T\) groups, each containing a number of frames of optical flow. The conv layer outputs from each group are concatenated and fed into binary logistic regression for predicting the arrow of time.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/learning-arrow-of-time.png" alt="Learning the arrow of time" /></p>
<p><em>Fig. 21. Overview of learning representation by predicting the arrow of time. (a) Conv features of multiple groups of frame sequences are concatenated. (b) The top level contains 3 conv layers and average pooling. (Image source: <a href="https://www.robots.ox.ac.uk/~vgg/publications/2018/Wei18/wei18.pdf">Wei et al, 2018</a>)</em></p>
<p>Interestingly, there exist a couple of artificial cues in the dataset. If not handled properly, they could lead to a trivial classifier without relying on the actual video content:</p>
<ul>
<li>Due to the video compression, the black framing might not be completely black but instead may contain certain information on the chronological order. Hence black framing should be removed in the experiments.</li>
<li>Large camera motion, like vertical translation or zoom-in/out, also provides strong signals for the arrow of time but independent of content. The processing stage should stabilize the camera motion.</li>
</ul>
<p>The AoT pretext task is shown to improve the performance on action classification downstream task when used as a pretraining step. Note that fine-tuning is still needed.</p>
<h3 id="video-colorization">Video Colorization</h3>
<p><a href="https://arxiv.org/abs/1806.09594">Vondrick et al. (2018)</a> proposed <mark><b>video colorization</b></mark> as a self-supervised learning problem, resulting in a rich representation that can be used for video segmentation and unlabelled visual region tracking, <em>without extra fine-tuning</em>.</p>
<p>Unlike the image-based <a href="#colorization">colorization</a>, here the task is to copy colors from a normal reference frame in color to another target frame in grayscale by leveraging the natural temporal coherency of colors across video frames (thus these two frames shouldn’t be too far apart in time). In order to copy colors consistently, the model is designed to learn to keep track of correlated pixels in different frames.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/video-colorization.png" alt="Video colorization" /></p>
<p><em>Fig. 22. Video colorization by copying colors from a reference frame to target frames in grayscale. (Image source: <a href="https://arxiv.org/abs/1806.09594">Vondrick et al. 2018</a>)</em></p>
<p>The idea is quite simple and smart. Let \(c_i\) be the true color of the \(i-th\) pixel in the reference frame and \(c_j\) be the color of \(j\)-th pixel in the target frame. The predicted color of \(j\)-th color in the target \(\hat{c}_j\) is a weighted sum of colors of all the pixels in reference, where the weighting term measures the similarity:</p>
\[\hat{c}_j = \sum_i A_{ij} c_i \text{ where } A_{ij} = \frac{\exp(f_i f_j)}{\sum_{i'} \exp(f_{i'} f_j)}\]
<p>where \(f\) are learned embeddings for corresponding pixels; \(i’\) indexes all the pixels in the reference frame. The weighting term implements an attention-based pointing mechanism, similar to <a href="/lil-log/2018/11/30/meta-learning.html#matching-networks">matching network</a> and <a href="/lil-log/2018/06/24/attention-attention.html#pointer-network">pointer network</a>. As the full similarity matrix could be really large, both frames are downsampled. The categorical cross-entropy loss between \(c_j\) and \(\hat{c}_j\) is used with quantized colors, just like in <a href="https://arxiv.org/abs/1603.08511">Zhang et al. 2016</a>.</p>
<p>Based on how the reference frame are marked, the model can be used to complete several color-based downstream tasks such as tracking segmentation or human pose in time. No fine-tuning is needed. See Fig. 15.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/video-colorization-examples.png" alt="Video colorization for tracking" /></p>
<p><em>Fig. 23. Use video colorization to track object segmentation and human pose in time. (Image source: <a href="https://arxiv.org/abs/1806.09594">Vondrick et al. (2018)</a>)</em></p>
<blockquote>
<p>A couple common observations:</p>
<ul>
<li>Combining multiple pretext tasks improves performance;</li>
<li>Deeper networks improve the quality of representation;</li>
<li>Supervised learning baselines still beat all of them by far.</li>
</ul>
</blockquote>
<h2 id="control-based">Control-Based</h2>
<p>When running a RL policy in the real world, such as controlling a physical robot on visual inputs, it is non-trivial to properly track states, obtain reward signals or determine whether a goal is achieved for real. The visual data has a lot of noise that is irrelevant to the true state and thus the equivalence of states cannot be inferred from pixel-level comparison. Self-supervised representation learning has shown great potential in learning useful state embedding that can be used directly as input to a control policy.</p>
<p>All the cases discussed in this section are in robotic learning, mainly for state representation from multiple camera views and goal representation.</p>
<h3 id="multi-view-metric-learning">Multi-View Metric Learning</h3>
<p>The concept of metric learning has been mentioned multiple times in the <a href="#counting-feature-loss">previous</a> <a href="#tracking">sections</a>. A common setting is: Given a triple of samples, (<em>anchor</em> \(s_a\), <em>positive</em> sample \(s_p\), <em>negative</em> sample \(s_n\)), the learned representation embedding \(\phi(s)\) fulfills that \(s_a\) stays close to \(s_p\) but far away from \(s_n\) in the latent space.</p>
<p><a href="#grasp2vec"></a><mark><b>Grasp2Vec</b></mark> (<a href="https://arxiv.org/abs/1811.06964">Jang & Devin et al., 2018</a>) aims to learn an object-centric vision representation in the robot grasping task from free, unlabelled grasping activities. By object-centric, it means that, irrespective of how the environment or the robot looks like, if two images contain similar items, they should be mapped to similar representation; otherwise the embeddings should be far apart.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/grasp2vec.png" alt="Grasp2vec" /></p>
<p><em>Fig. 23. A conceptual illustration of how grasp2vec learns an object-centric state embedding. (Image source: <a href="https://arxiv.org/abs/1811.06964">Jang & Devin et al., 2018</a>)</em></p>
<p>The grasping system can tell whether it moves an object but cannot tell which object it is. Cameras are set up to take images of the entire scene and the grasped object. During early training, the grasp robot is executed to grasp any object \(o\) at random, producing a triple of images, \((s_\text{pre}, s_\text{post}, o)\):</p>
<ul>
<li>\(o\) is an image of the grasped object held up to the camera;</li>
<li>\(s_\text{pre}\) is an image of the scene <em>before</em> grasping, with the object \(o\) in the tray;</li>
<li>\(s_\text{post}\) is an image of the same scene <em>after</em> grasping, without the object \(o\) in the tray.</li>
</ul>
<p>To learn object-centric representation, we expect the difference between embeddings of \(s_\text{pre}\) and \(s_\text{post}\) to capture the removed object \(o\). The idea is quite interesting and similar to relationships that have been observed in <a href="/lil-log/2017/10/15/learning-word-embedding.html">word embedding</a>, <a href="https://developers.google.com/machine-learning/crash-course/embeddings/translating-to-a-lower-dimensional-space">e.g.</a> distance(“king”, “queen”) ≈ distance(“man”, “woman”).</p>
<p>Let \(\phi_s\) and \(\phi_o\) be the embedding functions for the scene and the object respectively. The model learns the representation by minimizing the distance between \(\phi_s(s_\text{pre}) - \phi_s(s_\text{post})\) and \(\phi_o(o)\) using <em>n-pair loss</em>:</p>
\[\begin{aligned}
\mathcal{L}_\text{grasp2vec} &= \text{NPair}(\phi_s(s_\text{pre}) - \phi_s(s_\text{post}), \phi_o(o)) + \text{NPair}(\phi_o(o), \phi_s(s_\text{pre}) - \phi_s(s_\text{post})) \\
\text{where }\text{NPair}(a, p) &= \sum_{i<B} -\log\frac{\exp(a_i^\top p_j)}{\sum_{j<B, i\neq j}\exp(a_i^\top p_j)} + \lambda (\|a_i\|_2^2 + \|p_i\|_2^2)
\end{aligned}\]
<p>where \(B\) refers to a batch of (anchor, positive) sample pairs.</p>
<p>When framing representation learning as metric learning, <a href="https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective"><strong>n-pair loss</strong></a> is a common choice. Rather than processing explicit a triple of (anchor, positive, negative) samples, the n-pairs loss treats all other positive instances in one mini-batch across pairs as negatives.</p>
<p>The embedding function \(\phi_o\) works great for presenting a goal \(g\) with an image. The reward function that quantifies how close the actually grasped object \(o\) is close to the goal is defined as \(r = \phi_o(g) \cdot \phi_o(o)\). Note that computing rewards only relies on the learned latent space and doesn’t involve ground truth positions, so it can be used for training on real robots.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/grasp2vec-attention-map.png" alt="Grasp2vec attention map" /></p>
<p><em>Fig. 24. Localization results of grasp2vec embedding. The heatmap of localizing a goal object in a pre-grasping scene is defined as \(\phi_o(o)^\top \phi_{s, \text{spatial}} (s_\text{pre})\), where \(\phi_{s, \text{spatial}}\) is the output of the last resnet block after ReLU. The fourth column is a failure case and the last three columns take real images as goals. (Image source: <a href="https://arxiv.org/abs/1811.06964">Jang & Devin et al., 2018</a>)</em></p>
<p>Other than the embedding-similarity-based reward function, there are a few other tricks for training the RL policy in the grasp2vec framework:</p>
<ul>
<li><em>posthoc labelingP</em>: Augment the dataset by labeling a randomly grasped object as a correct goal, like HER (Hindsight Experience Replay; <a href="https://papers.nips.cc/paper/7090-hindsight-experience-replay.pdf">Andrychowicz, et al., 2017</a>).</li>
<li><em>Auxiliary goal augmentation</em>: Augment the replay buffer even further by relabeling transitions with unachieved goals; precisely, in each iteration, two goals are sampled \((g, g')\) and both are used to add new transitions into replay buffer.</li>
</ul>
<p><a href="#tcn"></a><strong>TCN</strong> (<mark><b>Time-Contrastive Networks</b></mark>; <a href="https://arxiv.org/abs/1704.06888">Sermanet, et al. 2018</a>) learn from multi-camera view videos with the intuition that different viewpoints at the same timestep of the same scene should share the same embedding (like in <a href="https://arxiv.org/abs/1503.03832">FaceNet</a>) while embedding should vary in time, even of the same camera viewpoint. Therefore embedding captures the semantic meaning of the underlying state rather than visual similarity. The TCN embedding is trained with <a href="#triplet-loss">triplet loss</a>.</p>
<p>The training data is collected by taking videos of the same scene simultaneously but from different angles. All the videos are unlabelled.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/TCN.png" alt="Time-contrastive network" /></p>
<p><em>Fig. 25. An illustration of time-contrastive approach for learning state embedding. The blue frames selected from two camera views at the same timestep are anchor and positive samples, while the red frame at a different timestep is the negative sample.</em></p>
<p>TCN embedding extracts visual features that are invariant to camera configurations. It can be used to construct a reward function for imitation learning based on the euclidean distance between the demo video and the observations in the latent space.</p>
<p>A further improvement over TCN is to learn embedding over multiple frames jointly rather than a single frame, resulting in <strong>mfTCN</strong> (<b><mark>Multi-frame</mark> Time-Contrastive Networks</b>; <a href="https://arxiv.org/abs/1808.00928">Dwibedi et al., 2019</a>). Given a set of videos from several synchronized camera viewpoints, \(v_1, v_2, \dots, v_k\), the frame at time \(t\) and the previous \(n-1\) frames selected with stride \(s\) in each video are aggregated and mapped into one embedding vector, resulting in a lookback window of size $(n−1) \times s + 1$. Each frame first goes through a CNN to extract low-level features and then we use 3D temporal convolutions to aggregate frames in time. The model is trained with <a href="#n-pair-loss">n-pairs loss</a>.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/mfTCN.png" alt="mfTCN" /></p>
<p><em>Fig. 26. The sampling process for training mfTCN. (Image source: <a href="https://arxiv.org/abs/1808.00928">Dwibedi et al., 2019</a>)</em></p>
<p>The training data is sampled as follows:</p>
<ol>
<li>First we construct two pairs of video clips. Each pair contains two clips from different camera views but with synchronized timesteps. These two sets of videos should be far apart in time.</li>
<li>Sample a fixed number of frames from each video clip in the same pair simultaneously with the same stride.</li>
<li>Frames with the same timesteps are trained as positive samples in the n-pair loss, while frames across pairs are negative samples.</li>
</ol>
<p>mfTCN embedding can capture the position and velocity of objects in the scene (e.g. in cartpole) and can also be used as inputs for policy.</p>
<h3 id="autonomous-goal-generation">Autonomous Goal Generation</h3>
<p><strong>RIG</strong> (<b>Reinforcement learning with <mark>Imagined Goals</mark></b>; <a href="https://arxiv.org/abs/1807.04742">Nair et al., 2018</a>) described a way to train a goal-conditioned policy with unsupervised representation learning. A policy learns from self-supervised practice by first imagining “fake” goals and then trying to achieve them.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RIG.png" alt="RIG" /></p>
<p><em>Fig. 27. The workflow of RIG. (Image source: <a href="https://arxiv.org/abs/1807.04742">Nair et al., 2018</a>)</em></p>
<p>The task is to control a robot arm to push a small puck on a table to a desired position. The desired position, or the goal, is present in an image. During training, it learns latent embedding of both state \(s\) and goal \(g\) through $\beta$-VAE encoder and the control policy operates entirely in the latent space.</p>
<p>Let’s say a <a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#beta-vae">\(\beta\)-VAE</a> has an encoder \(q_\phi\) mapping input states to latent variable \(z\) which is modeled by a Gaussian distribution and a decoder \(p_\psi\) mapping \(z\) back to the states. The state encoder in RIG is set to be the mean of \(\beta\)-VAE encoder.</p>
\[\begin{aligned}
z &\sim q_\phi(z \vert s) = \mathcal{N}(z; \mu_\phi(s), \sigma^2_\phi(s)) \\
\mathcal{L}_{\beta\text{-VAE}} &= - \mathbb{E}_{z \sim q_\phi(z \vert s)} [\log p_\psi (s \vert z)] + \beta D_\text{KL}(q_\phi(z \vert s) \| p_\psi(s)) \\
e(s) &\triangleq \mu_\phi(s)
\end{aligned}\]
<p>The reward is the Euclidean distance between state and goal embedding vectors: \(r(s, g) = -\|e(s) - e(g)\|\). Similar to <a href="#grasp2vec">grasp2vec</a>, RIG applies data augmentation as well by latent goal relabeling: precisely half of the goals are generated from the prior at random and the other half are selected using HER. Also same as grasp2vec, rewards do not depend on any ground truth states but only the learned state encoding, so it can be used for training on real robots.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RIG-algorithm.png" alt="RIG algorithm" /></p>
<p><em>Fig. 28. The algorithm of RIG. (Image source: <a href="https://arxiv.org/abs/1807.04742">Nair et al., 2018</a>)</em></p>
<p>The problem with RIG is a lack of object variations in the imagined goal pictures. If \(\beta\)-VAE is only trained with a black puck, it would not be able to create a goal with other objects like blocks of different shapes and colors. A follow-up improvement replaces \(\beta\)-VAE with a <strong>CC-VAE</strong> (Context-Conditioned VAE; <a href="https://arxiv.org/abs/1910.11670">Nair, et al., 2019</a>), inspired by <strong>CVAE</strong> (Conditional VAE; <a href="https://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models">Sohn, Lee & Yan, 2015</a>), for goal generation.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CC-RIG.png" alt="Context-conditional RIG" /></p>
<p><em>Fig. 29. The workflow of context-conditioned RIG. (Image source: <a href="https://arxiv.org/abs/1910.11670">Nair, et al., 2019</a>).</em></p>
<p>A CVAE conditions on a context variable \(c\). It trains an encoder \(q_\phi(z \vert s, c)\) and a decoder \(p_\psi (s \vert z, c)\) and note that both have access to \(c\). The CVAE loss penalizes information passing from the input state \(s\) through an information bottleneck but allows for <em>unrestricted</em> information flow from \(c\) to both encoder and decoder.</p>
\[\mathcal{L}_\text{CVAE} = - \mathbb{E}_{z \sim q_\phi(z \vert s,c)} [\log p_\psi (s \vert z, c)] + \beta D_\text{KL}(q_\phi(z \vert s, c) \| p_\psi(s))\]
<p>To create plausible goals, CC-VAE conditions on a starting state \(s_0\) so that the generated goal presents a consistent type of object as in \(s_0\). This goal consistency is necessary; e.g. if the current scene contains a red puck but the goal has a blue block, it would confuse the policy.</p>
<p>Other than the state encoder \(e(s) \triangleq \mu_\phi(s)\), CC-VAE trains a second convolutional encoder \(e_0(.)\) to translate the starting state \(s_0\) into a compact context representation \(c = e_0(s_0)\). Two encoders, \(e(.)\) and \(e_0(.)\), are intentionally different without shared weights, as they are expected to encode different factors of image variation. In addition to the loss function of CVAE, CC-VAE adds an extra term to learn to reconstruct \(c\) back to \(s_0\), \(\hat{s}_0 = d_0(c)\).</p>
\[\mathcal{L}_\text{CC-VAE} = \mathcal{L}_\text{CVAE} + \log p(s_0\vert c)\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CC-RIG-goal-samples.png" alt="RIG goal samples" /></p>
<p><em>Fig. 30. Examples of imagined goals generated by CVAE that conditions on the context image (the first row), while VAE fails to capture the object consistency. (Image source: <a href="https://arxiv.org/abs/1910.11670">Nair, et al., 2019</a>).</em></p>
<h3 id="bisimulation">Bisimulation</h3>
<p>Task-agnostic representation (e.g. a model that intends to represent all the dynamics in the system) may distract the RL algorithms as irrelevant information is also presented. For example, if we just train an auto-encoder to reconstruct the input image, there is no guarantee that the entire learned representation will be useful for RL. Therefore, we need to move away from reconstruction-based representation learning if we only want to learn information relevant to control, as irrelevant details are still important for reconstruction.</p>
<p>Representation learning for control based on bisimulation does not depend on reconstruction, but aims to group states based on their behavioral similarity in MDP.</p>
<p><strong>Bisimulation</strong> (<a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.61.2493&rep=rep1&type=pdf">Givan et al. 2003</a>) refers to an equivalence relation between two states with similar long-term behavior. <em>Bisimulation metrics</em> quantify such relation so that we can aggregate states to compress a high-dimensional state space into a smaller one for more efficient computation. The <em>bisimulation distance</em> between two states corresponds to how behaviorally different these two states are.</p>
<p>Given a <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#markov-decision-processes">MDP</a> \(\mathcal{M} = \langle \mathcal{S}, \mathcal{A}, \mathcal{P}, \mathcal{R}, \gamma \rangle\) and a bisimulation relation \(B\), two states that are equal under relation \(B\) (i.e. \(s_i B s_j\)) should have the same immediate reward for all actions and the same transition probabilities over the next bisimilar states:</p>
\[\begin{aligned}
\mathcal{R}(s_i, a) &= \mathcal{R}(s_j, a) \; \forall a \in \mathcal{A} \\
\mathcal{P}(G \vert s_i, a) &= \mathcal{P}(G \vert s_j, a) \; \forall a \in \mathcal{A} \; \forall G \in \mathcal{S}_B
\end{aligned}\]
<p>where \(\mathcal{S}_B\) is a partition of the state space under the relation \(B\).</p>
<p>Note that \(=\) is always a bisimulation relation. The most interesting one is the maximal bisimulation relation \(\sim\), which defines a partition \(\mathcal{S}_\sim\) with <em>fewest</em> groups of states.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DeepMDP.png" alt="DeepMDP" /></p>
<p><em>Fig. 31. DeepMDP learns a latent space model by minimizing two losses on a reward model and a dynamics model. (Image source: <a href="https://arxiv.org/abs/1906.02736">Gelada, et al. 2019</a>)</em></p>
<p>With a goal similar to bisimulation metric, <strong>DeepMDP</strong> (<a href="https://arxiv.org/abs/1906.02736">Gelada, et al. 2019</a>) simplifies high-dimensional observations in RL tasks and learns a latent space model via minimizing two losses:</p>
<ol>
<li>prediction of rewards and</li>
<li>prediction of the distribution over next latent states.</li>
</ol>
\[\begin{aligned}
\mathcal{L}_{\bar{\mathcal{R}}}(s, a) = \vert \mathcal{R}(s, a) - \bar{\mathcal{R}}(\phi(s), a) \vert \\
\mathcal{L}_{\bar{\mathcal{P}}}(s, a) = D(\phi \mathcal{P}(s, a), \bar{\mathcal{P}}(. \vert \phi(s), a))
\end{aligned}\]
<p>where \(\phi(s)\) is the embedding of state \(s\); symbols with bar are functions (reward function \(R\) and transition function \(P\)) in the same MDP but running in the latent low-dimensional observation space. Here the embedding representation \(\phi\) can be connected to bisimulation metrics, as the bisimulation distance is proved to be upper-bounded by the L2 distance in the latent space.</p>
<p>The function \(D\) quantifies the distance between two probability distributions and should be chosen carefully. DeepMDP focuses on <em>Wasserstein-1</em> metric (also known as <a href="/lil-log/2017/08/20/from-GAN-to-WGAN.html#what-is-wasserstein-distance">“earth-mover distance”</a>). The Wasserstein-1 distance between distributions \(P\) and \(Q\) on a metric space \((M, d)\) (i.e., \(d: M \times M \to \mathbb{R}\)) is:</p>
\[W_d (P, Q) = \inf_{\lambda \in \Pi(P, Q)} \int_{M \times M} d(x, y) \lambda(x, y) \; \mathrm{d}x \mathrm{d}y\]
<p>where \(\Pi(P, Q)\) is the set of all <a href="https://en.wikipedia.org/wiki/Coupling_(probability)">couplings</a> of \(P\) and \(Q\). \(d(x, y)\) defines the cost of moving a particle from point \(x\) to point \(y\).</p>
<p>The Wasserstein metric has a dual form according to the Monge-Kantorovich duality:</p>
\[W_d (P, Q) = \sup_{f \in \mathcal{F}_d} \vert \mathbb{E}_{x \sim P} f(x) - \mathbb{E}_{y \sim Q} f(y) \vert\]
<p>where \(\mathcal{F}_d\) is the set of 1-Lipschitz functions under the metric \(d\) - \(\mathcal{F}_d = \{ f: \vert f(x) - f(y) \vert \leq d(x, y) \}\).</p>
<p>DeepMDP generalizes the model to the Norm Maximum Mean Discrepancy (Norm-<a href="https://en.wikipedia.org/wiki/Kernel_embedding_of_distributions#Measuring_distance_between_distributions">MMD</a>) metrics to improve the tightness of the bounds of its deep value function and, at the same time, to save computation (Wasserstein is expensive computationally). In their experiments, they found the model architecture of the transition prediction model can have a big impact on the performance. Adding these DeepMDP losses as auxiliary losses when training model-free RL agents leads to good improvement on most of the Atari games.</p>
<p><strong>Deep Bisimulatioin for Control</strong> (short for <strong>DBC</strong>; <a href="https://arxiv.org/abs/2006.10742">Zhang et al. 2020</a>) learns the latent representation of observations that are good for control in RL tasks, without domain knowledge or pixel-level reconstruction.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/DBC-illustration.png" alt="DBC algorithm" /></p>
<p><em>Fig. 32. The Deep Bisimulation for Control algorithm learns a bisimulation metric representation via learning a reward model and a dynamics model. The model architecture is a siamese network. (Image source: <a href="https://arxiv.org/abs/2006.10742">Zhang et al. 2020</a>)</em></p>
<p>Similar to DeepMDP, DBC models the dynamics by learning a reward model and a transition model. Both models operate in the latent space, \(\phi(s)\). The optimization of embedding \(\phi\) depends on one important conclusion from <a href="https://arxiv.org/abs/1207.4114">Ferns, et al. 2004</a> (Theorem 4.5) and <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.295.2114&rep=rep1&type=pdf">Ferns, et al 2011</a> (Theorem 2.6):</p>
<blockquote>
<p>Given \(c \in (0, 1)\) a discounting factor, \(\pi\) a policy that is being improved continuously, and \(M\) the space of bounded <a href="https://mathworld.wolfram.com/Pseudometric.html">pseudometric</a> on the state space \(\mathcal{S}\), we can define \(\mathcal{F}: M \mapsto M\):</p>
\[\mathcal{F}(d; \pi)(s_i, s_j) = (1-c) \vert \mathcal{R}_{s_i}^\pi - \mathcal{R}_{s_j}^\pi \vert + c W_d (\mathcal{P}_{s_i}^\pi, \mathcal{P}_{s_j}^\pi)\]
<p>Then, \(\mathcal{F}\) has a unique fixed point \(\tilde{d}\) which is a \(\pi^*\)-bisimulation metric and \(\tilde{d}(s_i, s_j) = 0 \iff s_i \sim s_j\).</p>
</blockquote>
<p>[The proof is not trivial. I may or may not add it in the future _(:3」∠)_ …]</p>
<p>Given batches of observations pairs, the training loss for \(\phi\), \(J(\phi)\), minimizes the mean square error between the on-policy bisimulation metric and Euclidean distance in the latent space:</p>
\[J(\phi) = \Big( \|\phi(s_i) - \phi(s_j)\|_1 - \vert \hat{\mathcal{R}}(\bar{\phi}(s_i)) - \hat{\mathcal{R}}(\bar{\phi}(s_j)) \vert - \gamma W_2(\hat{\mathcal{P}}(\cdot \vert \bar{\phi}(s_i), \bar{\pi}(\bar{\phi}(s_i))), \hat{\mathcal{P}}(\cdot \vert \bar{\phi}(s_j), \bar{\pi}(\bar{\phi}(s_j)))) \Big)^2\]
<p>where \(\bar{\phi}(s)\) denotes \(\phi(s)\) with stop gradient and \(\bar{\pi}\) is the mean policy output. The learned reward model \(\hat{\mathcal{R}}\) is deterministic and the learned forward dynamics model \(\hat{\mathcal{P}}\) outputs a Gaussian distribution.</p>
<p>DBC is based on SAC but operates on the latent space:</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/DBC-algorithm.png" alt="DBC algorithm" /></p>
<p><em>Fig. 33. The algorithm of Deep Bisimulation for Control. (Image source: <a href="https://arxiv.org/abs/2006.10742">Zhang et al. 2020</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019selfsup,
title = "Self-Supervised Representation Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "https://lilianweng.github.io/lil-log/2019/11/10/self-supervised-learning.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Alexey Dosovitskiy, et al. <a href="https://arxiv.org/abs/1406.6909">“Discriminative unsupervised feature learning with exemplar convolutional neural networks.”</a> IEEE transactions on pattern analysis and machine intelligence 38.9 (2015): 1734-1747.</p>
<p>[2] Spyros Gidaris, Praveer Singh & Nikos Komodakis. <a href="https://arxiv.org/abs/1803.07728">“Unsupervised Representation Learning by Predicting Image Rotations”</a> ICLR 2018.</p>
<p>[3] Carl Doersch, Abhinav Gupta, and Alexei A. Efros. <a href="https://arxiv.org/abs/1505.05192">“Unsupervised visual representation learning by context prediction.”</a> ICCV. 2015.</p>
<p>[4] Mehdi Noroozi & Paolo Favaro. <a href="https://arxiv.org/abs/1603.09246">“Unsupervised learning of visual representations by solving jigsaw puzzles.”</a> ECCV, 2016.</p>
<p>[5] Mehdi Noroozi, Hamed Pirsiavash, and Paolo Favaro. <a href="https://arxiv.org/abs/1708.06734">“Representation learning by learning to count.”</a> ICCV. 2017.</p>
<p>[6] Richard Zhang, Phillip Isola & Alexei A. Efros. <a href="https://arxiv.org/abs/1603.08511">“Colorful image colorization.”</a> ECCV, 2016.</p>
<p>[7] Pascal Vincent, et al. <a href="https://www.cs.toronto.edu/~larocheh/publications/icml-2008-denoising-autoencoders.pdf">“Extracting and composing robust features with denoising autoencoders.”</a> ICML, 2008.</p>
<p>[8] Jeff Donahue, Philipp Krähenbühl, and Trevor Darrell. <a href="https://arxiv.org/abs/1605.09782">“Adversarial feature learning.”</a> ICLR 2017.</p>
<p>[9] Deepak Pathak, et al. <a href="https://arxiv.org/abs/1604.07379">“Context encoders: Feature learning by inpainting.”</a> CVPR. 2016.</p>
<p>[10] Richard Zhang, Phillip Isola, and Alexei A. Efros. <a href="https://arxiv.org/abs/1611.09842">“Split-brain autoencoders: Unsupervised learning by cross-channel prediction.”</a> CVPR. 2017.</p>
<p>[11] Xiaolong Wang & Abhinav Gupta. <a href="https://arxiv.org/abs/1505.00687">“Unsupervised Learning of Visual Representations using Videos.”</a> ICCV. 2015.</p>
<p>[12] Carl Vondrick, et al. <a href="https://arxiv.org/pdf/1806.09594.pdf">“Tracking Emerges by Colorizing Videos”</a> ECCV. 2018.</p>
<p>[13] Ishan Misra, C. Lawrence Zitnick, and Martial Hebert. <a href="https://arxiv.org/abs/1603.08561">“Shuffle and learn: unsupervised learning using temporal order verification.”</a> ECCV. 2016.</p>
<p>[14] Basura Fernando, et al. <a href="https://arxiv.org/abs/1611.06646">“Self-Supervised Video Representation Learning With Odd-One-Out Networks”</a> CVPR. 2017.</p>
<p>[15] Donglai Wei, et al. <a href="https://www.robots.ox.ac.uk/~vgg/publications/2018/Wei18/wei18.pdf">“Learning and Using the Arrow of Time”</a> CVPR. 2018.</p>
<p>[16] Florian Schroff, Dmitry Kalenichenko and James Philbin. <a href="https://arxiv.org/abs/1503.03832">“FaceNet: A Unified Embedding for Face Recognition and Clustering”</a> CVPR. 2015.</p>
<p>[17] Pierre Sermanet, et al. <a href="https://arxiv.org/abs/1704.06888">“Time-Contrastive Networks: Self-Supervised Learning from Video”</a> CVPR. 2018.</p>
<p>[18] Debidatta Dwibedi, et al. <a href="https://arxiv.org/abs/1808.00928">“Learning actionable representations from visual observations.”</a> IROS. 2018.</p>
<p>[19] Eric Jang & Coline Devin, et al. <a href="https://arxiv.org/abs/1811.06964">“Grasp2Vec: Learning Object Representations from Self-Supervised Grasping”</a> CoRL. 2018.</p>
<p>[20] Ashvin Nair, et al. <a href="https://arxiv.org/abs/1807.04742">“Visual reinforcement learning with imagined goals”</a> NeuriPS. 2018.</p>
<p>[21] Ashvin Nair, et al. <a href="https://arxiv.org/abs/1910.11670">“Contextual imagined goals for self-supervised robotic learning”</a> CoRL. 2019.</p>
<p>[22] Aaron van den Oord, Yazhe Li & Oriol Vinyals. <a href="https://arxiv.org/abs/1807.03748">“Representation Learning with Contrastive Predictive Coding”</a> arXiv preprint arXiv:1807.03748, 2018.</p>
<p>[23] Olivier J. Henaff, et al. <a href="https://arxiv.org/abs/1905.09272">“Data-Efficient Image Recognition with Contrastive Predictive Coding”</a> arXiv preprint arXiv:1905.09272, 2019.</p>
<p>[24] Kaiming He, et al. <a href="https://arxiv.org/abs/1911.05722">“Momentum Contrast for Unsupervised Visual Representation Learning.”</a> CVPR 2020.</p>
<p>[25] Zhirong Wu, et al. <a href="https://arxiv.org/abs/1805.01978v1">“Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination.”</a> CVPR 2018.</p>
<p>[26] Ting Chen, et al. <a href="https://arxiv.org/abs/2002.05709">“A Simple Framework for Contrastive Learning of Visual Representations.”</a> arXiv preprint arXiv:2002.05709, 2020.</p>
<p>[27] Aravind Srinivas, Michael Laskin & Pieter Abbeel <a href="https://arxiv.org/abs/2004.04136">“CURL: Contrastive Unsupervised Representations for Reinforcement Learning.”</a> arXiv preprint arXiv:2004.04136, 2020.</p>
<p>[28] Carles Gelada, et al. <a href="https://arxiv.org/abs/1906.02736">“DeepMDP: Learning Continuous Latent Space Models for Representation Learning”</a> ICML 2019.</p>
<p>[29] Amy Zhang, et al. <a href="https://arxiv.org/abs/2006.10742">“Learning Invariant Representations for Reinforcement Learning without Reconstruction”</a> arXiv preprint arXiv:2006.10742, 2020.</p>
<p>[30] Xinlei Chen, et al. <a href="https://arxiv.org/abs/2003.04297">“Improved Baselines with Momentum Contrastive Learning”</a> arXiv preprint arXiv:2003.04297, 2020.</p>
<p>[31] Jean-Bastien Grill, et al. <a href="https://arxiv.org/abs/2006.07733">“Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning”</a> arXiv preprint arXiv:2006.07733, 2020.</p>
<p>[32] Abe Fetterman & Josh Albrecht. <a href="https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html">“Understanding self-supervised and contrastive learning with Bootstrap Your Own Latent (BYOL)”</a> Untitled blog. Aug 24, 2020.</p>Lilian WengSelf-supervised learning opens up a huge opportunity for better utilizing unlabelled data, while learning in a supervised learning manner. This post covers many interesting ideas of self-supervised learning tasks on images, videos, and control problems.Evolution Strategies2019-09-05T12:00:00+00:002019-09-05T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/09/05/evolution-strategies<blockquote>
<p>Gradient descent is not the only option when learning optimal model parameters. Evolution Strategies (ES) works out well in the cases where we don’t know the precise analytic form of an objective function or cannot compute the gradients directly. This post dives into several classic ES methods, as well as how ES can be used in deep reinforcement learning.</p>
</blockquote>
<!--more-->
<p>Stochastic gradient descent is a universal choice for optimizing deep learning models. However, it is not the only option. With black-box optimization algorithms, you can evaluate a target function \(f(x): \mathbb{R}^n \to \mathbb{R}\), even when you don’t know the precise analytic form of \(f(x)\) and thus cannot compute gradients or the Hessian matrix. Examples of black-box optimization methods include <a href="https://en.wikipedia.org/wiki/Simulated_annealing">Simulated Annealing</a>, <a href="https://en.wikipedia.org/wiki/Hill_climbing">Hill Climbing</a> and <a href="https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method">Nelder-Mead method</a>.</p>
<p><strong>Evolution Strategies (ES)</strong> is one type of black-box optimization algorithms, born in the family of <strong>Evolutionary Algorithms (EA)</strong>. In this post, I would dive into a couple of classic ES methods and introduce a few applications of how ES can play a role in deep reinforcement learning.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#what-are-evolution-strategies" id="markdown-toc-what-are-evolution-strategies">What are Evolution Strategies?</a></li>
<li><a href="#simple-gaussian-evolution-strategies" id="markdown-toc-simple-gaussian-evolution-strategies">Simple Gaussian Evolution Strategies</a></li>
<li><a href="#covariance-matrix-adaptation-evolution-strategies-cma-es" id="markdown-toc-covariance-matrix-adaptation-evolution-strategies-cma-es">Covariance Matrix Adaptation Evolution Strategies (CMA-ES)</a> <ul>
<li><a href="#updating-the-mean" id="markdown-toc-updating-the-mean">Updating the Mean</a></li>
<li><a href="#controlling-the-step-size" id="markdown-toc-controlling-the-step-size">Controlling the Step Size</a></li>
<li><a href="#adapting-the-covariance-matrix" id="markdown-toc-adapting-the-covariance-matrix">Adapting the Covariance Matrix</a></li>
</ul>
</li>
<li><a href="#natural-evolution-strategies" id="markdown-toc-natural-evolution-strategies">Natural Evolution Strategies</a> <ul>
<li><a href="#natural-gradients" id="markdown-toc-natural-gradients">Natural Gradients</a></li>
<li><a href="#estimation-using-fisher-information-matrix" id="markdown-toc-estimation-using-fisher-information-matrix">Estimation using Fisher Information Matrix</a></li>
<li><a href="#nes-algorithm" id="markdown-toc-nes-algorithm">NES Algorithm</a></li>
</ul>
</li>
<li><a href="#applications-es-in-deep-reinforcement-learning" id="markdown-toc-applications-es-in-deep-reinforcement-learning">Applications: ES in Deep Reinforcement Learning</a> <ul>
<li><a href="#openai-es-for-rl" id="markdown-toc-openai-es-for-rl">OpenAI ES for RL</a></li>
<li><a href="#exploration-with-es" id="markdown-toc-exploration-with-es">Exploration with ES</a></li>
<li><a href="#cem-rl" id="markdown-toc-cem-rl">CEM-RL</a></li>
</ul>
</li>
<li><a href="#extension-ea-in-deep-learning" id="markdown-toc-extension-ea-in-deep-learning">Extension: EA in Deep Learning</a> <ul>
<li><a href="#hyperparameter-tuning-pbt" id="markdown-toc-hyperparameter-tuning-pbt">Hyperparameter Tuning: PBT</a></li>
<li><a href="#network-topology-optimization-wann" id="markdown-toc-network-topology-optimization-wann">Network Topology Optimization: WANN</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="what-are-evolution-strategies">What are Evolution Strategies?</h2>
<p>Evolution strategies (ES) belong to the big family of evolutionary algorithms. The optimization targets of ES are vectors of real numbers, \(x \in \mathbb{R}^n\).</p>
<p>Evolutionary algorithms refer to a division of population-based optimization algorithms inspired by <em>natural selection</em>. Natural selection believes that individuals with traits beneficial to their survival can live through generations and pass down the good characteristics to the next generation. Evolution happens by the selection process gradually and the population grows better adapted to the environment.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/EA-illustration.png" alt="EA" /></p>
<p><em>Fig. 1. How natural selection works. (Image source: Khan Academy: <a href="https://www.khanacademy.org/science/biology/her/evolution-and-natural-selection/a/darwin-evolution-natural-selection">Darwin, evolution, & natural selection</a>)</em></p>
<p>Evolutionary algorithms can be summarized in the following <a href="https://ipvs.informatik.uni-stuttgart.de/mlr/marc/teaching/13-Optimization/06-blackBoxOpt.pdf">format</a> as a general optimization solution:</p>
<p>Let’s say we want to optimize a function \(f(x)\) and we are not able to compute gradients directly. But we still can evaluate \(f(x)\) given any \(x\) and the result is deterministic. Our belief in the probability distribution over \(x\) as a good solution to \(f(x)\) optimization is \(p_\theta(x)\), parameterized by \(\theta\). The goal is to find an optimal configuration of \(\theta\).</p>
<blockquote>
<p>Here given a fixed format of distribution (i.e. Gaussian), the parameter \(\theta\) carries the knowledge about the best solutions and is being iteratively updated across generations.</p>
</blockquote>
<p>Starting with an initial value of \(\theta\), we can continuously update \(\theta\) by looping three steps as follows:</p>
<ol>
<li>Generate a population of samples \(D = \{(x_i, f(x_i)\}\) where \(x_i \sim p_\theta(x)\).</li>
<li>Evaluate the “fitness” of samples in \(D\).</li>
<li>Select the best subset of individuals and use them to update \(\theta\), generally based on fitness or rank.</li>
</ol>
<p>In <strong>Genetic Algorithms (GA)</strong>, another popular subcategory of EA, \(x\) is a sequence of binary codes, \(x \in \{0, 1\}^n\). While in ES, \(x\) is just a vector of real numbers, \(x \in \mathbb{R}^n\).</p>
<h2 id="simple-gaussian-evolution-strategies">Simple Gaussian Evolution Strategies</h2>
<p><a href="http://blog.otoro.net/2017/10/29/visual-evolution-strategies/">This</a> is the most basic and canonical version of evolution strategies. It models \(p_\theta(x)\) as a \(n\)-dimensional isotropic Gaussian distribution, in which \(\theta\) only tracks the mean \(\mu\) and standard deviation \(\sigma\).</p>
\[\theta = (\mu, \sigma),\;p_\theta(x) \sim \mathcal{N}(\mathbf{\mu}, \sigma^2 I) = \mu + \sigma \mathcal{N}(0, I)\]
<p>The process of Simple-Gaussian-ES, given \(x \in \mathcal{R}^n\):</p>
<ol>
<li>Initialize \(\theta = \theta^{(0)}\) and the generation counter \(t=0\)</li>
<li>Generate the offspring population of size \(\Lambda\) by sampling from the Gaussian distribution:<br /><br />\(D^{(t+1)}=\{ x^{(t+1)}_i \mid x^{(t+1)}_i = \mu^{(t)} + \sigma^{(t)} y^{(t+1)}_i \text{ where } y^{(t+1)}_i \sim \mathcal{N}(x \vert 0, \mathbf{I}),\;i = 1, \dots, \Lambda\}\)<br />.</li>
<li>Select a top subset of \(\lambda\) samples with optimal \(f(x_i)\) and this subset is called <strong>elite</strong> set. Without loss of generality, we may consider the first \(k\) samples in \(D^{(t+1)}\) to belong to the elite group — Let’s label them as<br /><br />\(D^{(t+1)}_\text{elite} = \{x^{(t+1)}_i \mid x^{(t+1)}_i \in D^{(t+1)}, i=1,\dots, \lambda, \lambda\leq \Lambda\}\)<br />.</li>
<li>Then we estimate the new mean and std for the next generation using the elite set:<br /><br />
\(\begin{aligned}
\mu^{(t+1)} &= \text{avg}(D^{(t+1)}_\text{elite}) = \frac{1}{\lambda}\sum_{i=1}^\lambda x_i^{(t+1)} \\
{\sigma^{(t+1)}}^2 &= \text{var}(D^{(t+1)}_\text{elite}) = \frac{1}{\lambda}\sum_{i=1}^\lambda (x_i^{(t+1)} -\mu^{(t)})^2
\end{aligned}\)<br /></li>
<li>Repeat steps (2)-(4) until the result is good enough ✌️</li>
</ol>
<h2 id="covariance-matrix-adaptation-evolution-strategies-cma-es">Covariance Matrix Adaptation Evolution Strategies (CMA-ES)</h2>
<p>The standard deviation \(\sigma\) accounts for the level of exploration: the larger \(\sigma\) the bigger search space we can sample our offspring population. In <a href="#simple-gaussian-evolution-strategies">vanilla ES</a>, \(\sigma^{(t+1)}\) is highly correlated with \(\sigma^{(t)}\), so the algorithm is not able to rapidly adjust the exploration space when needed (i.e. when the confidence level changes).</p>
<p><a href="https://en.wikipedia.org/wiki/CMA-ES"><strong>CMA-ES</strong></a>, short for <em>“Covariance Matrix Adaptation Evolution Strategy”</em>, fixes the problem by tracking pairwise dependencies between the samples in the distribution with a covariance matrix \(C\). The new distribution parameter becomes:</p>
\[\theta = (\mu, \sigma, C),\; p_\theta(x) \sim \mathcal{N}(\mu, \sigma^2 C) \sim \mu + \sigma \mathcal{N}(0, C)\]
<p>where \(\sigma\) controls for the overall scale of the distribution, often known as <em>step size</em>.</p>
<p>Before we dig into how the parameters are updated in CMA-ES, it is better to review how the covariance matrix works in the multivariate Gaussian distribution first. As a real symmetric matrix, the covariance matrix \(C\) has the following nice features (See <a href="http://s3.amazonaws.com/mitsloan-php/wp-faculty/sites/30/2016/12/15032137/Symmetric-Matrices-and-Eigendecomposition.pdf">proof</a> & <a href="http://control.ucsd.edu/mauricio/courses/mae280a/lecture11.pdf">proof</a>):</p>
<ul>
<li>It is always diagonalizable.</li>
<li>Always positive semi-definite.</li>
<li>All of its eigenvalues are real non-negative numbers.</li>
<li>All of its eigenvectors are orthogonal.</li>
<li>There is an orthonormal basis of \(\mathbb{R}^n\) consisting of its eigenvectors.</li>
</ul>
<p>Let the matrix \(C\) have an <em>orthonormal</em> basis of eigenvectors \(B = [b_1, \dots, b_n]\), with corresponding eigenvalues \(\lambda_1^2, \dots, \lambda_n^2\). Let \(D=\text{diag}(\lambda_1, \dots, \lambda_n)\).</p>
\[C = B^\top D^2 B
= \begin{bmatrix}
\mid & \mid & & \mid \\
b_1 & b_2 & \dots & b_n\\
\mid & \mid & & \mid \\
\end{bmatrix}
\begin{bmatrix}
\lambda_1^2 & 0 & \dots & 0 \\
0 & \lambda_2^2 & \dots & 0 \\
\vdots & \dots & \ddots & \vdots \\
0 & \dots & 0 & \lambda_n^2
\end{bmatrix}
\begin{bmatrix}
- & b_1 & - \\
- & b_2 & - \\
& \dots & \\
- & b_n & - \\
\end{bmatrix}\]
<p>The square root of \(C\) is:</p>
\[C^{\frac{1}{2}} = B^\top D B\]
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(x_i^{(t)} \in \mathbb{R}^n\)</td>
<td>the \(i\)-th samples at the generation (t)</td>
</tr>
<tr>
<td>\(y_i^{(t)} \in \mathbb{R}^n\)</td>
<td>\(x_i^{(t)} = \mu^{(t-1)} + \sigma^{(t-1)} y_i^{(t)}\)</td>
</tr>
<tr>
<td>\(\mu^{(t)}\)</td>
<td>mean of the generation (t)</td>
</tr>
<tr>
<td>\(\sigma^{(t)}\)</td>
<td>step size</td>
</tr>
<tr>
<td>\(C^{(t)}\)</td>
<td>covariance matrix</td>
</tr>
<tr>
<td>\(B^{(t)}\)</td>
<td>a matrix of \(C\)’s eigenvectors as row vectors</td>
</tr>
<tr>
<td>\(D^{(t)}\)</td>
<td>a diagonal matrix with \(C\)’s eigenvalues on the diagnose.</td>
</tr>
<tr>
<td>\(p_\sigma^{(t)}\)</td>
<td>evaluation path for \(\sigma\) at the generation (t)</td>
</tr>
<tr>
<td>\(p_c^{(t)}\)</td>
<td>evaluation path for \(C\) at the generation (t)</td>
</tr>
<tr>
<td>\(\alpha_\mu\)</td>
<td>learning rate for \(\mu\)’s update</td>
</tr>
<tr>
<td>\(\alpha_\sigma\)</td>
<td>learning rate for \(p_\sigma\)</td>
</tr>
<tr>
<td>\(d_\sigma\)</td>
<td>damping factor for \(\sigma\)’s update</td>
</tr>
<tr>
<td>\(\alpha_{cp}\)</td>
<td>learning rate for \(p_c\)</td>
</tr>
<tr>
<td>\(\alpha_{c\lambda}\)</td>
<td>learning rate for \(C\)’s rank-min(λ, n) update</td>
</tr>
<tr>
<td>\(\alpha_{c1}\)</td>
<td>learning rate for \(C\)’s rank-1 update</td>
</tr>
</tbody>
</table>
<h3 id="updating-the-mean">Updating the Mean</h3>
\[\mu^{(t+1)} = \mu^{(t)} + \alpha_\mu \frac{1}{\lambda}\sum_{i=1}^\lambda (x_i^{(t+1)} - \mu^{(t)})\]
<p>CMA-ES has a learning rate \(\alpha_\mu \leq 1\) to control how fast the mean \(\mu\) should be updated. Usually it is set to 1 and thus the equation becomes the same as in vanilla ES, \(\mu^{(t+1)} = \frac{1}{\lambda}\sum_{i=1}^\lambda (x_i^{(t+1)}\).</p>
<h3 id="controlling-the-step-size">Controlling the Step Size</h3>
<p>The sampling process can be decoupled from the mean and standard deviation:</p>
\[x^{(t+1)}_i = \mu^{(t)} + \sigma^{(t)} y^{(t+1)}_i \text{, where } y^{(t+1)}_i = \frac{x_i^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \sim \mathcal{N}(0, C)\]
<p>The parameter \(\sigma\) controls the overall scale of the distribution. It is separated from the covariance matrix so that we can change steps faster than the full covariance. A larger step size leads to faster parameter update. In order to evaluate whether the current step size is proper, CMA-ES constructs an <em>evolution path</em> \(p_\sigma\) by summing up a consecutive sequence of moving steps, \(\frac{1}{\lambda}\sum_{i}^\lambda y_i^{(j)}, j=1, \dots, t\). By comparing this path length with its expected length under random selection (meaning single steps are uncorrelated), we are able to adjust \(\sigma\) accordingly (See Fig. 2).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CMA-ES-step-size-path.png" alt="CMA-ES step size" /></p>
<p><em>Fig. 2. Three scenarios of how single steps are correlated in different ways and their impacts on step size update. (Image source: additional annotations on Fig 5 in <a href="https://arxiv.org/abs/1604.00772">CMA-ES tutorial</a> paper)</em></p>
<p>Each time the evolution path is updated with the average of moving step \(y_i\) in the same generation.</p>
\[\begin{aligned}
&\frac{1}{\lambda}\sum_{i=1}^\lambda y_i^{(t+1)}
= \frac{1}{\lambda} \frac{\sum_{i=1}^\lambda x_i^{(t+1)} - \lambda \mu^{(t)}}{\sigma^{(t)}}
= \frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \\
&\frac{1}{\lambda}\sum_{i=1}^\lambda y_i^{(t+1)}
\sim \frac{1}{\lambda}\mathcal{N}(0, \lambda C^{(t)})
\sim \frac{1}{\sqrt{\lambda}}{C^{(t)}}^{\frac{1}{2}}\mathcal{N}(0, I) \\
&\text{Thus } \sqrt{\lambda}\;{C^{(t)}}^{-\frac{1}{2}} \frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \sim \mathcal{N}(0, I)
\end{aligned}\]
<blockquote>
<p>By multiplying with \(C^{-\frac{1}{2}}\), the evolution path is transformed to be independent of its direction. The term \({C^{(t)}}^{-\frac{1}{2}} = {B^{(t)}}^\top {D^{(t)}}^{-\frac{1}{2}} {B^{(t)}}\) transformation works as follows:</p>
<ol>
<li>\({B^{(t)}}\) contains row vectors of \(C\)’s eigenvectors. It projects the original space onto the perpendicular principal axes.</li>
<li>Then \({D^{(t)}}^{-\frac{1}{2}} = \text{diag}(\frac{1}{\lambda_1}, \dots, \frac{1}{\lambda_n})\) scales the length of principal axes to be equal.</li>
<li>\({B^{(t)}}^\top\) transforms the space back to the original coordinate system.</li>
</ol>
</blockquote>
<p>In order to assign higher weights to recent generations, we use polyak averaging to update the evolution path with learning rate \(\alpha_\sigma\). Meanwhile, the weights are balanced so that \(p_\sigma\) is <a href="https://en.wikipedia.org/wiki/Conjugate_prior">conjugate</a>, \(\sim \mathcal{N}(0, I)\) both before and after one update.</p>
\[\begin{aligned}
p_\sigma^{(t+1)}
& = (1 - \alpha_\sigma) p_\sigma^{(t)} + \sqrt{1 - (1 - \alpha_\sigma)^2}\;\sqrt{\lambda}\; {C^{(t)}}^{-\frac{1}{2}} \frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \\
& = (1 - \alpha_\sigma) p_\sigma^{(t)} + \sqrt{c_\sigma (2 - \alpha_\sigma)\lambda}\;{C^{(t)}}^{-\frac{1}{2}} \frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}}
\end{aligned}\]
<p>The expected length of \(p_\sigma\) under random selection is \(\mathbb{E}\|\mathcal{N}(0,I)\|\), that is the expectation of the L2-norm of a \(\mathcal{N}(0,I)\) random variable. Following the idea in Fig. 2, we adjust the step size according to the ratio of \(\|p_\sigma^{(t+1)}\| / \mathbb{E}\|\mathcal{N}(0,I)\|\):</p>
\[\begin{aligned}
\ln\sigma^{(t+1)} &= \ln\sigma^{(t)} + \frac{\alpha_\sigma}{d_\sigma} \Big(\frac{\|p_\sigma^{(t+1)}\|}{\mathbb{E}\|\mathcal{N}(0,I)\|} - 1\Big) \\
\sigma^{(t+1)} &= \sigma^{(t)} \exp\Big(\frac{\alpha_\sigma}{d_\sigma} \Big(\frac{\|p_\sigma^{(t+1)}\|}{\mathbb{E}\|\mathcal{N}(0,I)\|} - 1\Big)\Big)
\end{aligned}\]
<p>where \(d_\sigma \approx 1\) is a damping parameter, scaling how fast \(\ln\sigma\) should be changed.</p>
<h3 id="adapting-the-covariance-matrix">Adapting the Covariance Matrix</h3>
<p>For the covariance matrix, it can be estimated from scratch using \(y_i\) of elite samples (recall that \(y_i \sim \mathcal{N}(0, C)\)):</p>
\[C_\lambda^{(t+1)}
= \frac{1}{\lambda}\sum_{i=1}^\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\top
= \frac{1}{\lambda {\sigma^{(t)}}^2} \sum_{i=1}^\lambda (x_i^{(t+1)} - \mu^{(t)})(x_i^{(t+1)} - \mu^{(t)})^\top\]
<p>The above estimation is only reliable when the selected population is large enough. However, we do want to run <em>fast</em> iteration with a <em>small</em> population of samples in each generation. That’s why CMA-ES invented a more reliable but also more complicated way to update \(C\). It involves two independent routes,</p>
<ul>
<li><em>Rank-min(λ, n) update</em>: uses the history of \(\{C_\lambda\}\), each estimated from scratch in one generation.</li>
<li><em>Rank-one update</em>: estimates the moving steps \(y_i\) and the sign information from the history.</li>
</ul>
<p>The first route considers the estimation of \(C\) from the entire history of \(\{C_\lambda\}\). For example, if we have experienced a large number of generations, \(C^{(t+1)} \approx \text{avg}(C_\lambda^{(i)}; i=1,\dots,t)\) would be a good estimator. Similar to \(p_\sigma\), we also use polyak averaging with a learning rate to incorporate the history:</p>
\[C^{(t+1)}
= (1 - \alpha_{c\lambda}) C^{(t)} + \alpha_{c\lambda} C_\lambda^{(t+1)}
= (1 - \alpha_{c\lambda}) C^{(t)} + \alpha_{c\lambda} \frac{1}{\lambda} \sum_{i=1}^\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\top\]
<p>A common choice for the learning rate is \(\alpha_{c\lambda} \approx \min(1, \lambda/n^2)\).</p>
<p>The second route tries to solve the issue that \(y_i{y_i}^\top = (-y_i)(-y_i)^\top\) loses the sign information. Similar to how we adjust the step size \(\sigma\), an evolution path \(p_c\) is used to track the sign information and it is constructed in a way that \(p_c\) is conjugate, \(\sim \mathcal{N}(0, C)\) both before and after a new generation.</p>
<p>We may consider \(p_c\) as another way to compute \(\text{avg}_i(y_i)\) (notice that both \(\sim \mathcal{N}(0, C)\)) while the entire history is used and the sign information is maintained. Note that we’ve known \(\sqrt{k}\frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \sim \mathcal{N}(0, C)\) in the <a href="#controlling-the-step-size">last section</a>,</p>
\[\begin{aligned}
p_c^{(t+1)}
&= (1-\alpha_{cp}) p_c^{(t)} + \sqrt{1 - (1-\alpha_{cp})^2}\;\sqrt{\lambda}\;\frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \\
&= (1-\alpha_{cp}) p_c^{(t)} + \sqrt{\alpha_{cp}(2 - \alpha_{cp})\lambda}\;\frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}}
\end{aligned}\]
<p>Then the covariance matrix is updated according to \(p_c\):</p>
\[C^{(t+1)} = (1-\alpha_{c1}) C^{(t)} + \alpha_{c1}\;p_c^{(t+1)} {p_c^{(t+1)}}^\top\]
<p>The <em>rank-one update</em> approach is claimed to generate a significant improvement over the <em>rank-min(λ, n)-update</em> when \(k\) is small, because the signs of moving steps and correlations between consecutive steps are all utilized and passed down through generations.</p>
<p>Eventually we combine two approaches together,</p>
\[C^{(t+1)}
= (1 - \alpha_{c\lambda} - \alpha_{c1}) C^{(t)}
+ \alpha_{c1}\;\underbrace{p_c^{(t+1)} {p_c^{(t+1)}}^\top}_\textrm{rank-one update}
+ \alpha_{c\lambda} \underbrace{\frac{1}{\lambda} \sum_{i=1}^\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\top}_\textrm{rank-min(lambda, n) update}\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CMA-ES-algorithm.png" alt="CMA-ES Algorithm" /></p>
<p>In all my examples above, each elite sample is considered to contribute an equal amount of weights, \(1/\lambda\). The process can be easily extended to the case where selected samples are assigned with different weights, \(w_1, \dots, w_\lambda\), according to their performances. See more detail in <a href="https://arxiv.org/abs/1604.00772">tutorial</a>.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/CMA-ES-illustration.png" alt="CMA-ES Illustration" /></p>
<p><em>Fig. 3. Illustration of how CMA-ES works on a 2D optimization problem (the lighter color the better). Black dots are samples in one generation. The samples are more spread out initially but when the model has higher confidence in finding a good solution in the late stage, the samples become very concentrated over the global optimum. (Image source: <a href="https://en.wikipedia.org/wiki/CMA-ES">Wikipedia CMA-ES</a>)</em></p>
<h2 id="natural-evolution-strategies">Natural Evolution Strategies</h2>
<p>Natural Evolution Strategies (<strong>NES</strong>; <a href="https://arxiv.org/abs/1106.4487">Wierstra, et al, 2008</a>) optimizes in a search distribution of parameters and moves the distribution in the direction of high fitness indicated by the <em>natural gradient</em>.</p>
<h3 id="natural-gradients">Natural Gradients</h3>
<p>Given an objective function \(\mathcal{J}(\theta)\) parameterized by \(\theta\), let’s say our goal is to find the optimal \(\theta\) to maximize the objective function value. A <em>plain gradient</em> finds the steepest direction within a small Euclidean distance from the current \(\theta\); the distance restriction is applied on the parameter space. In other words, we compute the plain gradient with respect to a small change of the absolute value of \(\theta\). The optimal step is:</p>
\[d^{*} = \operatorname*{argmax}_{\|d\| = \epsilon} \mathcal{J}(\theta + d)\text{, where }\epsilon \to 0\]
<p>Differently, <em>natural gradient</em> works with a probability <a href="https://arxiv.org/abs/1301.3584v7">distribution</a> <a href="https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/">space</a> parameterized by \(\theta\), \(p_\theta(x)\) (referred to as “search distribution” in NES <a href="https://arxiv.org/abs/1106.4487">paper</a>). It looks for the steepest direction within a small step in the distribution space where the distance is measured by KL divergence. With this constraint we ensure that each update is moving along the distributional manifold with constant speed, without being slowed down by its curvature.</p>
\[d^{*}_\text{N} = \operatorname*{argmax}_{\text{KL}[p_\theta \| p_{\theta+d}] = \epsilon} \mathcal{J}(\theta + d)\]
<h3 id="estimation-using-fisher-information-matrix">Estimation using Fisher Information Matrix</h3>
<p>But, how to compute \(\text{KL}[p_\theta \| p_{\theta+\Delta\theta}]\) precisely? By running Taylor expansion of \(\log p_{\theta + d}\) at \(\theta\), we get:</p>
\[\begin{aligned}
& \text{KL}[p_\theta \| p_{\theta+d}] \\
&= \mathbb{E}_{x \sim p_\theta} [\log p_\theta(x) - \log p_{\theta+d}(x)] & \\
&\approx \mathbb{E}_{x \sim p_\theta} [ \log p_\theta(x) -( \log p_{\theta}(x) + \nabla_\theta \log p_{\theta}(x) d + \frac{1}{2}d^\top \nabla^2_\theta \log p_{\theta}(x) d)] & \scriptstyle{\text{; Taylor expand }\log p_{\theta+d}} \\
&\approx - \mathbb{E}_x [\nabla_\theta \log p_{\theta}(x)] d - \frac{1}{2}d^\top \mathbb{E}_x [\nabla^2_\theta \log p_{\theta}(x)] d &
\end{aligned}\]
<p>where</p>
\[\begin{aligned}
\mathbb{E}_x [\nabla_\theta \log p_{\theta}] d
&= \int_{x\sim p_\theta} p_\theta(x) \nabla_\theta \log p_\theta(x) & \\
&= \int_{x\sim p_\theta} p_\theta(x) \frac{1}{p_\theta(x)} \nabla_\theta p_\theta(x) & \\
&= \nabla_\theta \Big( \int_{x} p_\theta(x) \Big) & \scriptstyle{\textrm{; note that }p_\theta(x)\textrm{ is probability distribution.}} \\
&= \nabla_\theta (1) = 0
\end{aligned}\]
<p>Finally we have,</p>
\[\text{KL}[p_\theta \| p_{\theta+d}] = - \frac{1}{2}d^\top \mathbf{F}_\theta d
\text{, where }\mathbf{F}_\theta = \mathbb{E}_x [(\nabla_\theta \log p_{\theta}) (\nabla_\theta \log p_{\theta})^\top]\]
<p>where \(\mathbf{F}_\theta\) is called the <strong><a href="http://mathworld.wolfram.com/FisherInformationMatrix.html">Fisher Information Matrix</a></strong> and <a href="https://wiseodd.github.io/techblog/2018/03/11/fisher-information/">it is</a> the covariance matrix of \(\nabla_\theta \log p_\theta\) since \(\mathbb{E}[\nabla_\theta \log p_\theta] = 0\).</p>
<p>The solution to the following optimization problem:</p>
\[\max \mathcal{J}(\theta + d) \approx \max \big( \mathcal{J}(\theta) + {\nabla_\theta\mathcal{J}(\theta)}^\top d \big)\;\text{ s.t. }\text{KL}[p_\theta \| p_{\theta+d}] - \epsilon = 0\]
<p>can be found using a Lagrangian multiplier,</p>
\[\begin{aligned}
\mathcal{L}(\theta, d, \beta) &= \mathcal{J}(\theta) + \nabla_\theta\mathcal{J}(\theta)^\top d - \beta (\frac{1}{2}d^\top \mathbf{F}_\theta d + \epsilon) = 0 \text{ s.t. } \beta > 0 \\
\nabla_d \mathcal{L}(\theta, d, \beta) &= \nabla_\theta\mathcal{J}(\theta) - \beta\mathbf{F}_\theta d = 0 \\
\text{Thus } d_\text{N}^* &= \nabla_\theta^\text{N} \mathcal{J}(\theta) = \mathbf{F}_\theta^{-1} \nabla_\theta\mathcal{J}(\theta)
\end{aligned}\]
<p>where \(d_\text{N}^*\) only extracts the direction of the optimal moving step on \(\theta\), ignoring the scalar \(\beta^{-1}\).</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/CMA-ES-coordinates.png" alt="Plain vs natural coordinates" /></p>
<p><em>Fig. 4. The natural gradient samples (black solid arrows) in the right are the plain gradient samples (black solid arrows) in the left multiplied by the inverse of their covariance. In this way, a gradient direction with high uncertainty (indicated by high covariance with other samples) are penalized with a small weight. The aggregated natural gradient (red dash arrow) is therefore more trustworthy than the natural gradient (green solid arrow). (Image source: additional annotations on Fig 2 in <a href="https://arxiv.org/abs/1106.4487">NES</a> paper)</em></p>
<h3 id="nes-algorithm">NES Algorithm</h3>
<p>The fitness associated with one sample is labeled as \(f(x)\) and the search distribution over \(x\) is parameterized by \(\theta\). NES is expected to optimize the parameter \(\theta\) to achieve maximum expected fitness:</p>
\[\mathcal{J}(\theta) = \mathbb{E}_{x\sim p_\theta(x)} [f(x)] = \int_x f(x) p_\theta(x) dx\]
<p>Using the same log-likelihood <a href="http://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/">trick</a> in <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce">REINFORCE</a>:</p>
\[\begin{aligned}
\nabla_\theta\mathcal{J}(\theta)
&= \nabla_\theta \int_x f(x) p_\theta(x) dx \\
&= \int_x f(x) \frac{p_\theta(x)}{p_\theta(x)}\nabla_\theta p_\theta(x) dx \\
& = \int_x f(x) p_\theta(x) \nabla_\theta \log p_\theta(x) dx \\
& = \mathbb{E}_{x \sim p_\theta} [f(x) \nabla_\theta \log p_\theta(x)]
\end{aligned}\]
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/NES-algorithm.png" alt="NES" /></p>
<p>Besides natural gradients, NES adopts a couple of important heuristics to make the algorithm performance more robust.</p>
<ul>
<li><a name="fitness-shaping"></a>NES applies <strong>rank-based fitness shaping</strong>, that is to use the <em>rank</em> under monotonically increasing fitness values instead of using \(f(x)\) directly. Or it can be a function of the rank (“utility function”), which is considered as a free parameter of NES.</li>
<li>NES adopts <strong>adaptation sampling</strong> to adjust hyperparameters at run time. When changing \(\theta \to \theta’\), samples drawn from \(p_\theta\) are compared with samples from \(p_{\theta’}\) using [Mann-Whitney U-test(https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test)]; if there shows a positive or negative sign, the target hyperparameter decreases or increases by a multiplication constant. Note the score of a sample \(x’_i \sim p_{\theta’}(x)\) has importance sampling weights applied \(w_i’ = p_\theta(x) / p_{\theta’}(x)\).</li>
</ul>
<h2 id="applications-es-in-deep-reinforcement-learning">Applications: ES in Deep Reinforcement Learning</h2>
<h3 id="openai-es-for-rl">OpenAI ES for RL</h3>
<p>The concept of using evolutionary algorithms in reinforcement learning can be traced back <a href="https://arxiv.org/abs/1106.0221">long ago</a>, but only constrained to tabular RL due to computational limitations.</p>
<p>Inspired by <a href="#natural-evolution-strategies">NES</a>, researchers at OpenAI (<a href="https://arxiv.org/abs/1703.03864">Salimans, et al. 2017</a>) proposed to use NES as a gradient-free black-box optimizer to find optimal policy parameters \(\theta\) that maximizes the return function \(F(\theta)\). The key is to add Gaussian noise $\epsilon$ on the model parameter $\theta$ and then use the log-likelihood trick to write it as the gradient of the Gaussian pdf. Eventually only the noise term is left as a weighting scalar for measured performance.</p>
<p>Let’s say the current parameter value is \(\hat{\theta}\) (the added hat is to distinguish the value from the random variable \(\theta\)). The search distribution of \(\theta\) is designed to be an isotropic multivariate Gaussian with a mean \(\hat{\theta}\) and a fixed covariance matrix \(\sigma^2 I\),</p>
\[\theta \sim \mathcal{N}(\hat{\theta}, \sigma^2 I) \text{ equivalent to } \theta = \hat{\theta} + \sigma\epsilon, \epsilon \sim \mathcal{N}(0, I)\]
<p>The gradient for \(\theta\) update is:</p>
\[\begin{aligned}
& \nabla_\theta \mathbb{E}_{\theta\sim\mathcal{N}(\hat{\theta}, \sigma^2 I)} F(\theta) \\
&= \nabla_\theta \mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} F(\hat{\theta} + \sigma\epsilon) \\
&= \nabla_\theta \int_{\epsilon} p(\epsilon) F(\hat{\theta} + \sigma\epsilon) d\epsilon & \scriptstyle{\text{; Gaussian }p(\epsilon)=(2\pi)^{-\frac{n}{2}} \exp(-\frac{1}{2}\epsilon^\top\epsilon)} \\
&= \int_{\epsilon} p(\epsilon) \nabla_\epsilon \log p(\epsilon) \nabla_\theta \epsilon\;F(\hat{\theta} + \sigma\epsilon) d\epsilon & \scriptstyle{\text{; log-likelihood trick}}\\
&= \mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} [ \nabla_\epsilon \big(-\frac{1}{2}\epsilon^\top\epsilon\big) \nabla_\theta \big(\frac{\theta - \hat{\theta}}{\sigma}\big) F(\hat{\theta} + \sigma\epsilon) ] & \\
&= \mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} [ (-\epsilon) (\frac{1}{\sigma}) F(\hat{\theta} + \sigma\epsilon) ] & \\
&= \frac{1}{\sigma}\mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} [ \epsilon F(\hat{\theta} + \sigma\epsilon) ] & \scriptstyle{\text{; negative sign can be absorbed.}}
\end{aligned}\]
<p>In one generation, we can sample many \(epsilon_i, i=1,\dots,n\) and evaluate the fitness <em>in parallel</em>. One beautiful design is that no large model parameter needs to be shared. By only communicating the random seeds between workers, it is enough for the master node to do parameter update. This approach is later extended to adaptively learn a loss function; see my previous post on <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#meta-learning-the-loss-function">Evolved Policy Gradient</a>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/OpenAI-ES-algorithm.png" alt="ES for RL" /></p>
<p><em>Fig. 5. The algorithm for training a RL policy using evolution strategies. (Image source: <a href="https://arxiv.org/abs/1703.03864">ES-for-RL</a> paper)</em></p>
<p>To make the performance more robust, OpenAI ES adopts virtual batch normalization (BN with mini-batch used for calculating statistics fixed), mirror sampling (sampling a pair of \((-\epsilon, \epsilon)\) for evaluation), and <a href="#fitness-shaping">fitness shaping</a>.</p>
<h3 id="exploration-with-es">Exploration with ES</h3>
<p>Exploration (<a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#exploitation-vs-exploration">vs exploitation</a>) is an important topic in RL. The optimization direction in the ES algorithm <a href="TBA">above</a> is only extracted from the cumulative return \(F(\theta)\). Without explicit exploration, the agent might get trapped in a local optimum.</p>
<p>Novelty-Search ES (<strong>NS-ES</strong>; <a href="https://arxiv.org/abs/1712.06560">Conti et al, 2018</a>) encourages exploration by updating the parameter in the direction to maximize the <em>novelty</em> score. The novelty score depends on a domain-specific behavior characterization function \(b(\pi_\theta)\). The choice of \(b(\pi_\theta)\) is specific to the task and seems to be a bit arbitrary; for example, in the Humanoid locomotion task in the paper, \(b(\pi_\theta)\) is the final \((x,y)\) location of the agent.</p>
<ol>
<li>Every policy’s \(b(\pi_\theta)\) is pushed to an archive set \(\mathcal{A}\).</li>
<li>Novelty of a policy \(\pi_\theta\) is measured as the k-nearest neighbor score between \(b(\pi_\theta)\) and all other entries in \(\mathcal{A}\).
(The use case of the archive set sounds quite similar to <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#episodic-control">episodic memory</a>.)</li>
</ol>
\[N(\theta, \mathcal{A}) = \frac{1}{\lambda} \sum_{i=1}^\lambda \| b(\pi_\theta), b^\text{knn}_i \|_2
\text{, where }b^\text{knn}_i \in \text{kNN}(b(\pi_\theta), \mathcal{A})\]
<p>The ES optimization step relies on the novelty score instead of fitness:</p>
\[\nabla_\theta \mathbb{E}_{\theta\sim\mathcal{N}(\hat{\theta}, \sigma^2 I)} N(\theta, \mathcal{A})
= \frac{1}{\sigma}\mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} [ \epsilon N(\hat{\theta} + \sigma\epsilon, \mathcal{A}) ]\]
<p>NS-ES maintains a group of \(M\) independently trained agents (“meta-population”), \(\mathcal{M} = \{\theta_1, \dots, \theta_M \}\) and picks one to advance proportional to the novelty score. Eventually we select the best policy. This process is equivalent to ensembling; also see the same idea in <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#svpg">SVPG</a>.</p>
\[\begin{aligned}
m &\leftarrow \text{pick } i=1,\dots,M\text{ according to probability}\frac{N(\theta_i, \mathcal{A})}{\sum_{j=1}^M N(\theta_j, \mathcal{A})} \\
\theta_m^{(t+1)} &\leftarrow \theta_m^{(t)} + \alpha \frac{1}{\sigma}\sum_{i=1}^N \epsilon_i N(\theta^{(t)}_m + \epsilon_i, \mathcal{A}) \text{ where }\epsilon_i \sim \mathcal{N}(0, I)
\end{aligned}\]
<p>where \(N\) is the number of Gaussian perturbation noise vectors and \(\alpha\) is the learning rate.</p>
<p>NS-ES completely discards the reward function and only optimizes for novelty to avoid deceptive local optima. To incorporate the fitness back into the formula, another two variations are proposed.</p>
<p><strong>NSR-ES</strong>:</p>
\[\theta_m^{(t+1)} \leftarrow \theta_m^{(t)} + \alpha \frac{1}{\sigma}\sum_{i=1}^N \epsilon_i \frac{N(\theta^{(t)}_m + \epsilon_i, \mathcal{A}) + F(\theta^{(t)}_m + \epsilon_i)}{2}\]
<p><strong>NSRAdapt-ES (NSRA-ES)</strong>: the adaptive weighting parameter \(w = 1.0\) initially. We start decreasing \(w\) if performance stays flat for a number of generations. Then when the performance starts to increase, we stop decreasing \(w\) but increase it instead. In this way, fitness is preferred when the performance stops growing but novelty is preferred otherwise.</p>
\[\theta_m^{(t+1)} \leftarrow \theta_m^{(t)} + \alpha \frac{1}{\sigma}\sum_{i=1}^N \epsilon_i \big((1-w) N(\theta^{(t)}_m + \epsilon_i, \mathcal{A}) + w F(\theta^{(t)}_m + \epsilon_i)\big)\]
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NS-ES-experiments.png" alt="NS-ES Experiments" /></p>
<p><em>Fig. 6. (Left) The environment is Humanoid locomotion with a three-sided wall which plays a role as a deceptive trap to create local optimum. (Right) Experiments compare ES baseline and other variations that encourage exploration. (Image source: <a href="https://arxiv.org/abs/1712.06560">NS-ES</a> paper)</em></p>
<h3 id="cem-rl">CEM-RL</h3>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CEM-RL.png" alt="CEM-RL" /></p>
<p><em>Fig. 7. Architectures of the (a) CEM-RL and (b) <a href="https://papers.nips.cc/paper/7395-evolution-guided-policy-gradient-in-reinforcement-learning.pdf">ERL</a> algorithms (Image source: <a href="https://arxiv.org/abs/1810.01222">CEM-RL</a> paper)</em></p>
<p>The CEM-RL method (<a href="https://arxiv.org/abs/1810.01222">Pourchot & Sigaud, 2019</a>) combines Cross Entropy Method (CEM) with either <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#ddpg">DDPG</a> or <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#td3">TD3</a>. CEM here works pretty much the same as the simple Gaussian ES described <a href="#simple-gaussian-evolution-strategies">above</a> and therefore the same function can be replaced using CMA-ES. CEM-RL is built on the framework of <em>Evolutionary Reinforcement Learning</em> (<em>ERL</em>; <a href="https://papers.nips.cc/paper/7395-evolution-guided-policy-gradient-in-reinforcement-learning.pdf">Khadka & Tumer, 2018</a>) in which the standard EA algorithm selects and evolves a population of actors and the rollout experience generated in the process is then added into reply buffer for training both RL-actor and RL-critic networks.</p>
<p>Workflow:</p>
<ul>
<li>1) The mean actor of the CEM population is \(\pi_\mu\) is initialized with a random actor network.</li>
<li>2) The critic network \(Q\) is initialized too, which will be updated by DDPG/TD3.</li>
<li>3) Repeat until happy:
<ul>
<li>a. Sample a population of actors \(\sim \mathcal{N}(\pi_\mu, \Sigma)\).</li>
<li>b. Half of the population is evaluated. Their fitness scores are used as the cumulative reward \(R\) and added into replay buffer.</li>
<li>c. The other half are updated together with the critic.</li>
<li>d. The new \(\pi_mu\) and \(\Sigma\) is computed using top performing elite samples. <a href="#covariance-matrix-adaptation-evolution-strategies-cma-es">CMA-ES</a> can be used for parameter update too.</li>
</ul>
</li>
</ul>
<h2 id="extension-ea-in-deep-learning">Extension: EA in Deep Learning</h2>
<p>(This section is not on evolution strategies, but still an interesting and relevant reading.)</p>
<p>The <em>Evolutionary Algorithms</em> have been applied on many deep learning problems. POET (<a href="https://arxiv.org/abs/1901.01753">Wang et al, 2019</a>) is a framework based on EA and attempts to generate a variety of different tasks while the problems themselves are being solved. POET has been introduced in my <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#task-generation-by-domain-randomization">last post</a> on meta-RL. Evolutionary Reinforcement Learning (ERL) is another example; See Fig. 7 (b).</p>
<p>Below I would like to introduce two applications in more detail, <em>Population-Based Training (PBT)</em> and <em>Weight-Agnostic Neural Networks (WANN)</em>.</p>
<h3 id="hyperparameter-tuning-pbt">Hyperparameter Tuning: PBT</h3>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/PBT.png" alt="PBT" /></p>
<p><em>Fig. 8. Paradigms of comparing different ways of hyperparameter tuning. (Image source: <a href="https://arxiv.org/abs/1711.09846">PBT</a> paper)</em></p>
<p>Population-Based Training (<a href="https://arxiv.org/abs/1711.09846">Jaderberg, et al, 2017</a>), short for <strong>PBT</strong> applies EA on the problem of hyperparameter tuning. It jointly trains a population of models and corresponding hyperparameters for optimal performance.</p>
<p>PBT starts with a set of random candidates, each containing a pair of model weights initialization and hyperparameters, \(\{(\theta_i, h_i)\mid i=1, \dots, N\}\). Every sample is trained in parallel and asynchronously evaluates its own performance periodically. Whenever a member deems ready (i.e. after taking enough gradient update steps, or when the performance is good enough), it has a chance to be updated by comparing with the whole population:</p>
<ul>
<li><strong><code class="language-plaintext highlighter-rouge">exploit()</code></strong>: When this model is under-performing, the weights could be replaced with a better performing model.</li>
<li><strong><code class="language-plaintext highlighter-rouge">explore()</code></strong>: If the model weights are overwritten, <code class="language-plaintext highlighter-rouge">explore</code> step perturbs the hyperparameters with random noise.</li>
</ul>
<p>In this process, only promising model and hyperparameter pairs can survive and keep on evolving, achieving better utilization of computational resources.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/PBT-algorithm.png" alt="PBT Algorithm" /></p>
<p><em>Fig. 9. The algorithm of population-based training. (Image source: <a href="https://arxiv.org/abs/1711.09846">PBT</a> paper)</em></p>
<h3 id="network-topology-optimization-wann">Network Topology Optimization: WANN</h3>
<p><em>Weight Agnostic Neural</em> Networks (short for <strong>WANN</strong>; <a href="https://arxiv.org/abs/1906.04358">Gaier & Ha 2019</a>) experiments with searching for the smallest network topologies that can achieve the optimal performance without training the network weights. By not considering the best configuration of network weights, WANN puts much more emphasis on the architecture itself, making the focus different from <a href="http://openaccess.thecvf.com/content_cvpr_2018/papers/Zoph_Learning_Transferable_Architectures_CVPR_2018_paper.pdf">NAS</a>. WANN is heavily inspired by a classic genetic algorithm to evolve network topologies, called <em>NEAT</em> (“Neuroevolution of Augmenting Topologies”; <a href="http://nn.cs.utexas.edu/downloads/papers/stanley.gecco02_1.pdf">Stanley & Miikkulainen 2002</a>).</p>
<p>The workflow of WANN looks pretty much the same as standard GA:</p>
<ol>
<li>Initialize: Create a population of minimal networks.</li>
<li>Evaluation: Test with a range of <em>shared</em> weight values.</li>
<li>Rank and Selection: Rank by performance and complexity.</li>
<li>Mutation: Create new population by varying best networks.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/WANN-mutations.png" alt="Mutation operations in WANN" /></p>
<p><em>Fig. 10. mutation operations for searching for new network topologies in WANN (Image source: <a href="https://arxiv.org/abs/1906.04358">WANN</a> paper)</em></p>
<p>At the “evaluation” stage, all the network weights are set to be the same. In this way, WANN is actually searching for network that can be described with a minimal description length. In the “selection” stage, both the network connection and the model performance are considered.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/WANN-results.png" alt="WANN results" /></p>
<p><em>Fig. 11. Performance of WANN found network topologies on different RL tasks are compared with baseline FF networks commonly used in the literature. “Tuned Shared Weight” only requires adjusting one weight value. (Image source: <a href="https://arxiv.org/abs/1906.04358">WANN</a> paper)</em></p>
<p>As shown in Fig. 11, WANN results are evaluated with both random weights and shared weights (single weight). It is interesting that even when enforcing weight-sharing on all weights and tuning this single parameter, WANN can discover topologies that achieve non-trivial good performance.</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019ES,
title = "Evolution Strategies",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "https://lilianweng.github.io/lil-log/2019/09/05/evolution-strategies.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Nikolaus Hansen. <a href="https://arxiv.org/abs/1604.00772">“The CMA Evolution Strategy: A Tutorial”</a> arXiv preprint arXiv:1604.00772 (2016).</p>
<p>[2] Marc Toussaint. <a href="https://ipvs.informatik.uni-stuttgart.de/mlr/marc/teaching/13-Optimization/06-blackBoxOpt.pdf">Slides: “Introduction to Optimization”</a></p>
<p>[3] David Ha. <a href="http://blog.otoro.net/2017/10/29/visual-evolution-strategies/">“A Visual Guide to Evolution Strategies”</a> blog.otoro.net. Oct 2017.</p>
<p>[4] Daan Wierstra, et al. <a href="https://arxiv.org/abs/1106.4487">“Natural evolution strategies.”</a> IEEE World Congress on Computational Intelligence, 2008.</p>
<p>[5] Agustinus Kristiadi. <a href="https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/">“Natural Gradient Descent”</a> Mar 2018.</p>
<p>[6] Razvan Pascanu & Yoshua Bengio. <a href="https://arxiv.org/abs/1301.3584v7">“Revisiting Natural Gradient for Deep Networks.”</a> arXiv preprint arXiv:1301.3584 (2013).</p>
<p>[7] Tim Salimans, et al. <a href="https://arxiv.org/abs/1703.03864">“Evolution strategies as a scalable alternative to reinforcement learning.”</a> arXiv preprint arXiv:1703.03864 (2017).</p>
<p>[8] Edoardo Conti, et al. <a href="https://arxiv.org/abs/1712.06560">“Improving exploration in evolution strategies for deep reinforcement learning via a population of novelty-seeking agents.”</a> NIPS. 2018.</p>
<p>[9] Aloïs Pourchot & Olivier Sigaud. <a href="https://arxiv.org/abs/1810.01222">“CEM-RL: Combining evolutionary and gradient-based methods for policy search.”</a> ICLR 2019.</p>
<p>[10] Shauharda Khadka & Kagan Tumer. <a href="https://papers.nips.cc/paper/7395-evolution-guided-policy-gradient-in-reinforcement-learning.pdf">“Evolution-guided policy gradient in reinforcement learning.”</a> NIPS 2018.</p>
<p>[11] Max Jaderberg, et al. <a href="https://arxiv.org/abs/1711.09846">“Population based training of neural networks.”</a> arXiv preprint arXiv:1711.09846 (2017).</p>
<p>[12] Adam Gaier & David Ha. <a href="https://arxiv.org/abs/1906.04358">“Weight Agnostic Neural Networks.”</a> arXiv preprint arXiv:1906.04358 (2019).</p>Lilian WengGradient descent is not the only option when learning optimal model parameters. Evolution Strategies (ES) works out well in the cases where we don’t know the precise analytic form of an objective function or cannot compute the gradients directly. This post dives into several classic ES methods, as well as how ES can be used in deep reinforcement learning.Meta Reinforcement Learning2019-06-23T12:00:00+00:002019-06-23T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/06/23/meta-reinforcement-learning<blockquote>
<p>Meta-RL is meta-learning on reinforcement learning tasks. After trained over a distribution of tasks, the agent is able to solve a new task by developing a new RL algorithm with its internal activity dynamics. This post starts with the origin of meta-RL and then dives into three key components of meta-RL.</p>
</blockquote>
<!--more-->
<p>In my earlier post on <a href="/lil-log/2018/11/30/meta-learning.html">meta-learning</a>, the problem is mainly defined in the context of few-shot classification. Here I would like to explore more into cases when we try to “meta-learn” <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">Reinforcement Learning (RL)</a> tasks by developing an agent that can solve unseen tasks fast and efficiently.</p>
<p>To recap, a good meta-learning model is expected to generalize to new tasks or new environments that have never been encountered during training. The adaptation process, essentially a <em>mini learning session</em>, happens at test with limited exposure to the new configurations. Even without any explicit fine-tuning (no gradient backpropagation on trainable variables), the meta-learning model autonomously adjusts internal hidden states to learn.</p>
<p>Training RL algorithms can be notoriously difficult sometimes. If the meta-learning agent could become so smart that the distribution of solvable unseen tasks grows extremely broad, we are on track towards <a href="http://incompleteideas.net/IncIdeas/BitterLesson.html">general purpose methods</a> — essentially building a “brain” which would solve all kinds of RL problems without much human interference or manual feature engineering. Sounds amazing, right? 💖</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#on-the-origin-of-meta-rl" id="markdown-toc-on-the-origin-of-meta-rl">On the Origin of Meta-RL</a> <ul>
<li><a href="#back-in-2001" id="markdown-toc-back-in-2001">Back in 2001</a></li>
<li><a href="#proposal-in-2016" id="markdown-toc-proposal-in-2016">Proposal in 2016</a></li>
</ul>
</li>
<li><a href="#define-meta-rl" id="markdown-toc-define-meta-rl">Define Meta-RL</a> <ul>
<li><a href="#formulation" id="markdown-toc-formulation">Formulation</a></li>
<li><a href="#main-differences-from-rl" id="markdown-toc-main-differences-from-rl">Main Differences from RL</a></li>
<li><a href="#key-components" id="markdown-toc-key-components">Key Components</a></li>
</ul>
</li>
<li><a href="#meta-learning-algorithms-for-meta-rl" id="markdown-toc-meta-learning-algorithms-for-meta-rl">Meta-Learning Algorithms for Meta-RL</a> <ul>
<li><a href="#optimizing-model-weights-for-meta-learning" id="markdown-toc-optimizing-model-weights-for-meta-learning">Optimizing Model Weights for Meta-learning</a></li>
<li><a href="#meta-learning-hyperparameters" id="markdown-toc-meta-learning-hyperparameters">Meta-learning Hyperparameters</a></li>
<li><a href="#meta-learning-the-loss-function" id="markdown-toc-meta-learning-the-loss-function">Meta-learning the Loss Function</a></li>
<li><a href="#meta-learning-the-exploration-strategies" id="markdown-toc-meta-learning-the-exploration-strategies">Meta-learning the Exploration Strategies</a></li>
<li><a href="#episodic-control" id="markdown-toc-episodic-control">Episodic Control</a></li>
</ul>
</li>
<li><a href="#training-task-acquisition" id="markdown-toc-training-task-acquisition">Training Task Acquisition</a> <ul>
<li><a href="#task-generation-by-domain-randomization" id="markdown-toc-task-generation-by-domain-randomization">Task Generation by Domain Randomization</a></li>
<li><a href="#evolutionary-algorithm-on-environment-generation" id="markdown-toc-evolutionary-algorithm-on-environment-generation">Evolutionary Algorithm on Environment Generation</a></li>
<li><a href="#learning-with-random-rewards" id="markdown-toc-learning-with-random-rewards">Learning with Random Rewards</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="on-the-origin-of-meta-rl">On the Origin of Meta-RL</h2>
<h3 id="back-in-2001">Back in 2001</h3>
<p>I encountered a paper written in 2001 by <a href="http://snowedin.net/tmp/Hochreiter2001.pdf">Hochreiter et al.</a> when reading <a href="https://arxiv.org/pdf/1611.05763.pdf">Wang et al., 2016</a>. Although the idea was proposed for supervised learning, there are so many resemblances to the current approach to meta-RL.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/Hochreiter-meta-learning.png" alt="Hochreiter 2001" /></p>
<p><em>Fig. 1. The meta-learning system consists of the supervisory and the subordinate systems. The subordinate system is a recurrent neural network that takes as input both the observation at the current time step, \(x_t\) and the label at the last time step, \(y_{t-1}\). (Image source: <a href="http://snowedin.net/tmp/Hochreiter2001.pdf">Hochreiter et al., 2001</a>)</em></p>
<p>Hochreiter’s meta-learning model is a recurrent network with LSTM cell. LSTM is a good choice because it can internalize a history of inputs and tune its own weights effectively through <a href="https://en.wikipedia.org/wiki/Backpropagation_through_time">BPTT</a>. The training data contains \(K\) sequences and each sequence is consist of \(N\) samples generated by a target function \(f_k(.), k=1, \dots, K\),</p>
\[\{\text{input: }(\mathbf{x}^k_i, \mathbf{y}^k_{i-1}) \to \text{label: }\mathbf{y}^k_i\}_{i=1}^N
\text{ where }\mathbf{y}^k_i = f_k(\mathbf{x}^k_i)\]
<p>Noted that <em>the last label</em> \(\mathbf{y}^k_{i-1}\) is also provided as an auxiliary input so that the function can learn the presented mapping.</p>
<p>In the experiment of decoding two-dimensional quadratic functions, \(a x_1^2 + b x_2^2 + c x_1 x_2 + d x_1 + e x_2 + f\), with coefficients \(a\)-\(f\) are randomly sampled from [-1, 1], this meta-learning system was able to approximate the function after seeing only ~35 examples.</p>
<h3 id="proposal-in-2016">Proposal in 2016</h3>
<p>In the modern days of DL, <a href="https://arxiv.org/abs/1611.05763">Wang et al.</a> (2016) and <a href="https://arxiv.org/abs/1611.02779">Duan et al.</a> (2017) simultaneously proposed the very similar idea of <strong>Meta-RL</strong> (it is called <strong>RL^2</strong> in the second paper). A meta-RL model is trained over a distribution of MDPs, and at test time, it is able to learn to solve a new task quickly. The goal of meta-RL is ambitious, taking one step further towards general algorithms.</p>
<h2 id="define-meta-rl">Define Meta-RL</h2>
<p><em>Meta Reinforcement Learning</em>, in short, is to do <a href="/lil-log/2018/11/30/meta-learning.html">meta-learning</a> in the field of <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">reinforcement learning</a>. Usually the train and test tasks are different but drawn from the same family of problems; i.e., experiments in the papers included multi-armed bandit with different reward probabilities, mazes with different layouts, same robots but with different physical parameters in simulator, and many others.</p>
<h3 id="formulation">Formulation</h3>
<p>Let’s say we have a distribution of tasks, each formularized as an <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#markov-decision-processes">MDP</a> (Markov Decision Process), \(M_i \in \mathcal{M}\). An MDP is determined by a 4-tuple, \(M_i= \langle \mathcal{S}, \mathcal{A}, P_i, R_i \rangle\):</p>
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr>
<td>\(\mathcal{S}\)</td>
<td>A set of states.</td>
</tr>
<tr>
<td>\(\mathcal{A}\)</td>
<td>A set of actions.</td>
</tr>
<tr>
<td>\(P_i: \mathcal{S} \times \mathcal{A} \times \mathcal{S} \to \mathbb{R}_{+}\)</td>
<td>Transition probability function.</td>
</tr>
<tr>
<td>\(R_i: \mathcal{S} \times \mathcal{A} \to \mathbb{R}\)</td>
<td>Reward function.</td>
</tr>
</tbody>
</table>
<p>(RL^2 paper adds an extra parameter, horizon \(T\), into the MDP tuple to emphasize that each MDP should have a finite horizon.)</p>
<p>Note that common state \(\mathcal{S}\) and action space \(\mathcal{A}\) are used above, so that a (stochastic) policy: \(\pi_\theta: \mathcal{S} \times \mathcal{A} \to \mathbb{R}_{+}\) would get inputs compatible across different tasks. The test tasks are sampled from the same distribution \(\mathcal{M}\) or slightly modified version.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/meta-RL-illustration.png" alt="Illustration of meta-RL" /></p>
<p><em>Fig. 2. Illustration of meta-RL, containing two optimization loops. The outer loop samples a new environment in every iteration and adjusts parameters that determine the agent’s behavior. In the inner loop, the agent interacts with the environment and optimizes for the maximal reward. (Image source: <a href="https://www.cell.com/action/showPdf?pii=S1364-6613%2819%2930061-0">Botvinick, et al. 2019</a></em></p>
<h3 id="main-differences-from-rl">Main Differences from RL</h3>
<p>The overall configure of meta-RL is very similar to an ordinary RL algorithm, except that <strong>the last reward</strong> \(r_{t-1}\) and <strong>the last action</strong> \(a_{t-1}\) are also incorporated into the policy observation in addition to the current state \(s_t\).</p>
<ul>
<li>In RL: \(\pi_\theta(s_t) \to\) a distribution over \(\mathcal{A}\)</li>
<li>In meta-RL: \(\pi_\theta(a_{t-1}, r_{t-1}, s_t) \to\) a distribution over \(\mathcal{A}\)</li>
</ul>
<p>The intention of this design is to feed a history into the model so that the policy can internalize the dynamics between states, rewards, and actions in the current MDP and adjust its strategy accordingly. This is well aligned with the setup in <a href="#back-in-2001">Hochreiter’s system</a>. Both meta-RL and RL^2 implemented an LSTM policy and the LSTM’s hidden states serve as a <em>memory</em> for tracking characteristics of the trajectories. Because the policy is recurrent, there is no need to feed the last state as inputs explicitly.</p>
<p>The training procedure works as follows:</p>
<ol>
<li>Sample a new MDP, \(M_i \sim \mathcal{M}\);</li>
<li><strong>Reset the hidden state</strong> of the model;</li>
<li>Collect multiple trajectories and update the model weights;</li>
<li>Repeat from step 1.</li>
</ol>
<p style="width: 68%;" class="center"><img src="/lil-log/assets/images/L2RL.png" alt="L2RL" /></p>
<p><em>Fig. 3. In the meta-RL paper, different actor-critic architectures all use a recurrent model. Last reward and last action are additional inputs. The observation is fed into the LSTM either as a one-hot vector or as an embedding vector after passed through an encoder model. (Image source: <a href="https://arxiv.org/abs/1611.05763">Wang et al., 2016</a>)</em></p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RL_2.png" alt="RL^2" /></p>
<p><em>Fig. 4. As described in the RL^2 paper, illustration of the procedure of the model interacting with a series of MDPs in training time . (Image source: <a href="https://arxiv.org/abs/1611.02779">Duan et al., 2017</a>)</em></p>
<h3 id="key-components">Key Components</h3>
<p>There are three key components in Meta-RL:</p>
<blockquote>
<p>⭐ <strong>A Model with Memory</strong>
<br />
A recurrent neural network maintains a hidden state. Thus, it could acquire and memorize the knowledge about the current task by updating the hidden state during rollouts. Without memory, meta-RL would not work.</p>
</blockquote>
<blockquote>
<p>⭐ <strong>Meta-learning Algorithm</strong>
<br />
A meta-learning algorithm refers to how we can update the model weights to optimize for the purpose of solving an unseen task fast at test time. In both Meta-RL and RL^2 papers, the meta-learning algorithm is the ordinary gradient descent update of LSTM with hidden state reset between a switch of MDPs.</p>
</blockquote>
<blockquote>
<p>⭐ <strong>A Distribution of MDPs</strong>
<br />
While the agent is exposed to a variety of environments and tasks during training, it has to learn how to adapt to different MDPs.</p>
</blockquote>
<p>According to <a href="https://www.cell.com/action/showPdf?pii=S1364-6613%2819%2930061-0">Botvinick et al.</a> (2019), one source of slowness in RL training is <em>weak <a href="https://en.wikipedia.org/wiki/Inductive_bias">inductive bias</a></em> ( = “a set of assumptions that the learner uses to predict outputs given inputs that it has not encountered”). As a general ML rule, a learning algorithm with weak inductive bias will be able to master a wider range of variance, but usually, will be less sample-efficient. Therefore, to narrow down the hypotheses with stronger inductive biases help improve the learning speed.</p>
<p>In meta-RL, we impose certain types of inductive biases from the <em>task distribution</em> and store them in <em>memory</em>. Which inductive bias to adopt at test time depends on the <em>algorithm</em>. Together, these three key components depict a compelling view of meta-RL: Adjusting the weights of a recurrent network is slow but it allows the model to work out a new task fast with its own RL algorithm implemented in its internal activity dynamics.</p>
<p>Meta-RL interestingly and not very surprisingly matches the ideas in the <a href="https://arxiv.org/abs/1905.10985">AI-GAs</a> (“AI-Generating Algorithms”) paper by Jeff Clune (2019). He proposed that one efficient way towards building general AI is to make learning as automatic as possible. The AI-GAs approach involves three pillars: (1) meta-learning architectures, (2) meta-learning algorithms, and (3) automatically generated environments for effective learning.</p>
<hr />
<p>The topic of designing good recurrent network architectures is a bit too broad to be discussed here, so I will skip it. Next, let’s look further into another two components: meta-learning algorithms in the context of meta-RL and how to acquire a variety of training MDPs.</p>
<h2 id="meta-learning-algorithms-for-meta-rl">Meta-Learning Algorithms for Meta-RL</h2>
<p>My previous <a href="/lil-log/2018/11/30/meta-learning.html">post</a> on meta-learning has covered several classic meta-learning algorithms. Here I’m gonna include more related to RL.</p>
<h3 id="optimizing-model-weights-for-meta-learning">Optimizing Model Weights for Meta-learning</h3>
<p>Both MAML (<a href="https://arxiv.org/abs/1703.03400">Finn, et al. 2017</a>) and Reptile (<a href="https://arxiv.org/abs/1803.02999">Nichol et al., 2018</a>) are methods on updating model parameters in order to achieve good generalization performance on new tasks. See an earlier post <a href="/lil-log/2018/11/30/meta-learning.html#optimization-based">section</a> on MAML and Reptile.</p>
<h3 id="meta-learning-hyperparameters">Meta-learning Hyperparameters</h3>
<p>The <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#value-function">return</a> function in an RL problem, \(G_t^{(n)}\) or \(G_t^\lambda\), involves a few hyperparameters that are often set heuristically, like the discount factor <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#value-function">\(\gamma\)</a> and the bootstrapping parameter <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#combining-td-and-mc-learning">\(\lambda\)</a>.
Meta-gradient RL (<a href="http://papers.nips.cc/paper/7507-meta-gradient-reinforcement-learning.pdf">Xu et al., 2018</a>) considers them as <em>meta-parameters</em>, \(\eta=\{\gamma, \lambda \}\), that can be tuned and learned <em>online</em> while an agent is interacting with the environment. Therefore, the return becomes a function of \(\eta\) and dynamically adapts itself to a specific task over time.</p>
\[\begin{aligned}
G_\eta^{(n)}(\tau_t) &= R_{t+1} + \gamma R_{t+2} + \dots + \gamma^{n-1}R_{t+n} + \gamma^n v_\theta(s_{t+n}) & \scriptstyle{\text{; n-step return}} \\
G_\eta^{\lambda}(\tau_t) &= (1-\lambda) \sum_{n=1}^\infty \lambda^{n-1} G_\eta^{(n)} & \scriptstyle{\text{; λ-return, mixture of n-step returns}}
\end{aligned}\]
<p>During training, we would like to update the policy parameters with gradients as a function of all the information in hand, \(\theta' = \theta + f(\tau, \theta, \eta)\), where \(\theta\) are the current model weights, \(\tau\) is a sequence of trajectories, and \(\eta\) are the meta-parameters.</p>
<p>Meanwhile, let’s say we have a meta-objective function \(J(\tau, \theta, \eta)\) as a performance measure. The training process follows the principle of online cross-validation, using a sequence of consecutive experiences:</p>
<ol>
<li>Starting with parameter \(\theta\), the policy \(\pi_\theta\) is updated on the first batch of samples \(\tau\), resulting in \(\theta'\).</li>
<li>Then we continue running the policy \(\pi_{\theta'}\) to collect a new set of experiences \(\tau'\), just following \(\tau\) consecutively in time. The performance is measured as \(J(\tau', \theta', \bar{\eta})\) with a fixed meta-parameter \(\bar{\eta}\).</li>
<li>The gradient of meta-objective \(J(\tau', \theta', \bar{\eta})\) w.r.t. \(\eta\) is used to update \(\eta\):</li>
</ol>
\[\begin{aligned}
\Delta \eta
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \eta} \\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \frac{d\theta'}{d\eta} & \scriptstyle{\text{ ; single variable chain rule.}} \\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \frac{\partial (\theta + f(\tau, \theta, \eta))}{\partial\eta} \\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \Big(\frac{d\theta}{d\eta} + \frac{\partial f(\tau, \theta, \eta)}{\partial\theta}\frac{d\theta}{d\eta} + \frac{\partial f(\tau, \theta, \eta)}{\partial\eta}\frac{d\eta}{d\eta} \Big) & \scriptstyle{\text{; multivariable chain rule.}}\\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \Big( \color{red}{\big(\mathbf{I} + \frac{\partial f(\tau, \theta, \eta)}{\partial\theta}\big)}\frac{d\theta}{d\eta} + \frac{\partial f(\tau, \theta, \eta)}{\partial\eta}\Big) & \scriptstyle{\text{; secondary gradient term in red.}}
\end{aligned}\]
<p>where \(\beta\) is the learning rate for \(\eta\).</p>
<p>The meta-gradient RL algorithm simplifies the computation by setting the secondary gradient term to zero, \(\mathbf{I} + \partial g(\tau, \theta, \eta)/\partial\theta = 0\) — this choice prefers the immediate effect of the meta-parameters \(\eta\) on the parameters \(\theta\). Eventually we get:</p>
\[\Delta \eta = -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \frac{\partial f(\tau, \theta, \eta)}{\partial\eta}\]
<p>Experiments in the paper adopted the meta-objective function same as \(TD(\lambda)\) algorithm, minimizing the error between the approximated value function \(v_\theta(s)\) and the \(\lambda\)-return:</p>
\[\begin{aligned}
J(\tau, \theta, \eta) &= (G^\lambda_\eta(\tau) - v_\theta(s))^2 \\
J(\tau', \theta', \bar{\eta}) &= (G^\lambda_{\bar{\eta}}(\tau') - v_{\theta'}(s'))^2
\end{aligned}\]
<h3 id="meta-learning-the-loss-function">Meta-learning the Loss Function</h3>
<p>In policy gradient algorithms, the expected total reward is maximized by updating the policy parameters \(\theta\) in the direction of estimated gradient (<a href="https://arxiv.org/abs/1506.02438">Schulman et al., 2016</a>),</p>
\[g = \mathbb{E}[\sum_{t=0}^\infty \Psi_t \nabla_\theta \log \pi_\theta (a_t \mid s_t)]\]
<p>where the candidates for \(\Psi_t\) include the trajectory return \(G_t\), the Q value \(Q(s_t, a_t)\), or the advantage value \(A(s_t, a_t)\). The corresponding surrogate loss function for the policy gradient can be reverse-engineered:</p>
\[L_\text{pg} = \mathbb{E}[\sum_{t=0}^\infty \Psi_t \log \pi_\theta (a_t \mid s_t)]\]
<p>This loss function is a measure over a history of trajectories, \((s_0, a_0, r_0, \dots, s_t, a_t, r_t, \dots)\). <strong>Evolved Policy Gradient</strong> (<strong>EPG</strong>; <a href="https://papers.nips.cc/paper/7785-evolved-policy-gradients.pdf">Houthooft, et al, 2018</a>) takes a step further by defining the policy gradient loss function as a temporal convolution (1-D convolution) over the agent’s past experience, \(L_\phi\). The parameters \(\phi\) of the loss function network are evolved in a way that an agent can achieve higher returns.</p>
<p>Similar to many meta-learning algorithms, EPG has two optimization loops:</p>
<ul>
<li>In the internal loop, an agent learns to improve its policy \(\pi_\theta\).</li>
<li>In the outer loop, the model updates the parameters \(\phi\) of the loss function \(L_\phi\). Because there is no explicit way to write down a differentiable equation between the return and the loss, EPG turned to <a href="https://en.wikipedia.org/wiki/Evolution_strategy"><em>Evolutionary Strategies</em></a> (ES).</li>
</ul>
<p>A general idea is to train a population of \(N\) agents, each of them is trained with the loss function \(L_{\phi + \sigma \epsilon_i}\) parameterized with \(\phi\) added with a small Gaussian noise \(\epsilon_i \sim \mathcal{N}(0, \mathbf{I})\) of standard deviation \(\sigma\). During the inner loop’s training, EPG tracks a history of experience and updates the policy parameters according to the loss function \(L_{\phi + \sigma\epsilon_i}\) for each agent:</p>
\[\theta_i \leftarrow \theta - \alpha_\text{in} \nabla_\theta L_{\phi + \sigma \epsilon_i} (\pi_\theta, \tau_{t-K, \dots, t})\]
<p>where \(\alpha_\text{in}\) is the learning rate of the inner loop and \(\tau_{t-K, \dots, t}\) is a sequence of \(M\) transitions up to the current time step \(t\).</p>
<p>Once the inner loop policy is mature enough, the policy is evaluated by the mean return \(\bar{G}_{\phi+\sigma\epsilon_i}\) over multiple randomly sampled trajectories. Eventually, we are able to estimate the gradient of \(\phi\) according to <a href="http://blog.otoro.net/2017/10/29/visual-evolution-strategies/">NES</a> numerically (<a href="https://arxiv.org/abs/1703.03864">Salimans et al, 2017</a>). While repeating this process, both the policy parameters \(\theta\) and the loss function weights \(\phi\) are being updated simultaneously to achieve higher returns.</p>
\[\phi \leftarrow \phi + \alpha_\text{out} \frac{1}{\sigma N} \sum_{i=1}^N \epsilon_i G_{\phi+\sigma\epsilon_i}\]
<p>where \(\alpha_\text{out}\) is the learning rate of the outer loop.</p>
<p>In practice, the loss \(L_\phi\) is bootstrapped with an ordinary policy gradient (such as REINFORCE or PPO) surrogate loss \(L_\text{pg}\), \(\hat{L} = (1-\alpha) L_\phi + \alpha L_\text{pg}\). The weight \(\alpha\) is annealing from 1 to 0 gradually during training. At test time, the loss function parameter \(\phi\) stays fixed and the loss value is computed over a history of experience to update the policy parameters \(\theta\).</p>
<h3 id="meta-learning-the-exploration-strategies">Meta-learning the Exploration Strategies</h3>
<p>The <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#exploitation-vs-exploration">exploitation vs exploration</a> dilemma is a critical problem in RL. Common ways to do exploration include \(\epsilon\)-greedy, random noise on actions, or stochastic policy with built-in randomness on the action space.</p>
<p><strong>MAESN</strong> (<a href="http://papers.nips.cc/paper/7776-meta-reinforcement-learning-of-structured-exploration-strategies.pdf">Gupta et al, 2018</a>) is an algorithm to learn structured action noise from prior experience for better and more effective exploration. Simply adding random noise on actions cannot capture task-dependent or time-correlated exploration strategies. MAESN changes the policy to condition on a per-task random variable \(z_i \sim \mathcal{N}(\mu_i, \sigma_i)\), for \(i\)-th task \(M_i\), so we would have a policy \(a \sim \pi_\theta(a\mid s, z_i)\).
The latent variable \(z_i\) is sampled once and fixed during one episode. Intuitively, the latent variable determines one type of behavior (or skills) that should be explored more at the beginning of a rollout and the agent would adjust its actions accordingly. Both the policy parameters and latent space are optimized to maximize the total task rewards. In the meantime, the policy learns to make use of the latent variables for exploration.</p>
<p>In addition, the loss function includes a KL divergence between the learned latent variable and a unit Gaussian prior, \(D_\text{KL}(\mathcal{N}(\mu_i, \sigma_i)\|\mathcal{N}(0, \mathbf{I}))\). On one hand, it restricts the learned latent space not too far from a common prior. On the other hand, it creates the variational evidence lower bound (<a href="http://users.umiacs.umd.edu/~xyang35/files/understanding-variational-lower.pdf">ELBO</a>) for the reward function. Interestingly the paper found that \((\mu_i, \sigma_i)\) for each task are usually close to the prior at convergence.</p>
<p style="width: 82%;" class="center"><img src="/lil-log/assets/images/MAESN.png" alt="MAESN" /></p>
<p><em>Fig. 5. The policy is conditioned on a latent variable variable \(z_i \sim \mathcal{N}(\mu, \sigma)\) that is sampled once every episode. Each task has different hyperparameters for the latent variable distribution, \((\mu_i, \sigma_i)\) and they are optimized in the outer loop. (Image source: <a href="http://papers.nips.cc/paper/7776-meta-reinforcement-learning-of-structured-exploration-strategies.pdf">Gupta et al, 2018</a>)</em></p>
<h3 id="episodic-control">Episodic Control</h3>
<p>A major criticism of RL is on its sample inefficiency. A large number of samples and small learning steps are required for incremental parameter adjustment in RL in order to maximize generalization and avoid catastrophic forgetting of earlier learning (<a href="https://www.cell.com/trends/cognitive-sciences/fulltext/S1364-6613\(19\)30061-0">Botvinick et al., 2019</a>).</p>
<p><strong>Episodic control</strong> (<a href="http://papers.nips.cc/paper/3311-hippocampal-contributions-to-control-the-third-way.pdf">Lengyel & Dayan, 2008</a>) is proposed as a solution to avoid forgetting and improve generalization while training at a faster speed. It is partially inspired by hypotheses on instance-based <a href="https://en.wikipedia.org/wiki/Hippocampus">hippocampal</a> learning.</p>
<p>An <em>episodic memory</em> keeps explicit records of past events and uses these records directly as point of reference for making new decisions (i.e. just like <a href="/lil-log/2018/11/30/meta-learning.html#metric-based">metric-based</a> meta-learning). In <strong>MFEC</strong> (Model-Free Episodic Control; <a href="https://arxiv.org/abs/1606.04460">Blundell et al., 2016</a>), the memory is modeled as a big table, storing the state-action pair \((s, a)\) as key and the corresponding Q-value \(Q_\text{EC}(s, a)\) as value. When receiving a new observation \(s\), the Q value is estimated in an non-parametric way as the average Q-value of top \(k\) most similar samples:</p>
\[\hat{Q}_\text{EC}(s, a) =
\begin{cases}
Q_\text{EC}(s, a) & \text{if } (s,a) \in Q_\text{EC}, \\
\frac{1}{k} \sum_{i=1}^k Q(s^{(i)}, a) & \text{otherwise}
\end{cases}\]
<p>where \(s^{(i)}, i=1, \dots, k\) are top \(k\) states with smallest distances to the state \(s\). Then the action that yields the highest estimated Q value is selected. Then the memory table is updated according to the return received at \(s_t\):</p>
\[Q_\text{EC}(s, a) \leftarrow
\begin{cases}
\max\{Q_\text{EC}(s_t, a_t), G_t\} & \text{if } (s,a) \in Q_\text{EC}, \\
G_t & \text{otherwise}
\end{cases}\]
<p>As a tabular RL method, MFEC suffers from large memory consumption and a lack of ways to generalize among similar states. The first one can be fixed with an LRU cache. Inspired by <a href="/lil-log/2018/11/30/meta-learning.html#metric-based">metric-based</a> meta-learning, especially Matching Networks (<a href="http://papers.nips.cc/paper/6385-matching-networks-for-one-shot-learning.pdf">Vinyals et al., 2016</a>), the generalization problem is improved in a follow-up algorithm, <strong>NEC</strong> (Neural Episodic Control; <a href="https://arxiv.org/abs/1703.01988">Pritzel et al., 2016</a>).</p>
<p>The episodic memory in NEC is a Differentiable Neural Dictionary (<strong>DND</strong>), where the key is a convolutional embedding vector of input image pixels and the value stores estimated Q value. Given an inquiry key, the output is a weighted sum of values of top similar keys, where the weight is a normalized kernel measure between the query key and the selected key in the dictionary. This sounds like a hard <a href="/2018/06/24/attention-attention.html">attention</a> machanism.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/neural-episodic-control.png" alt="Neural episodic control" /></p>
<p><em>Fig. 6 Illustrations of episodic memory module in NEC and two operations on a differentiable neural dictionary. (Image source: <a href="https://arxiv.org/abs/1703.01988">Pritzel et al., 2016</a>)</em></p>
<p>Further, <strong>Episodic LSTM</strong> (<a href="https://arxiv.org/abs/1805.09692">Ritter et al., 2018</a>) enhances the basic LSTM architecture with a DND episodic memory, which stores task context embeddings as keys and the LSTM cell states as values. The stored hidden states are retrieved and added directly to the current cell state through the same gating mechanism within LSTM:</p>
<p style="width: 77%;" class="center"><img src="/lil-log/assets/images/episodic-LSTM.png" alt="Episodic LSTM" /></p>
<p><em>Fig. 7. Illustration of the episodic LSTM architecture. The additional structure of episodic memory is in bold. (Image source: <a href="https://arxiv.org/abs/1805.09692">Ritter et al., 2018</a>)</em></p>
\[\begin{aligned}
\mathbf{c}_t &= \mathbf{i}_t \circ \mathbf{c}_\text{in} + \mathbf{f}_t \circ \mathbf{c}_{t-1} + \color{green}{\mathbf{r}_t \circ \mathbf{c}_\text{ep}} &\\
\mathbf{i}_t &= \sigma(\mathbf{W}_{i} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) & \scriptstyle{\text{; input gate}} \\
\mathbf{f}_t &= \sigma(\mathbf{W}_{f} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) & \scriptstyle{\text{; forget gate}} \\
\color{green}{\mathbf{r}_t} & \color{green}{=} \color{green}{\sigma(\mathbf{W}_{r} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_r)} & \scriptstyle{\text{; reinstatement gate}}
\end{aligned}\]
<p>where \(\mathbf{c}_t\) and \(\mathbf{h}_t\) are hidden and cell state at time \(t\); \(\mathbf{i}_t\), \(\mathbf{f}_t\) and \(\mathbf{r}_t\) are input, forget and reinstatement gates, respectively; \(\mathbf{c}_\text{ep}\) is the retrieved cell state from episodic memory. The newly added episodic memory components are marked in green.</p>
<p>This architecture provides a shortcut to the prior experience through context-based retrieval. Meanwhile, explicitly saving the task-dependent experience in an external memory avoids forgetting. In the paper, all the experiments have manually designed context vectors. How to construct an effective and efficient format of task context embeddings for more free-formed tasks would be an interesting topic.</p>
<p>Overall the capacity of episodic control is limited by the complexity of the environment. It is very rare for an agent to repeatedly visit exactly the same states in a real-world task, so properly encoding the states is critical. The learned embedding space compresses the observation data into a lower dimension space and, in the meantime, two states being close in this space are expected to demand similar strategies.</p>
<h2 id="training-task-acquisition">Training Task Acquisition</h2>
<p>Among three key components, how to design a proper distribution of tasks is the less studied and probably the most specific one to meta-RL itself. As described <a href="#formulation">above</a>, each task is a MDP: \(M_i = \langle \mathcal{S}, \mathcal{A}, P_i, R_i \rangle \in \mathcal{M}\). We can build a distribution of MDPs by modifying:</p>
<ul>
<li>The <em>reward configuration</em>: Among different tasks, same behavior might get rewarded differently according to \(R_i\).</li>
<li>Or, the <em>environment</em>: The transition function \(P_i\) can be reshaped by initializing the environment with varying shifts between states.</li>
</ul>
<h3 id="task-generation-by-domain-randomization">Task Generation by Domain Randomization</h3>
<p>Randomizing parameters in a simulator is an easy way to obtain tasks with modified transition functions. If interested in learning further, check my last <a href="/lil-log/2019/05/05/domain-randomization.html">post</a> on <strong>domain randomization</strong>.</p>
<h3 id="evolutionary-algorithm-on-environment-generation">Evolutionary Algorithm on Environment Generation</h3>
<p><a href="https://en.wikipedia.org/wiki/Evolutionary_algorithm">Evolutionary algorithm</a> is a gradient-free heuristic-based optimization method, inspired by natural selection. A population of solutions follows a loop of evaluation, selection, reproduction, and mutation. Eventually, good solutions survive and thus get selected.</p>
<p><strong>POET</strong> (<a href="https://arxiv.org/abs/1901.01753">Wang et al, 2019</a>), a framework based on the evolutionary algorithm, attempts to generate tasks while the problems themselves are being solved. The implementation of POET is only specifically designed for a simple 2D <a href="https://gym.openai.com/envs/BipedalWalkerHardcore-v2/">bipedal walker</a> environment but points out an interesting direction. It is noteworthy that the evolutionary algorithm has had some compelling applications in Deep Learning like <a href="#meta-learning-the-loss-function">EPG</a> and PBT (Population-Based Training; <a href="https://arxiv.org/abs/1711.09846"> Jaderberg et al, 2017</a>).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/POET.png" alt="POET" /></p>
<p><em>Fig. 8. An example bipedal walking environment (top) and an overview of POET (bottom). (Image source: <a href="https://eng.uber.com/poet-open-ended-deep-learning/">POET blog post</a>)</em></p>
<p>The 2D bipedal walking environment is evolving: from a simple flat surface to a much more difficult trail with potential gaps, stumps, and rough terrains. POET pairs the generation of environmental challenges and the optimization of agents together so as to (a) select agents that can resolve current challenges and (b) evolve environments to be solvable. The algorithm maintains a list of <em>environment-agent pairs</em> and repeats the following:</p>
<ol>
<li><em>Mutation</em>: Generate new environments from currently active environments. Note that here types of mutation operations are created just for bipedal walker and a new environment would demand a new set of configurations.</li>
<li><em>Optimization</em>: Train paired agents within their respective environments.</li>
<li><em>Selection</em>: Periodically attempt to transfer current agents from one environment to another. Copy and update the best performing agent for every environment. The intuition is that skills learned in one environment might be helpful for a different environment.</li>
</ol>
<p>The procedure above is quite similar to <a href="https://arxiv.org/abs/1711.09846">PBT</a>, but PBT mutates and evolves hyperparameters instead. To some extent, POET is doing <a href="/lil-log/2019/05/05/domain-randomization.html">domain randomization</a>, as all the gaps, stumps and terrain roughness are controlled by some randomization probability parameters. Different from DR, the agents are not exposed to a fully randomized difficult environment all at once, but instead they are learning gradually with a curriculum configured by the evolutionary algorithm.</p>
<h3 id="learning-with-random-rewards">Learning with Random Rewards</h3>
<p>An MDP without a reward function \(R\) is known as a <em>Controlled Markov process</em> (CMP). Given a predefined CMP, \(\langle \mathcal{S}, \mathcal{A}, P\rangle\), we can acquire a variety of tasks by generating a collection of reward functions \(\mathcal{R}\) that encourage the training of an effective meta-learning policy.</p>
<p><a href="https://arxiv.org/abs/1806.04640">Gupta et al. (2018)</a> proposed two unsupervised approaches for growing the task distribution in the context of CMP. Assuming there is an underlying latent variable \(z \sim p(z)\) associated with every task, it parameterizes/determines a reward function: \(r_z(s) = \log D(z|s)\), where a “discriminator” function \(D(.)\) is used to extract the latent variable from the state. The paper described two ways to construct a discriminator function:</p>
<ul>
<li>Sample random weights \(\phi_\text{rand}\) of the discriminator, \(D_{\phi_\text{rand}}(z \mid s)\).</li>
<li>Learn a discriminator function to encourage diversity-driven exploration. This method is introduced in more details in another sister paper “DIAYN” (<a href="https://arxiv.org/abs/1802.06070">Eysenbach et al., 2018</a>).</li>
</ul>
<p>DIAYN, short for “Diversity is all you need”, is a framework to encourage a policy to learn useful skills without a reward function. It explicitly models the latent variable \(z\) as a <em>skill</em> embedding and makes the policy conditioned on \(z\) in addition to state \(s\), \(\pi_\theta(a \mid s, z)\). (Ok, this part is same as <a href="#meta-learning-the-exploration-strategies">MAESN</a> unsurprisingly, as the papers are from the same group.) The design of DIAYN is motivated by a few hypotheses:</p>
<ul>
<li>Skills should be diverse and lead to visitations of different states. → maximize the mutual information between states and skills, \(I(S; Z)\)</li>
<li>Skills should be distinguishable by states, not actions. → minimize the mutual information between actions and skills, conditioned on states \(I(A; Z \mid S)\)</li>
</ul>
<p>The objective function to maximize is as follows, where the policy entropy is also added to encourage diversity:</p>
\[\begin{aligned}
\mathcal{F}(\theta)
&= I(S; Z) + H[A \mid S] - I(A; Z \mid S) & \\
&= (H(Z) - H(Z \mid S)) + H[A \mid S] - (H[A\mid S] - H[A\mid S, Z]) & \\
&= H[A\mid S, Z] \color{green}{- H(Z \mid S) + H(Z)} & \\
&= H[A\mid S, Z] + \mathbb{E}_{z\sim p(z), s\sim\rho(s)}[\log p(z \mid s)] - \mathbb{E}_{z\sim p(z)}[\log p(z)] & \scriptstyle{\text{; can infer skills from states & p(z) is diverse.}} \\
&\ge H[A\mid S, Z] + \mathbb{E}_{z\sim p(z), s\sim\rho(s)}[\color{red}{\log D_\phi(z \mid s) - \log p(z)}] & \scriptstyle{\text{; according to Jensen's inequality; "pseudo-reward" in red.}}
\end{aligned}\]
<p>where \(I(.)\) is mutual information and \(H[.]\) is entropy measure. We cannot integrate all states to compute \(p(z \mid s)\), so approximate it with \(D_\phi(z \mid s)\) — that is the diversity-driven discriminator function.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DIAYN.png" alt="DIAYN" /></p>
<p><em>Fig. 9. DIAYN Algorithm. (Image source: <a href="https://arxiv.org/abs/1802.06070">Eysenbach et al., 2019</a>)</em></p>
<p>Once the discriminator function is learned, sampling a new MDP for training is strainght-forward: First, sample a latent variable, \(z \sim p(z)\) and construct a reward function \(r_z(s) = \log(D(z \vert s))\). Pairing the reward function with a predefined CMP creates a new MDP.</p>
<!--
---
So far, experiments of meta-RL are still limited to a collection of very similar tasks, originated from the same family; such as multi-armed bandit with different reward probabilities, mazes with different layouts, or same robots but with different physical parameters in simulator. I'm looking forward to more research demonstrating the power of meta-RL over a more diverse set of tasks.
-->
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019metaRL,
title = "Meta Reinforcement Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/06/23/meta-reinforcement-learning.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Richard S. Sutton. <a href="http://incompleteideas.net/IncIdeas/BitterLesson.html">“The Bitter Lesson.”</a> March 13, 2019.</p>
<p>[2] Sepp Hochreiter, A. Steven Younger, and Peter R. Conwell. <a href="http://snowedin.net/tmp/Hochreiter2001.pdf">“Learning to learn using gradient descent.”</a> Intl. Conf. on Artificial Neural Networks. 2001.</p>
<p>[3] Jane X Wang, et al. <a href="https://arxiv.org/abs/1611.05763">“Learning to reinforcement learn.”</a> arXiv preprint arXiv:1611.05763 (2016).</p>
<p>[4] Yan Duan, et al. <a href="https://arxiv.org/abs/1611.02779">“RL $^ 2$: Fast Reinforcement Learning via Slow Reinforcement Learning.”</a> ICLR 2017.</p>
<p>[5] Matthew Botvinick, et al. <a href="https://www.cell.com/trends/cognitive-sciences/fulltext/S1364-6613\(19\)30061-0">“Reinforcement Learning, Fast and Slow”</a> Cell Review, Volume 23, Issue 5, P408-422, May 01, 2019.</p>
<p>[6] Jeff Clune. <a href="https://arxiv.org/abs/1905.10985">“AI-GAs: AI-generating algorithms, an alternate paradigm for producing general artificial intelligence”</a> arXiv preprint arXiv:1905.10985 (2019).</p>
<p>[7] Zhongwen Xu, et al. <a href="http://papers.nips.cc/paper/7507-meta-gradient-reinforcement-learning.pdf">“Meta-Gradient Reinforcement Learning”</a> NIPS 2018.</p>
<p>[8] Rein Houthooft, et al. <a href="https://papers.nips.cc/paper/7785-evolved-policy-gradients.pdf">“Evolved Policy Gradients.”</a> NIPS 2018.</p>
<p>[9] Tim Salimans, et al. <a href="https://arxiv.org/abs/1703.03864">“Evolution strategies as a scalable alternative to reinforcement learning.”</a> arXiv preprint arXiv:1703.03864 (2017).</p>
<p>[10] Abhishek Gupta, et al. <a href="http://papers.nips.cc/paper/7776-meta-reinforcement-learning-of-structured-exploration-strategies.pdf">“Meta-Reinforcement Learning of Structured Exploration Strategies.”</a> NIPS 2018.</p>
<p>[11] Alexander Pritzel, et al. <a href="https://arxiv.org/abs/1703.01988">“Neural episodic control.”</a> Proc. Intl. Conf. on Machine Learning, Volume 70, 2017.</p>
<p>[12] Charles Blundell, et al. <a href="https://arxiv.org/abs/1606.04460">“Model-free episodic control.”</a> arXiv preprint arXiv:1606.04460 (2016).</p>
<p>[13] Samuel Ritter, et al. <a href="https://arxiv.org/abs/1805.09692">“Been there, done that: Meta-learning with episodic recall.”</a> ICML, 2018.</p>
<p>[14] Rui Wang et al. <a href="https://arxiv.org/abs/1901.01753">“Paired Open-Ended Trailblazer (POET): Endlessly Generating Increasingly Complex and Diverse Learning Environments and Their Solutions”</a> arXiv preprint arXiv:1901.01753 (2019).</p>
<p>[15] Uber Engineering Blog: <a href="https://eng.uber.com/poet-open-ended-deep-learning/">“POET: Endlessly Generating Increasingly Complex and Diverse Learning Environments and their Solutions through the Paired Open-Ended Trailblazer.”</a> Jan 8, 2019.</p>
<p>[16] Abhishek Gupta, et al.<a href="https://arxiv.org/abs/1806.04640">“Unsupervised meta-learning for Reinforcement Learning”</a> arXiv preprint arXiv:1806.04640 (2018).</p>
<p>[17] Eysenbach, Benjamin, et al. <a href="https://arxiv.org/abs/1802.06070">“Diversity is all you need: Learning skills without a reward function.”</a> ICLR 2019.</p>
<p>[18] Max Jaderberg, et al. <a href="https://arxiv.org/abs/1711.09846">“Population Based Training of Neural Networks.”</a> arXiv preprint arXiv:1711.09846 (2017).</p>Lilian WengMeta-RL is meta-learning on reinforcement learning tasks. After trained over a distribution of tasks, the agent is able to solve a new task by developing a new RL algorithm with its internal activity dynamics. This post starts with the origin of meta-RL and then dives into three key components of meta-RL.Domain Randomization for Sim2Real Transfer2019-05-05T00:00:00+00:002019-05-05T00:00:00+00:00https://lilianweng.github.io/lil-log/2019/05/05/domain-randomization<blockquote>
<p>If a model or policy is mainly trained in a simulator but expected to work on a real robot, it would surely face the sim2real gap. <em>Domain Randomization</em> (DR) is a simple but powerful idea of closing this gap by randomizing properties of the training environment.</p>
</blockquote>
<!--more-->
<p>In Robotics, one of the hardest problems is how to make your model transfer to the real world. Due to the sample inefficiency of deep RL algorithms and the cost of data collection on real robots, we often need to train models in a simulator which theoretically provides an infinite amount of data. However, the reality gap between the simulator and the physical world often leads to failure when working with physical robots. The gap is triggered by an inconsistency between physical parameters (i.e. friction, kp, damping, mass, density) and, more fatally, the incorrect physical modeling (i.e. collision between soft surfaces).</p>
<p>To close the sim2real gap, we need to improve the simulator and make it closer to reality. A couple of approaches:</p>
<ul>
<li><strong>System identification</strong>
<ul>
<li><em>System identification</em> is to build a mathematical model for a physical system; in the context of RL, the mathematical model is the simulator. To make the simulator more realistic, careful calibration is necessary.</li>
<li>Unfortunately, calibration is expensive. Furthermore, many physical parameters of the same machine might vary significantly due to temperature, humidity, positioning or its wear-and-tear in time.</li>
</ul>
</li>
<li><strong>Domain adaptation</strong>
<ul>
<li><em>Domain adaptation (DA)</em> refers to a set of transfer learning techniques developed to update the data distribution in sim to match the real one through a mapping or regularization enforced by the task model.</li>
<li>Many DA models, especially for image classification or end-to-end image-based RL task, are built on adversarial loss or <a href="/lil-log/2017/08/20/from-GAN-to-WGAN.html">GAN</a>.</li>
</ul>
</li>
<li><strong>Domain randomization</strong>
<ul>
<li>With <em>domain randomization (DR)</em>, we are able to create a variety of simulated environments with randomized properties and train a model that works across all of them.</li>
<li>Likely this model can adapt to the real-world environment, as the real system is expected to be one sample in that rich distribution of training variations.</li>
</ul>
</li>
</ul>
<p>Both DA and DR are unsupervised. Compared to DA which requires a decent amount of real data samples to capture the distribution, DR may need <em>only a little or no</em> real data. DR is the focus of this post.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/sim2real-transfer.png" alt="Approaches for sim2real transfer" /></p>
<p><em>Fig. 1. Conceptual illustrations of three approaches for sim2real transfer.</em></p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#what-is-domain-randomization" id="markdown-toc-what-is-domain-randomization">What is Domain Randomization?</a></li>
<li><a href="#uniform-domain-randomization" id="markdown-toc-uniform-domain-randomization">Uniform Domain Randomization</a></li>
<li><a href="#why-does-domain-randomization-work" id="markdown-toc-why-does-domain-randomization-work">Why does Domain Randomization Work?</a> <ul>
<li><a href="#dr-as-optimization" id="markdown-toc-dr-as-optimization">DR as Optimization</a></li>
<li><a href="#dr-as-meta-learning" id="markdown-toc-dr-as-meta-learning">DR as Meta-Learning</a></li>
</ul>
</li>
<li><a href="#guided-domain-randomization" id="markdown-toc-guided-domain-randomization">Guided Domain Randomization</a> <ul>
<li><a href="#optimization-for-task-performance" id="markdown-toc-optimization-for-task-performance">Optimization for Task Performance</a></li>
<li><a href="#match-real-data-distribution" id="markdown-toc-match-real-data-distribution">Match Real Data Distribution</a></li>
<li><a href="#guided-by-data-in-simulator" id="markdown-toc-guided-by-data-in-simulator">Guided by Data in Simulator</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="what-is-domain-randomization">What is Domain Randomization?</h2>
<p>To make the definition more general, let us call the environment that we have full access to (i.e. simulator) <strong>source domain</strong> and the environment that we would like to transfer the model to <strong>target domain</strong> (i.e. physical world). Training happens in the source domain. We can control a set of \(N\) randomization parameters in the source domain \(e_\xi\) with a configuration \(\xi\), sampled from a randomization space, \(\xi \in \Xi \subset \mathbb{R}^N\).</p>
<p>During policy training, episodes are collected from source domain with randomization applied. Thus the policy is exposed to a variety of environments and learns to generalize. The policy parameter \(\theta\) is trained to maximize the expected reward \(R(.)\) average across a distribution of configurations:</p>
\[\theta^* = \arg\max_\theta \mathbb{E}_{\xi \sim \Xi} [\mathbb{E}_{\pi_\theta, \tau \sim e_\xi} [R(\tau)]]\]
<p>where \(\tau_\xi\) is a trajectory collected in source domain randomized with \(\xi\). In a way, <em>“discrepancies between the source and target domains are modeled as variability in the source domain.”</em> (quote from <a href="https://arxiv.org/abs/1710.06537">Peng et al. 2018</a>).</p>
<h2 id="uniform-domain-randomization">Uniform Domain Randomization</h2>
<p>In the original form of DR (<a href="https://arxiv.org/abs/1703.06907">Tobin et al, 2017</a>; <a href="https://arxiv.org/pdf/1611.04201.pdf">Sadeghi et al. 2016</a>), each randomization parameter \(\xi_i\) is bounded by an interval, \(\xi_i \in [\xi_i^\text{low}, \xi_i^\text{high}], i=1,\dots,N\) and each parameter is uniformly sampled within the range.</p>
<p>The randomization parameters can control appearances of the scene, including but not limited to the followings (see Fig. 2). A model trained on simulated and randomized images is able to transfer to real non-randomized images.</p>
<ul>
<li>Position, shape, and color of objects,</li>
<li>Material texture,</li>
<li>Lighting condition,</li>
<li>Random noise added to images,</li>
<li>Position, orientation, and field of view of the camera in the simulator.</li>
</ul>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/DR.png" alt="Domain Randomization" /></p>
<p><em>Fig. 2. Images captured in the training environment are randomized. (Image source: <a href="https://arxiv.org/abs/1703.06907">Tobin et al, 2017</a>)</em></p>
<p>Physical dynamics in the simulator can also be randomized (<a href="https://arxiv.org/abs/1710.06537">Peng et al. 2018</a>). Studies have showed that a <em>recurrent</em> policy can adapt to different physical dynamics including the partially observable reality. A set of physical dynamics features include but are not limited to:</p>
<ul>
<li>Mass and dimensions of objects,</li>
<li>Mass and dimensions of robot bodies,</li>
<li>Damping, kp, friction of the joints,</li>
<li>Gains for the PID controller (P term),</li>
<li>Joint limit,</li>
<li>Action delay,</li>
<li>Observation noise.</li>
</ul>
<p>With visual and dynamics DR, at OpenAI Robotics, we were able to learn a policy that works on real dexterous robot hand (<a href="https://arxiv.org/abs/1808.00177">OpenAI, 2018</a>). Our manipulation task is to teach the robot hand to rotate an object continously to achieve 50 successive random target orientations. The sim2real gap in this task is very large, due to (a) a high number of simultaneous contacts between the robot and the object and (b) imperfect simulation of object collision and other motions. At first, the policy could barely survive for more than 5 seconds without dropping the object. But with the help of DR, the policy evolved to work surprisingly well in reality eventually.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/DKe8FumoD4E" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen=""></iframe>
<h2 id="why-does-domain-randomization-work">Why does Domain Randomization Work?</h2>
<p>Now you may ask, why does domain randomization work so well? The idea sounds really simple. Here are two non-exclusive explanations I found most convincing.</p>
<h3 id="dr-as-optimization">DR as Optimization</h3>
<p>One idea (<a href="https://arxiv.org/abs/1903.11774">Vuong, et al, 2019</a>) is to view learning randomization parameters in DR as a <em>bilevel optimization</em>. Assuming we have access to the real environment \(e_\text{real}\) and the randomization config is sampled from a distribution parameterized by \(\phi\), \(\xi \sim P_\phi(\xi)\), we would like to learn a distribution on which a policy \(\pi_\theta\) is trained on can achieve maximal performance in \(e_\text{real}\):</p>
\[\begin{aligned}
&\phi^* = \arg\min_{\phi} \mathcal{L}(\pi_{\theta^*(\phi)}; e_\text{real}) \\
\text{where } &\theta^*(\phi) = \arg\min_\theta \mathbb{E}_{\xi \sim P_\phi(\xi)}[\mathcal{L}(\pi_\theta; e_\xi)]
\end{aligned}\]
<p>where \(\mathcal{L}(\pi; e)\) is the loss function of policy \(\pi\) evaluated in the environment \(e\).</p>
<p>Although randomization ranges are hand-picked in uniform DR, it often involves domain knowledge and a couple rounds of trial-and-error adjustment based on the transfer performance. Essentially this is a manual optimization process on tuning \(\phi\) for the optimal \(\mathcal{L}(\pi_{\theta^*(\phi)}; e_\text{real})\).</p>
<p>Guided domain randomization in the next section is largely inspired by this view, aiming to do bilevel optimization and learn the best parameter distribution automatically.</p>
<h3 id="dr-as-meta-learning">DR as Meta-Learning</h3>
<p>In our learning dexterity project (<a href="https://arxiv.org/abs/1808.00177">OpenAI, 2018</a>), we trained an LSTM policy to generalize across different environmental dynamics. We observed that once a robot achieved the first rotation, the time it needed for the following successes was much shorter. Also, a FF policy without memory was found not able to transfer to a physical robot. Both are evidence of the policy dynamically learning and adapting to a new environment.</p>
<p>In some ways, domain randomization composes a collection of different tasks. Memory in the recurrent network empowers the policy to achieve <a href="/lil-log/2018/11/30/meta-learning.html"><em>meta-learning</em></a> across tasks and further work on a real-world setting.</p>
<h2 id="guided-domain-randomization">Guided Domain Randomization</h2>
<p>The vanilla DR assumes no access to the real data, and thus the randomization config is sampled as broadly and uniformly as possible in sim, hoping that the real environment could be covered under this broad distribution. It is reasonable to think of a more sophisticated strategy — replacing uniform sampling with guidance from <em>task performance</em>, <em>real data</em>, or <em>simulator</em>.</p>
<p>One motivation for guided DR is to save computation resources by avoiding training models in unrealistic environments. Another is to avoid infeasible solutions that might arise from overly wide randomization distributions and thus might hinder successful policy learning.</p>
<h3 id="optimization-for-task-performance">Optimization for Task Performance</h3>
<p>Say we train a family of policies with different randomization parameters \(\xi \sim P_\phi(\xi)\), where \(P_\xi\) is the distribution for \(\xi\) parameterized by \(\phi\). Later we decide to try every one of them on the downstream task in the target domain (i.e. control a robot in reality or evaluate on a validation set) to collect feedback. This feedback tells us how good a configuration \(\xi\) is and provides signals for optimizing \(\phi\).</p>
<p>Inspired by <a href="https://ai.google/research/pubs/pub45826">NAS</a>, <strong>AutoAugment</strong> (<a href="https://arxiv.org/abs/1805.09501">Cubuk, et al. 2018</a>) frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem. Note that AutoAugment is not proposed for sim2real transfer, but falls in the bucket of DR guided by task performance. Individual augmentation configuration is tested on the evaluation set and the performance improvement is used as a reward to train a PPO policy. This policy outputs different augmentation strategies for different datasets; for example, for CIFAR-10 AutoAugment mostly picks color-based transformations, while ImageNet prefers geometric based.</p>
<p><a href="https://arxiv.org/abs/1810.02513">Ruiz (2019)</a> considered the <em>task feedback</em> as <em>reward</em> in RL problem and proposed a RL-based method, named “learning to simulate”, for adjusting \(\xi\). A policy is trained to predict \(\xi\) using performance metrics on the validation data of the main task as rewards, which is modeled as a multivariate Gaussian. Overall the idea is similar to AutoAugment, applying NAS on data generation. According to their experiments, even if the main task model is not converged, it still can provide a reasonable signal to the data generation policy.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/learning-to-simulate.png" alt="Learning to simulate" /></p>
<p><em>Fig. 3. An overview of the “learning to simulate” approach. (Image source: <a href="https://arxiv.org/abs/1810.02513">Ruiz (2019)</a>)</em></p>
<p>Evolutionary algorithm is another way to go, where the <em>feedback</em> is treated as <em>fitness</em> for guiding evolution (<a href="https://openreview.net/forum?id=H1g6osRcFQ">Yu et al, 2019</a>). In this study, they used <a href="https://en.wikipedia.org/wiki/CMA-ES">CMA-ES</a> (covariance matrix adaptation evolution strategy) while fitness is the performance of a \(\xi\)-conditional policy in target environment. In the appendix, they compared CMA-ES with other ways of modeling the dynamics of \(\xi\), including Bayesian optimization or a neural network. The main claim was those methods are not as stable or sample efficient as CMA-ES. Interestly, when modeling \(P(\xi)\) as a neural network, LSTM is found to notably outperform FF.</p>
<p>Some believe that sim2real gap is a combination of appearance gap and content gap; i.e. most GAN-inspired DA models focus on appearance gap. <strong>Meta-Sim</strong> (<a href="https://arxiv.org/abs/1904.11621">Kar, et al. 2019</a>) aims to close the content gap by generating task-specific synthetic datasets. Meta-Sim uses self-driving car training as an example and thus the scene could be very complicated. In this case, the synthetic scenes are parameterized by a hierarchy of objects with properties (i.e., location, color) as well as relationships between objects. The hierarchy is specified by a probabilistic scene grammar akin to structure domain randomization (<strong>SDR</strong>; <a href="https://arxiv.org/abs/1810.10093">Prakash et al., 2018</a>) and it is assumed to be known beforehand. A model \(G\) is trained to augment the distribution of scene properties \(s\) by following:</p>
<ol>
<li>Learn the prior first: pre-train \(G\) to learn the identity function \(G(s) = s\).</li>
<li>Minimize MMD loss between the real and sim data distributions. This involves backpropagation through non-differentiable renderer. The paper computes it numerically by perturbing the attributes of \(G(s)\).</li>
<li>Minimize REINFORCE task loss when trained on synthetic data but evaluated on real data. Again, very similar to AutoAugment.</li>
</ol>
<p>Unfortunately, this family of methods are not suitable for sim2real case. Either an RL policy or an EA model requires a large number of real samples. And it is really expensive to include real-time feedback collection on a physical robot into the training loop. Whether you want to trade less computation resource for real data collection would depend on your task.</p>
<h3 id="match-real-data-distribution">Match Real Data Distribution</h3>
<p>Using real data to guide domain randomization feels a lot like doing system identification or DA. The core idea behind DA is to improve the synthetic data to match the real data distribution. In the case of real-data-guided DR, we would like to learn the randomization parameters \(\xi\) that bring the state distribution in simulator close to the state distribution in the real world.</p>
<p>The <strong>SimOpt</strong> model (<a href="https://arxiv.org/abs/1810.05687">Chebotar et al, 2019</a>) is trained under an initial randomization distribution \(P_\phi(\xi)\) first, getting a policy \(\pi_{\theta, P_\phi}\). Then this policy is deployed on both simulator and physical robot to collect trajectories \(\tau_\xi\) and \(\tau_\text{real}\) respectively. The optimization objective is to minimize the discrepancy between sim and real trajectories:</p>
\[\phi^* = \arg\min_{\phi}\mathbb{E}_{\xi \sim P_\phi(\xi)} [\mathbb{E}_{\pi_{\theta, P_\phi}} [D(\tau_\text{sim}, \tau_\text{real})]]\]
<p>where \(D(.)\) is a trajectory-based discrepancy measure. Like the “Learning to simulate” paper, SimOpt also has to solve the tricky problem of how to propagate gradient through non-differentiable simulator. It used a method called <a href="https://www.aaai.org/ocs/index.php/AAAI/AAAI10/paper/viewFile/1851/2264">relative entropy policy search</a>, see paper for more details.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/simopt.png" alt="SimOpt" /></p>
<p><em>Fig. 4. An overview of the SimOpt framework. (Image source: <a href="https://arxiv.org/abs/1810.05687">Chebotar et al, 2019</a>)</em></p>
<p><strong>RCAN</strong> (<a href="https://arxiv.org/abs/1812.07252">James et al., 2019</a>), short for “Randomized-to-Canonical Adaptation Networks”, is a nice combination of DA and DR for end-to-end RL tasks. An image-conditional GAN (<a href="https://arxiv.org/abs/1611.07004">cGAN</a>) is trained in sim to translate a domain-randomized image into a non-randomized version (aka “canonical version”). Later the same model is used to translate real images into corresponding simulated version so that the agent would consume consistent observation as what it has encountered in training. Still, the underlying assumption is that the distribution of domain-randomized sim images is broad enough to cover real-world samples.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/RCAN.png" alt="RCAN" /></p>
<p><em>Fig. 5. RCAN is an image-conditional generator that can convert a domain-randomized or real image into its corresponding non-randomized simulator version. (Image source: <a href="https://arxiv.org/abs/1812.07252">James et al., 2019</a>)</em></p>
<p>The RL model is trained end-to-end in a simulator to do vision-based robot arm grasping. Randomization is applied at each timestep, including the position of tray divider, objects to grasp, random textures, as well as the position, direction, and color of the lighting. The canonical version is the default simulator look. RCAN is trying to learn a generator</p>
<p>\(G\): randomized image \(\to\) {canonical image, segmentation, depth}</p>
<p>where segmentation masks and depth images are used as auxiliary tasks. RCAN had a better zero-shot transfer compared to uniform DR, although both were shown to be worse than the model trained on only real images. Conceptually, RCAN operates in a reverse direction of <a href="https://arxiv.org/abs/1709.07857">GraspGAN</a> which translates synthetic images into real ones by domain adaptation.</p>
<h3 id="guided-by-data-in-simulator">Guided by Data in Simulator</h3>
<p>Network-driven domain randomization (<a href="https://arxiv.org/abs/1904.02750">Zakharov et al., 2019</a>), also known as <strong>DeceptionNet</strong>, is motivated by learning which randomizations are actually useful to bridge the domain gap for image classification tasks.</p>
<p>Randomization is applied through a set of deception modules with encoder-decoder architecture. The deception modules are specifically designed to transform images; such as change backgrounds, add distortion, change lightings, etc. The other recognition network handles the main task by running classification on transformed images.</p>
<p>The training involves two steps:</p>
<ol>
<li>With the recognition network fixed, <em>maximize the difference</em> between the prediction and the labels by applying reversed gradients during backpropagation. So that the deception module can learn the most confusing tricks.</li>
<li>With the deception modules fixed, train the recognition network with input images altered.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/deception-net.png" alt="DeceptionNet" /></p>
<p><em>Fig. 6. How DeceptionNet works. (Image source: <a href="https://arxiv.org/abs/1904.02750">Zakharov et al., 2019</a>)</em></p>
<p>The feedback for training deception modules is provided by the downstream classifier. But rather than trying to maximize the task performance like <a href="#optimization-for-task-performance">the section</a> above, the randomization modules aim to create harder cases. One big disadvantage is you need to manually design different deception modules for different datasets or tasks, making it not easily scalable. Given the fact that it is zero-shot, the results are still worse than SOTA DA methods on MNIST and LineMOD.</p>
<p>Similarly, Active domain randomization (<strong>ADR</strong>; <a href="https://arxiv.org/abs/1904.04762">Mehta et al., 2019</a>) also relies on sim data to create harder training samples. ADR searches for the <em>most informative</em> environment variations within the given randomization ranges, where the <em>informativeness</em> is measured as the discrepancies of policy rollouts in randomized and reference (original, non-randomized) environment instances. Sounds a bit like <a href="#match-real-data-distribution">SimOpt</a>? Well, noted that SimOpt measures the discrepancy between sim and real rollouts, while ADR measures between randomized and non-randomized sim, avoiding the expensive real data collection part.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/ADR.png" alt="ADR" /></p>
<p><em>Fig. 7. How active domain randomization (ADR) works. (Image source: <a href="https://arxiv.org/abs/1904.04762">Mehta et al., 2019</a>)</em></p>
<p>Precisely the training happens as follows:</p>
<ol>
<li>Given a policy, run it on both reference and randomized envs and collect two sets of trajectories respectively.</li>
<li>Train a discriminator model to tell whether a rollout trajectory is randomized apart from reference run. The predicted \(\log p\) (probability of being randomized) is used as reward. The more different randomized and reference rollouts, the easier the prediction, the higher the reward.
<ul>
<li>The intuition is that if an environment is easy, the same policy agent can produce similar trajectories as in the reference one. Then the model should reward and explore hard environments by encouraging different behaviors.</li>
</ul>
</li>
<li>The reward by discriminator is fed into <em>Stein Variational Policy Gradient</em> (<a href="https://arxiv.org/abs/1704.02399">SVPG</a>) particles, outputting a diverse set of randomization configurations.</li>
</ol>
<p>The idea of ADR is very appealing with two small concerns. The similarity between trajectories might not be a good way to measure the env difficulty when running a stochastic policy. The sim2real results look unfortunately not as exciting, but the paper pointed out the win being ADR explores a smaller range of randomization parameters.</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019DR,
title = "Domain Randomization for Sim2Real Transfer",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/05/04/domain-randomization.html"
}
</code></pre></div></div>
<p>Overall, after reading this post, I hope you like domain randomization as much as I do :).</p>
<h2 id="references">References</h2>
<p>[1] Josh Tobin, et al. <a href="https://arxiv.org/pdf/1703.06907.pdf">“Domain randomization for transferring deep neural networks from simulation to the real world.”</a> IROS, 2017.</p>
<p>[2] Fereshteh Sadeghi and Sergey Levine. <a href="https://arxiv.org/abs/1611.04201">“CAD2RL: Real single-image flight without a single real image.”</a> arXiv:1611.04201 (2016).</p>
<p>[3] Xue Bin Peng, et al. <a href="https://arxiv.org/abs/1710.06537">“Sim-to-real transfer of robotic control with dynamics randomization.”</a> ICRA, 2018.</p>
<p>[4] Nataniel Ruiz, et al. <a href="https://openreview.net/forum?id=HJgkx2Aqt7">“Learning to Simulate.”</a> ICLR 2019</p>
<p>[5] OpenAI. <a href="https://arxiv.org/abs/1808.00177">“Learning Dexterous In-Hand Manipulation.”</a> arXiv:1808.00177 (2018).</p>
<p>[6] OpenAI Blog. <a href="https://openai.com/blog/learning-dexterity/">“Learning dexterity”</a> July 30, 2018.</p>
<p>[7] Quan Vuong, et al. <a href="https://arxiv.org/abs/1903.11774">“How to pick the domain randomization parameters for sim-to-real transfer of reinforcement learning policies?.”</a> arXiv:1903.11774 (2019).</p>
<p>[8] Ekin D. Cubuk, et al. <a href="https://arxiv.org/abs/1805.09501">“AutoAugment: Learning augmentation policies from data.”</a> arXiv:1805.09501 (2018).</p>
<p>[9] Wenhao Yu et al. <a href="https://openreview.net/forum?id=H1g6osRcFQ">“Policy Transfer with Strategy Optimization.”</a> ICLR 2019</p>
<p>[10] Yevgen Chebotar et al. <a href="https://arxiv.org/abs/1810.05687">“Closing the Sim-to-Real Loop: Adapting Simulation Randomization with Real World Experience.”</a> Arxiv: 1810.05687 (2019).</p>
<p>[11] Stephen James et al. <a href="https://arxiv.org/abs/1812.07252">“Sim-to-real via sim-to-sim: Data-efficient robotic grasping via randomized-to-canonical adaptation networks”</a> CVPR 2019.</p>
<p>[12] Bhairav Mehta et al. <a href="https://arxiv.org/abs/1904.04762">“Active Domain Randomization”</a> arXiv:1904.04762</p>
<p>[13] Sergey Zakharov,et al. <a href="https://arxiv.org/abs/1904.02750">“DeceptionNet: Network-Driven Domain Randomization.”</a> arXiv:1904.02750 (2019).</p>
<p>[14] Amlan Kar, et al. <a href="https://arxiv.org/abs/1904.11621">“Meta-Sim: Learning to Generate Synthetic Datasets.”</a> arXiv:1904.11621 (2019).</p>
<p>[15] Aayush Prakash, et al. <a href="https://arxiv.org/abs/1810.10093">“Structured Domain Randomization: Bridging the Reality Gap by Context-Aware Synthetic Data.”</a> arXiv:1810.10093 (2018).</p>Lilian WengIf a model or policy is mainly trained in a simulator but expected to work on a real robot, it would surely face the sim2real gap. Domain Randomization (DR) is a simple but powerful idea of closing this gap by randomizing properties of the training environment.Are Deep Neural Networks Dramatically Overfitted?2019-03-14T12:00:00+00:002019-03-14T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/03/14/are-deep-neural-networks-dramatically-overfitted<blockquote>
<p>If you are, like me, confused by why deep neural networks can generalize to out-of-sample data points without drastic overfitting, keep on reading.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2019-05-27: add the <a href="#the-lottery-ticket-hypothesis">section</a> on Lottery Ticket Hypothesis.]</span></p>
<p>If you are like me, entering into the field of deep learning with experience in traditional machine learning, you may often ponder over this question: Since a typical deep neural network has so many parameters and training error can easily be perfect, it should surely suffer from substantial overfitting. How could it be ever generalized to out-of-sample data points?</p>
<p>The effort in understanding why deep neural networks can generalize somehow reminds me of this interesting paper on System Biology — <a href="https://www.cell.com/cancer-cell/pdf/S1535-6108(02)00133-2.pdf">“Can a biologist fix a radio?”</a> (Lazebnik, 2002). If a biologist intends to fix a radio machine like how she works on a biological system, life could be hard. Because the full mechanism of the radio system is not revealed, poking small local functionalities might give some hints but it can hardly present all the interactions within the system, let alone the entire working flow. No matter whether you think it is relevant to DL, it is a very fun read.</p>
<p>I would like to discuss a couple of papers on generalizability and complexity measurement of deep learning models in the post. Hopefully, it could shed light on your thinking path towards the understanding of why DNN can generalize.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#classic-theorems-on-compression-and-model-selection" id="markdown-toc-classic-theorems-on-compression-and-model-selection">Classic Theorems on Compression and Model Selection</a> <ul>
<li><a href="#occams-razor" id="markdown-toc-occams-razor">Occam’s Razor</a></li>
<li><a href="#minimum-description-length-principle" id="markdown-toc-minimum-description-length-principle">Minimum Description Length principle</a></li>
<li><a href="#kolmogorov-complexity" id="markdown-toc-kolmogorov-complexity">Kolmogorov Complexity</a></li>
<li><a href="#solomonoffs-inference-theory" id="markdown-toc-solomonoffs-inference-theory">Solomonoff’s Inference Theory</a></li>
</ul>
</li>
<li><a href="#expressive-power-of-dl-models" id="markdown-toc-expressive-power-of-dl-models">Expressive Power of DL Models</a> <ul>
<li><a href="#universal-approximation-theorem" id="markdown-toc-universal-approximation-theorem">Universal Approximation Theorem</a></li>
<li><a href="#proof-finite-sample-expressivity-of-two-layer-nn" id="markdown-toc-proof-finite-sample-expressivity-of-two-layer-nn">Proof: Finite Sample Expressivity of Two-layer NN</a></li>
<li><a href="#deep-nn-can-learn-random-noise" id="markdown-toc-deep-nn-can-learn-random-noise">Deep NN can Learn Random Noise</a></li>
</ul>
</li>
<li><a href="#are-deep-learning-models-dramatically-overfitted" id="markdown-toc-are-deep-learning-models-dramatically-overfitted">Are Deep Learning Models Dramatically Overfitted?</a> <ul>
<li><a href="#modern-risk-curve-for-deep-learning" id="markdown-toc-modern-risk-curve-for-deep-learning">Modern Risk Curve for Deep Learning</a></li>
<li><a href="#regularization-is-not-the-key-to-generalization" id="markdown-toc-regularization-is-not-the-key-to-generalization">Regularization is not the Key to Generalization</a></li>
<li><a href="#intrinsic-dimension" id="markdown-toc-intrinsic-dimension">Intrinsic Dimension</a></li>
<li><a href="#heterogeneous-layer-robustness" id="markdown-toc-heterogeneous-layer-robustness">Heterogeneous Layer Robustness</a></li>
<li><a href="#the-lottery-ticket-hypothesis" id="markdown-toc-the-lottery-ticket-hypothesis">The Lottery Ticket Hypothesis</a></li>
</ul>
</li>
<li><a href="#experiments" id="markdown-toc-experiments">Experiments</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="classic-theorems-on-compression-and-model-selection">Classic Theorems on Compression and Model Selection</h2>
<p>Let’s say we have a classification problem and a dataset, we can develop many models to solve it, from fitting a simple linear regression to memorizing the full dataset in disk space. Which one is better? If we only care about the accuracy over training data (especially given that testing data is likely unknown), the memorization approach seems to be the best — well, it doesn’t sound right.</p>
<p>There are many classic theorems to guide us when deciding what types of properties a good model should possess in such scenarios.</p>
<h3 id="occams-razor">Occam’s Razor</h3>
<p><a href="http://pespmc1.vub.ac.be/OCCAMRAZ.html">Occam’s Razor</a> is an informal principle for problem-solving, proposed by <a href="https://en.wikipedia.org/wiki/William_of_Ockham">William of Ockham</a> in the 14th century:</p>
<blockquote>
<p>“Simpler solutions are more likely to be correct than complex ones.”</p>
</blockquote>
<p>The statement is extremely powerful when we are facing multiple candidates of underlying theories to explain the world and have to pick one. Too many unnecessary assumptions might seem to be plausible for one problem, but harder to be generalized to other complications or to eventually lead to the basic principles of the universe.</p>
<p>Think of this, it took people hundreds of years to figure out that the sky is blue in the daytime but reddish at sunset are because of the same reason (<a href="https://en.wikipedia.org/wiki/Rayleigh_scattering">Rayleigh scattering</a>), although two phenomena look very different. People must have proposed many other explanations for them separately but the unified and simple version won eventually.</p>
<h3 id="minimum-description-length-principle">Minimum Description Length principle</h3>
<p>The principle of Occam’s Razor can be similarly applied to machine learning models. A formalized version of such concept is called the <em>Minimum Description Length (MDL)</em> principle, used for comparing competing models / explanations given data observed.</p>
<blockquote>
<p>“Comprehension is compression.”</p>
</blockquote>
<p>The fundamental idea in MDL is to <em>view learning as data compression</em>. By compressing the data, we need to discover regularity or patterns in the data with the high potentiality to generalize to unseen samples. <a href="/lil-log/2017/09/28/anatomize-deep-learning-with-information-theory.html">Information bottleneck</a> theory believes that a deep neural network is trained first to represent the data by minimizing the generalization error and then learn to compress this representation by trimming noise.</p>
<p>Meanwhile, MDL considers the model description as part of the compression delivery, so the model cannot be arbitrarily large.</p>
<p>A <em>two-part version</em> of MDL principle states that: Let \(\mathcal{H}^{(1)}, \mathcal{H}^{(2)}, \dots\) be a list of models that can explain the dataset \(\mathcal{D}\). The best hypothesis among them should be the one that minimizes the sum:</p>
\[\mathcal{H}^\text{best} = \arg\min_\mathcal{H} [L(\mathcal{H}) + L(\mathcal{D}\vert\mathcal{H})]\]
<ul>
<li>\(L(\mathcal{H})\) is the length of the description of model \(\mathcal{H}\) in bits.</li>
<li>\(L(\mathcal{D}\vert\mathcal{H})\) is the length of the description of the data \(\mathcal{D}\) in bits when encoded with \(\mathcal{H}\).</li>
</ul>
<p>In simple words, the <em>best</em> model is the <em>smallest</em> model containing the encoded data and the model itself. Following this criterion, the memorization approach I proposed at the beginning of the section sounds horrible no matter how good accuracy it can achieve on the training data.</p>
<p>People might argue Occam’s Razor is wrong, as given the real world can be arbitrarily complicated, why do we have to find simple models? One interesting view by MDL is to consider models as <strong>“languages”</strong> instead of fundamental generative theorems. We would like to find good compression strategies to describe regularity in a small set of samples, and they <strong>do not have to be the “real” generative model</strong> for explaining the phenomenon. Models can be wrong but still useful (i.e., think of any Bayesian prior).</p>
<h3 id="kolmogorov-complexity">Kolmogorov Complexity</h3>
<p>Kolmogorov Complexity relies on the concept of modern computers to define the algorithmic (descriptive) complexity of an object: It is <em>the length of the shortest binary computer program that describes the object</em>. Following MDL, a computer is essentially the most general form of data decompressor.</p>
<p>The formal definition of Kolmogorov Complexity states that: Given a universal computer \(\mathcal{U}\) and a program \(p\), let’s denote \(\mathcal{U}(p)\) as the output of the computer processing the program and \(L(p)\) as the descriptive length of the program. Then Kolmogorov Complexity \(K_\mathcal{U}\) of a string \(s\) with respect to a universal computer \(\mathcal{U}\) is:</p>
\[K_\mathcal{U}(s) = \min_{p: \mathcal{U}(p)=s} L(p)\]
<p>Note that a universal computer is one that can mimic the actions of any other computers. All modern computers are universal as they can all be reduced to Turing machines. The definition is universal no matter which computers we are using, because another universal computer can always be programmed to clone the behavior of \(\mathcal{U}\), while encoding this clone program is just a constant.</p>
<p>There are a lot of connections between Kolmogorov Complexity and Shannon Information Theory, as both are tied to universal coding. It is an amazing fact that the expected Kolmogorov Complexity of a random variable is approximately equal to its Shannon entropy (see Sec 2.3 of <a href="https://homepages.cwi.nl/~paulv/papers/info.pdf">the report</a>). More on this topic is out of the scope here, but there are many interesting readings online. Help yourself :)</p>
<h3 id="solomonoffs-inference-theory">Solomonoff’s Inference Theory</h3>
<p>Another mathematical formalization of Occam’s Razor is Solomonoff’s theory of universal inductive inference (<a href="https://www.sciencedirect.com/science/article/pii/S0019995864902232">Solomonoff</a>, <a href="https://www.sciencedirect.com/science/article/pii/S0019995864901317">1964</a>). The principle is to favor models that correspond to the “shortest program” to produce the training data, based on its Kolmogorov complexity</p>
<h2 id="expressive-power-of-dl-models">Expressive Power of DL Models</h2>
<p>Deep neural networks have an extremely large number of parameters compared to the traditional statistical models. If we use MDL to measure the complexity of a deep neural network and consider the number of parameters as the model description length, it would look awful. The model description \(L(\mathcal{H})\) can easily grow out of control.</p>
<p>However, having numerous parameters is <em>necessary</em> for a neural network to obtain high expressivity power. Because of its great capability to capture any flexible data representation, deep neural networks have achieved great success in many applications.</p>
<h3 id="universal-approximation-theorem">Universal Approximation Theorem</h3>
<p>The <em>Universal Approximation Theorem</em> states that a feedforward network with: 1) a linear output layer, 2) at least one hidden layer containing a finite number of neurons and 3) some activation function can approximate <strong>any</strong> continuous functions on a compact subset of \(\mathbb{R}^n\) to arbitrary accuracy. The theorem was first proved for sigmoid activation function (<a href="https://pdfs.semanticscholar.org/05ce/b32839c26c8d2cb38d5529cf7720a68c3fab.pdf">Cybenko, 1989</a>). Later it was shown that the universal approximation property is not specific to the choice of activation (<a href="http://zmjones.com/static/statistical-learning/hornik-nn-1991.pdf">Hornik, 1991</a>) but the multilayer feedforward architecture.</p>
<p>Although a feedforward network with a single layer is sufficient to represent any function, the width has to be exponentially large. The universal approximation theorem does not guarantee whether the model can be learned or generalized properly. Often, adding more layers helps to reduce the number of hidden neurons needed in a shallow network.</p>
<p>To take advantage of the universal approximation theorem, we can always find a neural network to represent the target function with error under any desired threshold, but we need to pay the price — the network might grow super large.</p>
<h3 id="proof-finite-sample-expressivity-of-two-layer-nn">Proof: Finite Sample Expressivity of Two-layer NN</h3>
<p>The Universal Approximation Theorem we have discussed so far does not consider a finite sample set. <a href="https://arxiv.org/abs/1611.03530">Zhang, et al. (2017)</a> provided a neat proof on the finite-sample expressivity of two-layer neural networks.</p>
<p>A neural network \(C\) can represent any function given a sample size \(n\) in \(d\) dimensions if: For every finite sample set \(S \subseteq \mathbb{R}^d\) with \(\vert S \vert = n\) and every function defined on this sample set: \(f: S \mapsto \mathbb{R}\), we can find a set of weight configuration for \(C\) so that \(C(\boldsymbol{x}) = f(\boldsymbol{x}), \forall \boldsymbol{x} \in S\).</p>
<p>The paper proposed a theorem:</p>
<blockquote>
<p>There exists a two-layer neural network with ReLU activations and \(2n + d\) weights that can represent any function on a sample of size \(n\) in \(d\) dimensions.</p>
</blockquote>
<p><em>Proof.</em> First we would like to construct a two-layer neural network \(C: \mathbb{R}^d \mapsto \mathbb{R}\). The input is a \(d\)-dimensional vector, \(\boldsymbol{x} \in \mathbb{R}^d\). The hidden layer has \(h\) hidden units, associated with a weight matrix \(\mathbf{W} \in \mathbb{R}^{d\times h}\), a bias vector \(-\mathbf{b} \in \mathbb{R}^h\) and ReLU activation function. The second layer outputs a scalar value with weight vector \(\boldsymbol{v} \in \mathbb{R}^h\) and zero biases.</p>
<p>The output of network \(C\) for a input vector \(\boldsymbol{x}\) can be represented as follows:</p>
\[C(\boldsymbol{x})
= \boldsymbol{v} \max\{ \boldsymbol{x}\mathbf{W} - \boldsymbol{b}, 0\}^\top
= \sum_{i=1}^h v_i \max\{\boldsymbol{x}\boldsymbol{W}_{(:,i)} - b_i, 0\}\]
<p>where \(\boldsymbol{W}_{(:,i)}\) is the \(i\)-th column in the \(d \times h\) matrix.</p>
<p>Given a sample set \(S = \{\boldsymbol{x}_1, \dots, \boldsymbol{x}_n\}\) and target values \(\boldsymbol{y} = \{y_1, \dots, y_n \}\), we would like to find proper weights \(\mathbf{W} \in \mathbb{R}^{d\times h}\), \(\boldsymbol{b}, \boldsymbol{v} \in \mathbb{R}^h\) so that \(C(\boldsymbol{x}_i) = y_i, \forall i=1,\dots,n\).</p>
<p>Let’s combine all sample points into one batch as one input matrix \(\mathbf{X} \in \mathbb{R}^{n \times d}\). If set \(h=n\), \(\mathbf{X}\mathbf{W} - \boldsymbol{b}\) would be a square matrix of size \(n \times n\).</p>
\[\mathbf{M}_\text{ReLU}
= \max\{\mathbf{X}\mathbf{W} - \boldsymbol{b}, 0 \}
= \begin{bmatrix}
\boldsymbol{x}_1\mathbf{W} - \boldsymbol{b} \\
\dots \\
\boldsymbol{x}_n\mathbf{W} - \boldsymbol{b} \\
\end{bmatrix}
= [\boldsymbol{x}_i\boldsymbol{W}_{(:,j)} - b_j]_{i \times j}\]
<p>We can simplify \(\mathbf{W}\) to have the same column vectors across all the columns:</p>
\[\mathbf{W}_{(:,j)} = \boldsymbol{w} \in \mathbb{R}^{d}, \forall j = 1, \dots, n\]
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/nn-expressivity-proof.png" alt="intrinsic dimension experiment 1" /></p>
<p>Let \(a_i = \boldsymbol{x}_i \boldsymbol{w}\), we would like to find a suitable \(\boldsymbol{w}\) and \(\boldsymbol{b}\) such that \(b_1 < a_1 < b_2 < a_2 < \dots < b_n < a_n\). This is always achievable because we try to solve \(n+d\) unknown variables with \(n\) constraints and \(\boldsymbol{x}_i\) are independent (i.e. pick a random \(\boldsymbol{w}\), sort \(\boldsymbol{x}_i \boldsymbol{w}\) and then set \(b_j\)’s as values in between). Then \(\mathbf{M}_\text{ReLU}\) becomes a lower triangular matrix:</p>
\[\mathbf{M}_\text{ReLU} = [a_i - b_j]_{i \times j}
= \begin{bmatrix}
a_1 - b_1 & 0 & 0 & \dots & 0 \\
\vdots & \ddots & & & \vdots \\
a_i - b_1 & \dots & a_i - b_i & \dots & 0\\
\vdots & & & \ddots & \vdots \\
a_n - b_1 & a_n - b_2 & \dots & \dots & a_n - b_n \\
\end{bmatrix}\]
<p>It is a nonsingular square matrix as \(\det(\mathbf{M}_\text{ReLU}) \neq 0\), so we can always find suitable \(\boldsymbol{v}\) to solve \(\boldsymbol{v}\mathbf{M}_\text{ReLU}=\boldsymbol{y}\) (In other words, the column space of \(\mathbf{M}_\text{ReLU}\) is all of \(\mathbb{R}^n\) and we can find a linear combination of column vectors to obtain any \(\boldsymbol{y}\)).</p>
<h3 id="deep-nn-can-learn-random-noise">Deep NN can Learn Random Noise</h3>
<p>As we know two-layer neural networks are universal approximators, it is less surprising to see that they are able to learn unstructured random noise perfectly, as shown in <a href="https://arxiv.org/abs/1611.03530">Zhang, et al. (2017)</a>. If labels of image classification dataset are randomly shuffled, the high expressivity power of deep neural networks can still empower them to achieve near-zero training loss. These results do not change with regularization terms added.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/fit-random-labels.png" alt="Fitting random labels" /></p>
<p><em>Fig. 1. Fit models on CIFAR10 with random labels or random pixels: (a) learning curves; (b-c) label corruption ratio is the percentage of randomly shuffled labels. (Image source: <a href="https://arxiv.org/abs/1611.03530">Zhang’s paper</a>)</em></p>
<h2 id="are-deep-learning-models-dramatically-overfitted">Are Deep Learning Models Dramatically Overfitted?</h2>
<p>Deep learning models are heavily over-parameterized and can often get to perfect results on training data. In the traditional view, like bias-variance trade-offs, this could be a disaster that nothing may generalize to the unseen test data. However, as is often the case, such “overfitted” (training error = 0) deep learning models still present a decent performance on out-of-sample test data. Hmm … interesting and why?</p>
<h3 id="modern-risk-curve-for-deep-learning">Modern Risk Curve for Deep Learning</h3>
<p>The traditional machine learning uses the following U-shape risk curve to measure the bias-variance trade-offs and quantify how generalizable a model is. If I get asked how to tell whether a model is overfitted, this would be the first thing popping into my mind.</p>
<p>As the model turns larger (more parameters added), the training error decreases to close to zero, but the test error (generalization error) starts to increase once the model complexity grows to pass the threshold between “underfitting” and “overfitting”. In a way, this is well aligned with Occam’s Razor.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/bias-variance-risk-curve.png" alt="Bias-variance risk curve" /></p>
<p><em>Fig. 2. U-shaped bias-variance risk curve. (Image source: (left) <a href="https://arxiv.org/abs/1812.11118">paper</a> (right) <a href="http://scott.fortmann-roe.com/docs/BiasVariance.html">fig. 6 of this post</a>)</em></p>
<p>Unfortunately this does not apply to deep learning models. <a href="https://arxiv.org/abs/1812.11118">Belkin et al. (2018)</a> reconciled the traditional bias-variance trade-offs and proposed a new double-U-shaped risk curve for deep neural networks. Once the number of network parameters is high enough, the risk curve enters another regime.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/new-bias-variance-risk-curve.png" alt="new risk curve" /></p>
<p><em>Fig. 3. A new double-U-shaped bias-variance risk curve for deep neural networks. (Image source: <a href="https://arxiv.org/abs/1812.11118">original paper</a>)</em></p>
<p>The paper claimed that it is likely due to two reasons:</p>
<ul>
<li>The number of parameters is not a good measure of <em>inductive bias</em>, defined as the set of assumptions of a learning algorithm used to predict for unknown samples. See more discussion on DL model complexity in <a href="#intrinsic-dimension">later</a> <a href="#heterogeneous-layer-robustness">sections</a>.</li>
<li>Equipped with a larger model, we might be able to discover larger function classes and further find interpolating functions that have smaller norm and are thus “simpler”.</li>
</ul>
<p>The double-U-shaped risk curve was observed empirically, as shown in the paper. However I was struggling quite a bit to reproduce the results. There are some signs of life, but in order to generate a pretty smooth curve similar to the theorem, <a href="#experiments">many details</a> in the experiment have to be taken care of.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/new-risk-curve-mnist.png" alt="New risk curve on MNIST" /></p>
<p><em>Fig. 4. Training and evaluation errors of a one hidden layer fc network of different numbers of hidden units, trained on 4000 data points sampled from MNIST. (Image source: <a href="https://arxiv.org/abs/1812.11118">original paper</a>)</em></p>
<h3 id="regularization-is-not-the-key-to-generalization">Regularization is not the Key to Generalization</h3>
<p>Regularization is a common way to control overfitting and improve model generalization performance. Interestingly some research (<a href="https://arxiv.org/abs/1611.03530">Zhang, et al. 2017</a>) has shown that explicit regularization (i.e. data augmentation, weight decay and dropout) is neither necessary or sufficient for reducing generalization error.</p>
<p>Taking the Inception model trained on CIFAR10 as an example (see Fig. 5), regularization techniques help with out-of-sample generalization but not much. No single regularization seems to be critical independent of other terms. Thus, it is unlikely that regularizers are the <em>fundamental reason</em> for generalization.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/regularization-generalization-test.png" alt="regularization test" /></p>
<p><em>Fig. 5. The accuracy of Inception model trained on CIFAR10 with different combinations of taking on or off data augmentation and weight decay. (Image source: Table 1 in the <a href="https://arxiv.org/abs/1611.03530">original paper</a>)</em></p>
<h3 id="intrinsic-dimension">Intrinsic Dimension</h3>
<p>The number of parameters is not correlated with model overfitting in the field of deep learning, suggesting that parameter counting cannot indicate the true complexity of deep neural networks.</p>
<p>Apart from parameter counting, researchers have proposed many ways to quantify the complexity of these models, such as the number of degrees of freedom of models (<a href="https://arxiv.org/abs/1603.09260">Gao & Jojic, 2016</a>), or prequential code (<a href="https://arxiv.org/abs/1802.07044">Blier & Ollivier, 2018</a>).</p>
<p>I would like to discuss a recent method on this matter, named <strong>intrinsic dimension</strong> (<a href="https://arxiv.org/abs/1804.08838">Li et al, 2018</a>). Intrinsic dimension is intuitive, easy to measure, while still revealing many interesting properties of models of different sizes.</p>
<p>Considering a neural network with a great number of parameters, forming a high-dimensional parameter space, the learning happens on this high-dimensional <em>objective landscape</em>.
The shape of the parameter space manifold is critical. For example, a smoother manifold is beneficial for optimization by providing more predictive gradients and allowing for larger learning rates—this was claimed to be the reason why batch normalization has succeeded in stabilizing training (<a href="https://arxiv.org/abs/1805.11604">Santurkar, et al, 2019</a>).</p>
<p>Even though the parameter space is huge, fortunately we don’t have to worry too much about the optimization process getting stuck in local optima, as it has been <a href="https://arxiv.org/abs/1406.2572">shown</a> that local optimal points in the objective landscape almost always lay in saddle-points rather than valleys. In other words, there is always a subset of dimensions containing paths to leave local optima and keep on exploring.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/optimization-landscape-shape.png" alt="parameter landscape shape" /></p>
<p><em>Fig. 6. Illustrations of various types of critical points on the parameter optimization landscape. (Image source: <a href="https://www.offconvex.org/2016/03/22/saddlepoints/">here</a>)</em></p>
<p>One intuition behind the measurement of intrinsic dimension is that, since the parameter space has such high dimensionality, it is probably not necessary to exploit all the dimensions to learn efficiently. If we only travel through a slice of objective landscape and still can learn a good solution, the complexity of the resulting model is likely lower than what it appears to be by parameter-counting. This is essentially what intrinsic dimension tries to assess.</p>
<p>Say a model has \(D\) dimensions and its parameters are denoted as \(\theta^{(D)}\). For learning, a smaller \(d\)-dimensional subspace is randomly sampled, \(\theta^{(d)}\), where \(d < D\). During one optimization update, rather than taking a gradient step according to all \(D\) dimensions, only the smaller subspace \(\theta^{(d)}\) is used and remapped to update model parameters.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/intrinsic-dimension-illustration.png" alt="illustration" /></p>
<p><em>Fig. 7. Illustration of parameter vectors for direct optimization when \(D=3\). (Image source: <a href="https://arxiv.org/abs/1804.08838">original paper</a>)</em></p>
<p>The gradient update formula looks like the follows:</p>
\[\theta^{(D)} = \theta_0^{(D)} + \mathbf{P} \theta^{(d)}\]
<p>where \(\theta_0^{(D)}\) are the initialization values and \(\mathbf{P}\) is a \(D \times d\) projection matrix that is randomly sampled before training. Both \(\theta_0^{(D)}\) and \(\mathbf{P}\) are not trainable and fixed during training. \(\theta^{(d)}\) is initialized as all zeros.</p>
<p>By searching through the value of \(d = 1, 2, \dots, D\), the corresponding \(d\) when the solution emerges is defined as the <em>intrinsic dimension</em>.</p>
<p>It turns out many problems have much smaller intrinsic dimensions than the number of parameters. For example, on CIFAR10 image classification, a fully-connected network with 650k+ parameters has only 9k intrinsic dimension and a convolutional network containing 62k parameters has an even lower intrinsic dimension of 2.9k.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/intrinsic-dimension.png" alt="intrinsic dimension results" /></p>
<p><em>Fig. 8. The measured intrinsic dimensions \(d\) for various models achieving 90% of the best performance. (Image source: <a href="https://arxiv.org/abs/1804.08838">original paper</a>)</em></p>
<p>The measurement of intrinsic dimensions suggests that deep learning models are significantly simpler than what they might appear to be.</p>
<h3 id="heterogeneous-layer-robustness">Heterogeneous Layer Robustness</h3>
<p><a href="https://arxiv.org/abs/1902.01996">Zhang et al. (2019)</a> investigated the role of parameters in different layers. The fundamental question raised by the paper is: <em>“are all layers created equal?”</em> The short answer is: No. The model is more sensitive to changes in some layers but not others.</p>
<p>The paper proposed two types of operations that can be applied to parameters of the \(\ell\)-th layer, \(\ell = 1, \dots, L\), at time \(t\), \(\theta^{(\ell)}_t\) to test their impacts on model robustness:</p>
<ul>
<li>
<p><strong>Re-initialization</strong>: Reset the parameters to the initial values, \(\theta^{(\ell)}_t \leftarrow \theta^{(\ell)}_0\). The performance of a network in which layer \(\ell\) was re-initialized is referred to as the <em>re-initialization robustness</em> of layer \(\ell\).</p>
</li>
<li>
<p><strong>Re-randomization</strong>: Re-sampling the layer’s parameters at random, \(\theta^{(\ell)}_t \leftarrow \tilde{\theta}^{(\ell)} \sim \mathcal{P}^{(\ell)}\). The corresponding network performance is called the <em>re-randomization robustness</em> of layer \(\ell\).</p>
</li>
</ul>
<p>Layers can be categorized into two categories with the help of these two operations:</p>
<ul>
<li><strong>Robust Layers</strong>: The network has no or only negligible performance degradation after re-initializing or re-randomizing the layer.</li>
<li><strong>Critical Layers</strong>: Otherwise.</li>
</ul>
<p>Similar patterns are observed on fully-connected and convolutional networks. Re-randomizing any of the layers <em>completely destroys</em> the model performance, as the prediction drops to random guessing immediately. More interestingly and surprisingly, when applying re-initialization, only the first or the first few layers (those closest to the input layer) are critical, while re-initializing higher levels causes <em>only negligible decrease</em> in performance.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/layer-robustness-results.png" alt="Re-initialization robustness" /></p>
<p><em>Fig. 9. (a) A fc network trained on MNIST. Each row corresponds to one layer in the network. The first column is re-randomization robustness of each layer and the rest of the columns indicate re-initialization robustness at different training time. (b) VGG11 model (conv net) trained on CIFAR 10. Similar representation as in (a) but rows and columns are transposed. (Image source: <a href="https://arxiv.org/abs/1902.01996">original paper</a>)</em></p>
<p>ResNet is able to use shortcuts between non-adjacent layers to re-distribute the sensitive layers across the networks rather than just at the bottom. With the help of residual block architecture, the network can <em>evenly be robust to re-randomization</em>. Only the first layer of each residual block is still sensitive to both re-initialization and re-randomization. If we consider each residual block as a local sub-network, the robustness pattern resembles the fc and conv nets above.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/layer-robustness-resnet.png" alt="ResNet robustness" /></p>
<p><em>Fig. 10. Re-randomization (first row) and re-initialization (the reset rows) robustness of layers in ResNet-50 model trained on CIFAR10. (Image source: <a href="https://arxiv.org/abs/1902.01996">original paper</a>)</em></p>
<p>Based on the fact that many top layers in deep neural networks are not critical to the model performance after re-initialization, the paper loosely concluded that:</p>
<blockquote>
<p>“Over-capacitated deep networks trained with stochastic gradient have low-complexity due to self-restricting the number of critical layers.”</p>
</blockquote>
<p>We can consider re-initialization as a way to reduce the effective number of parameters, and thus the observation is aligned with what intrinsic dimension has demonstrated.</p>
<h3 id="the-lottery-ticket-hypothesis">The Lottery Ticket Hypothesis</h3>
<p>The lottery ticket hypothesis (<a href="https://arxiv.org/abs/1803.03635">Frankle & Carbin, 2019</a>) is another intriguing and inspiring discovery, supporting that only a subset of network parameters have impact on the model performance and thus the network is not overfitted. The lottery ticket hypothesis states that a randomly initialized, dense, feed-forward network contains a pool of subnetworks and among them only a subset are <em>“winning tickets”</em> which can achieve the optimal performance when <em>trained in isolation</em>.</p>
<p>The idea is motivated by network pruning techniques — removing unnecessary weights (i.e. tiny weights that are almost negligible) without harming the model performance. Although the final network size can be reduced dramatically, it is hard to train such a pruned network architecture successfully from scratch. It feels like in order to successfully train a neural network, we need a large number of parameters, but we don’t need that many parameters to keep the accuracy high once the model is trained. Why is that?</p>
<p>The lottery ticket hypothesis did the following experiments:</p>
<ol>
<li>Randomly initialize a dense feed-forward network with initialization values \(\theta_0\);</li>
<li>Train the network for multiple iterations to achieve a good performance with parameter config \(\theta\);</li>
<li>Run pruning on \(\theta\) and creating a mask \(m\).</li>
<li>The “winning ticket” initialization config is \(m \odot \theta_0\).</li>
</ol>
<p>Only training the small “winning ticket” subset of parameters with the initial values as found in step 1, the model is able to achieve the same level of accuracy as in step 2. It turns out a large parameter space is not needed in the final solution representation, but needed for training as it provides a big pool of initialization configs of many much smaller subnetworks.</p>
<p>The lottery ticket hypothesis opens a new perspective about interpreting and dissecting deep neural network results. Many interesting following-up works are on the way.</p>
<h2 id="experiments">Experiments</h2>
<p>After seeing all the interesting findings above, it should be pretty fun to reproduce them. Some results are easily to reproduce than others. Details are described below. My code is available on github <a href="https://github.com/lilianweng/generalization-experiment">lilianweng/generalization-experiment</a>.</p>
<p><strong>New Risk Curve for DL Models</strong></p>
<p>This is the trickiest one to reproduce. The authors did give me a lot of good advice and I appreciate it a lot. Here are a couple of noticeable settings in their experiments:</p>
<ul>
<li>There are no regularization terms like weight decay, dropout.</li>
<li>In Fig 3, the training set contains 4k samples. It is only sampled once and fixed for all the models. The evaluation uses the full MNIST test set.</li>
<li>Each network is trained for a long time to achieve near-zero training risk. The learning rate is adjusted differently for models of different sizes.</li>
<li>To make the model less sensitive to the initialization in the under-parameterization region, their experiments adopted a <em>“weight reuse”</em> scheme: the parameters obtained from training a smaller neural network are used as initialization for training larger networks.</li>
</ul>
<p>I did not train or tune each model long enough to get perfect training performance, but evaluation error indeed shows a special twist around the interpolation threshold, different from training error. For example, for MNIST, the threshold is the number of training samples times the number of classes (10), that is 40000.</p>
<p>The x-axis is the number of model parameters: (28 * 28 + 1) * num. units + num. units * 10, in logarithm.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/risk_curve_loss-mse_sample-4000_epoch-500.png" alt="risk curve experiment 1" /></p>
<p><br /></p>
<p><strong>Layers are not Created Equal</strong></p>
<p>This one is fairly easy to reproduce. See my implementation <a href="https://github.com/lilianweng/generalization-experiment/blob/master/layer_equality.py">here</a>.</p>
<p>In the first experiment, I used a three-layer fc networks with 256 units in each layer. Layer 0 is the input layer while layer 3 is the output. The network is trained on MNIST for 100 epochs.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/layer_equality_256x3.png" alt="Layer equality experiment 1" /></p>
<p>In the second experiment, I used a four-layer fc networks with 128 units in each layer. Other settings are the same as experiment 1.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/layer_equality_128x4.png" alt="Layer equality experiment 2" /></p>
<p><br /></p>
<p><strong>Intrinsic Dimension Measurement</strong></p>
<p>To correctly map the \(d\)-dimensional subspace to the full parameter space, the projection matrix \(\mathbf{P}\) should have orthogonal columns. Because the production \(\mathbf{P}\theta^{(d)}\) is the sum of columns of \(\mathbf{P}\) scaled by corresponding scalar values in the \(d\)-dim vector, \(\sum_{i=1}^d \theta^{(d)}_i \mathbf{P}^\top_{(:,i)}\), it is better to fully utilize the subspace with orthogonal columns in \(\mathbf{P}\).</p>
<p>My implementation follows a naive approach by sampling a large matrix with independent entries from a standard normal distribution. The columns are expected to be independent in a high dimension space and thus to be orthogonal. This works when the dimension is not too large. When exploring with a large \(d\), there are methods for creating sparse projection matrices, which is what the intrinsic dimension paper suggested.</p>
<p>Here are experiment runs on two networks: (left) a two-layer fc network with 64 units in each layer and (right) a one-layer fc network with 128 hidden units, trained on 10% of MNIST. For every \(d\), the model is trained for 100 epochs. See the <a href="https://github.com/lilianweng/generalization-experiment/blob/master/intrinsic_dimensions.py">code</a> <a href="https://github.com/lilianweng/generalization-experiment/blob/master/intrinsic_dimensions_measurement.py">here</a>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/intrinsic-dimension-net-64-64-and-128.png" alt="intrinsic dimension experiment 1" /></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019overfit,
title = "Are Deep Neural Networks Dramatically Overfitted?",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/03/14/are-deep-neural-networks-dramatically-overfitted.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Wikipedia page on <a href="https://en.wikipedia.org/wiki/Occam%27s_razor">Occam’s Razor</a>.</p>
<p>[2] <a href="http://pespmc1.vub.ac.be/OCCAMRAZ.html">Occam’s Razor</a> on Principia Cybernetica Web.</p>
<p>[3] Peter Grunwald. <a href="https://arxiv.org/abs/math/0406077">“A Tutorial Introduction to the Minimum Description Length Principle”</a>. 2004.</p>
<p>[4] Ian Goodfellow, et al. <a href="https://www.deeplearningbook.org/">Deep Learning</a>. 2016. <a href="https://www.deeplearningbook.org/contents/mlp.html">Sec 6.4.1</a>.</p>
<p>[5] Zhang, Chiyuan, et al. <a href="https://arxiv.org/abs/1611.03530">“Understanding deep learning requires rethinking generalization.”</a> ICLR 2017.</p>
<p>[6] Shibani Santurkar, et al. <a href="https://arxiv.org/abs/1805.11604">“How does batch normalization help optimization?.”</a> NIPS 2018.</p>
<p>[7] Mikhail Belkin, et al. <a href="https://arxiv.org/abs/1812.11118">“Reconciling modern machine learning and the bias-variance trade-off.”</a> arXiv:1812.11118, 2018.</p>
<p>[8] Chiyuan Zhang, et al. <a href="https://arxiv.org/abs/1902.01996">“Are All Layers Created Equal?”</a> arXiv:1902.01996, 2019.</p>
<p>[9] Chunyuan Li, et al. <a href="https://arxiv.org/abs/1804.08838">“Measuring the intrinsic dimension of objective landscapes.”</a> ICLR 2018.</p>
<p>[10] Jonathan Frankle and Michael Carbin. <a href="https://arxiv.org/abs/1803.03635">“The lottery ticket hypothesis: Finding sparse, trainable neural networks.”</a> ICLR 2019.</p>Lilian WengIf you are, like me, confused by why deep neural networks can generalize to out-of-sample data points without drastic overfitting, keep on reading.