Jekyll2020-06-28T08:24:22+00:00https://lilianweng.github.io/lil-log/feed.xmlLil’LogDocument my learning notes.Lilian WengExploration Strategies in Deep Reinforcement Learning2020-06-07T12:00:00+00:002020-06-07T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/06/07/exploration-strategies-in-deep-reinforcement-learning<blockquote>
<p>Exploitation versus exploration is a critical topic in reinforcement learning. This post introduces several common approaches for better exploration in Deep RL.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2020-06-17: Add <a href="#exploration-via-disagreement">“exploration via disagreement”</a> in the “Forward Dynamics” <a href="#forward-dynamics">section</a>.</span></p>
<p><a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html">Exploitation versus exploration</a> is a critical topic in Reinforcement Learning. We’d like the RL agent to find the best solution as fast as possible. However, in the meantime, committing to solutions too quickly without enough exploration sounds pretty bad, as it could lead to local minima or total failure. Modern <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">RL</a> <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html">algorithms</a> that optimize for the best returns can achieve good exploitation quite efficiently, while exploration remains more like an open topic.</p>
<p>I would like to discuss several common exploration strategies in Deep RL here. As this is a very big topic, my post by no means can cover all the important subtopics. I plan to update it periodically and keep further enriching the content gradually in time.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#classic-exploration-strategies" id="markdown-toc-classic-exploration-strategies">Classic Exploration Strategies</a></li>
<li><a href="#key-exploration-problems" id="markdown-toc-key-exploration-problems">Key Exploration Problems</a> <ul>
<li><a href="#the-hard-exploration-problem" id="markdown-toc-the-hard-exploration-problem">The Hard-Exploration Problem</a></li>
<li><a href="#the-noisy-tv-problem" id="markdown-toc-the-noisy-tv-problem">The Noisy-TV Problem</a></li>
</ul>
</li>
<li><a href="#intrinsic-rewards-as-exploration-bonuses" id="markdown-toc-intrinsic-rewards-as-exploration-bonuses">Intrinsic Rewards as Exploration Bonuses</a> <ul>
<li><a href="#count-based-exploration" id="markdown-toc-count-based-exploration">Count-based Exploration</a> <ul>
<li><a href="#counting-by-density-model" id="markdown-toc-counting-by-density-model">Counting by Density Model</a></li>
<li><a href="#counting-after-hashing" id="markdown-toc-counting-after-hashing">Counting after Hashing</a></li>
</ul>
</li>
<li><a href="#prediction-based-exploration" id="markdown-toc-prediction-based-exploration">Prediction-based Exploration</a> <ul>
<li><a href="#forward-dynamics" id="markdown-toc-forward-dynamics">Forward Dynamics</a></li>
<li><a href="#random-networks" id="markdown-toc-random-networks">Random Networks</a></li>
<li><a href="#physical-properties" id="markdown-toc-physical-properties">Physical Properties</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#memory-based-exploration" id="markdown-toc-memory-based-exploration">Memory-based Exploration</a> <ul>
<li><a href="#episodic-memory" id="markdown-toc-episodic-memory">Episodic Memory</a></li>
<li><a href="#direct-exploration" id="markdown-toc-direct-exploration">Direct Exploration</a></li>
</ul>
</li>
<li><a href="#q-value-exploration" id="markdown-toc-q-value-exploration">Q-Value Exploration</a></li>
<li><a href="#varitional-options" id="markdown-toc-varitional-options">Varitional Options</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="classic-exploration-strategies">Classic Exploration Strategies</h2>
<p>As a quick recap, let’s first go through several classic exploration algorithms that work out pretty well in the multi-armed bandit problem or simple tabular RL.</p>
<ul>
<li><strong>Epsilon-greedy</strong>: The agent does random exploration occasionally with probability <script type="math/tex">\epsilon</script> and takes the optimal action most of the time with probability <script type="math/tex">1-\epsilon</script>.</li>
<li><strong>Upper confidence bounds</strong>: The agent selects the greediest action to maximize the upper confidence bound <script type="math/tex">\hat{Q}_t(a) + \hat{U}_t(a)</script>, where <script type="math/tex">\hat{Q}_t(a)</script> is the average rewards associated with action <script type="math/tex">a</script> up to time <script type="math/tex">t</script> and <script type="math/tex">\hat{U}_t(a)</script> is a function reversely proportional to how many times action <script type="math/tex">a</script> has been taken. See <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#upper-confidence-bounds">here</a> for more details.</li>
<li><strong>Boltzmann exploration</strong>: The agent draws actions from a <a href="https://en.wikipedia.org/wiki/Boltzmann_distribution">boltzmann distribution</a> (softmax) over the learned Q values, regulated by a temperature parameter <script type="math/tex">\tau</script>.</li>
<li><strong>Thompson sampling</strong>: The agent keeps track of a belief over the probability of optimal actions and samples from this distribution. See <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#thompson-sampling">here</a> for more details.</li>
</ul>
<p>The following strategies could be used for better exploration in deep RL training when neural networks are used for function approximation:</p>
<ul>
<li><strong>Entropy loss term</strong>: Add an entropy term <script type="math/tex">H(\pi(a \vert s))</script> into the loss function, encouraging the policy to take diverse actions.</li>
<li><strong>Noise-based Exploration</strong>: Add noise into the observation, action or even parameter space (<a href="https://arxiv.org/abs/1706.10295">Fortunato, et al. 2017</a>, <a href="https://arxiv.org/abs/1706.01905">Plappert, et al. 2017</a>).</li>
</ul>
<h2 id="key-exploration-problems">Key Exploration Problems</h2>
<p>Good exploration becomes especially hard when the environment rarely provides rewards as feedback or the environment has distracting noise. Many exploration strategies are proposed to solve one or both of the following problems.</p>
<h3 id="the-hard-exploration-problem">The Hard-Exploration Problem</h3>
<p>The “hard-exploration” problem refers to exploration in an environment with very sparse or even deceptive reward. It is difficult because random exploration in such scenarios can rarely discover successful states or obtain meaningful feedback.</p>
<p><a href="https://en.wikipedia.org/wiki/Montezuma%27s_Revenge_(video_game)">Montezuma’s Revenge</a> is a concrete example for the hard-exploration problem. It remains as a few challenging games in Atari for DRL to solve. Many papers use Montezuma’s Revenge to benchmark their results.</p>
<h3 id="the-noisy-tv-problem">The Noisy-TV Problem</h3>
<p>The “Noisy-TV” problem started as a thought experiment in <a href="https://arxiv.org/abs/1810.12894">Burda, et al (2018)</a>. Imagine that an RL agent is rewarded with seeking novel experience, a TV with uncontrollable & unpredictable random noise outputs would be able to attract the agent’s attention forever. The agent obtains new rewards from noisy TV consistently, but it fails to make any meaningful progress and becomes a “couch potato”.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/the-noisy-TV-problem.gif" alt="The noisy-TV problem" /></p>
<p><em>Fig. 1. An agent is rewarded with novel experience in the experiment. If a maze has a noisy TC set up, the agent would be attracted and stop moving in the maze. (Image source: <a href="https://openai.com/blog/reinforcement-learning-with-prediction-based-rewards/">OpenAI Blog: “Reinforcement Learning with Prediction-Based Rewards”</a>)</em></p>
<h2 id="intrinsic-rewards-as-exploration-bonuses">Intrinsic Rewards as Exploration Bonuses</h2>
<p>One common approach to better exploration, especially for solving the <a href="#the-hard-exploration-problem">hard-exploration</a> problem, is to augment the environment reward with an additional bonus signal to encourage extra exploration. The policy is thus trained with a reward composed of two terms, <script type="math/tex">r_t = r^e_t + \beta r^i_t</script>, where <script type="math/tex">\beta</script> is a hyperparameter adjusting the balance between exploitation and exploration.</p>
<ul>
<li><script type="math/tex">r^e_t</script> is an <em>extrinsic</em> reward from the environment at time <script type="math/tex">t</script>, defined according to the task in hand.</li>
<li><script type="math/tex">r^i_t</script> is an <em>intrinsic</em> exploration bonus at time <script type="math/tex">t</script>.</li>
</ul>
<p>This intrinsic reward is somewhat inspired by <em>intrinsic motivation</em> in psychology (<a href="https://www.researchgate.net/profile/Pierre-Yves_Oudeyer/publication/29614795_How_can_we_define_intrinsic_motivation/links/09e415107f1b4c8041000000/How-can-we-define-intrinsic-motivation.pdf">Oudeyer & Kaplan, 2008</a>). Exploration driven by curiosity might be an important way for children to grow and learn. In other words, exploratory activities should be rewarding intrinsically in the human mind to encourage such behavior. The intrinsic rewards could be correlated with curiosity, surprise, familiarity of the state, and many other factors.</p>
<p>Same ideas can be applied to RL algorithms. In the following sections, methods of bonus-based exploration rewards are roughly grouped into two categories:</p>
<ol>
<li>Discovery of novel states</li>
<li>Improvement of the agent’s knowledge about the environment.</li>
</ol>
<h3 id="count-based-exploration">Count-based Exploration</h3>
<p>If we consider intrinsic rewards as rewarding conditions that surprise us, we need a way to measure whether a state is novel or appears often. One intuitive way is to count how many times a state has been encountered and to assign a bonus accordingly. The bonus guides the agent’s behavior to prefer rarely visited states to common states. This is known as the <strong>count-based exploration</strong> method.</p>
<p>Let <script type="math/tex">N_n(s)</script> be the <em>empirical count</em> function that tracks the real number of visits of a state <script type="math/tex">s</script> in the sequence of <script type="math/tex">s_{1:n}</script>. Unfortunately, using <script type="math/tex">N_n(s)</script> for exploration directly is not practical, because most of the states would have <script type="math/tex">N_n(s)=0</script>, especially considering that the state space is often continuous or high-dimensional. We need an non-zero count for most states, even when they haven’t been seen before.</p>
<h4 id="counting-by-density-model">Counting by Density Model</h4>
<p><a href="https://arxiv.org/abs/1606.01868">Bellemare, et al. (2016)</a> used a <strong>density model</strong> to approximate the frequency of state visits and a novel algorithm for deriving a <em>pseudo-count</em> from this density model. Let’s first define a conditional probability over the state space, <script type="math/tex">\rho_n(s) = \rho(s \vert s_{1:n})</script> as the probability of the <script type="math/tex">(n+1)</script>-th state being <script type="math/tex">s</script> given the first <script type="math/tex">n</script> states are <script type="math/tex">s_{1:n}</script>. To measure this empirically, we can simply use <script type="math/tex">N_n(s)/n</script>.</p>
<p>Let’s also define a <em>recoding probability</em> of a state <script type="math/tex">s</script> as the probability assigned by the density model to <script type="math/tex">s</script> <em>after observing a new occurrence of</em> <script type="math/tex">s</script>, <script type="math/tex">\rho'_n(s) = \rho(s \vert s_{1:n}s)</script>.</p>
<p>The paper introduced two concepts to better regulate the density model, a <em>pseudo-count</em> function <script type="math/tex">\hat{N}_n(s)</script> and a <em>pseudo-count total</em> <script type="math/tex">\hat{n}</script>. As they are designed to imitate an empirical count function, we would have:</p>
<script type="math/tex; mode=display">\rho_n(s) = \frac{\hat{N}_n(s)}{\hat{n}} \leq \rho'_n(s) = \frac{\hat{N}_n(s) + 1}{\hat{n} + 1}</script>
<p>The relationship between <script type="math/tex">\rho_n(x)</script> and <script type="math/tex">\rho'_n(x)</script> requires the density model to be <em>learning-positive</em>: for all <script type="math/tex">s_{1:n} \in \mathcal{S}^n</script> and all <script type="math/tex">s \in \mathcal{S}</script>, <script type="math/tex">\rho_n(s) \leq \rho'_n(s)</script>. In other words, After observing one instance of <script type="math/tex">s</script>, the density model’s prediction of that same <script type="math/tex">s</script> should increase. Apart from being learning-positive, the density model should be trained completely <em>online</em> with non-randomized mini-batches of experienced states, so naturally we have <script type="math/tex">\rho'_n = \rho_{n+1}</script>.</p>
<p>The pseudo-count can be computed from <script type="math/tex">\rho_n(s)</script> and <script type="math/tex">\rho'_n(s)</script> after solving the above linear system:</p>
<script type="math/tex; mode=display">\hat{N}_n(s) = \hat{n} \rho_n(s) = \frac{\rho_n(s)(1 - \rho'_n(s))}{\rho'_n(s) - \rho_n(s)}</script>
<p>Or estimated by the <em>prediction gain (PG)</em>:</p>
<script type="math/tex; mode=display">\hat{N}_n(s) \approx (e^{\text{PG}_n(s)} - 1)^{-1} = (e^{\log \rho'_n(s) - \log \rho(s)} - 1)^{-1}</script>
<p>A common choice of a count-based intrinsic bonus is <script type="math/tex">r^i_t = N(s_t, a_t)^{-1/2}</script> (as in MBIE-EB; <a href="https://www.ics.uci.edu/~dechter/courses/ics-295/fall-2019/papers/2008-littman-aij-main.pdf">Strehl & Littman, 2008</a>). The pseudo-count-based exploration bonus is shaped in a similar form, <script type="math/tex">r^i_t = \big(\hat{N}_n(s_t, a_t) + 0.01 \big)^{-1/2}</script>.</p>
<p>Experiments in <a href="https://arxiv.org/abs/1606.01868">Bellemare et al., (2016)</a> adopted a simple <a href="http://proceedings.mlr.press/v32/bellemare14.html">CTS</a> (Context Tree Switching) density model to estimate pseudo-counts. The CTS model takes as input a 2D image and assigns to it a probability according to the product of location-dependent L-shaped filters, where the prediction of each filter is given by a CTS algorithm trained on past images. The CTS model is simple but limited in expressiveness, scalability, and data efficiency. In a following-up paper, <a href="https://arxiv.org/abs/1703.01310">Georg Ostrovski, et al. (2017)</a> improved the approach by training a PixelCNN (<a href="https://arxiv.org/abs/1606.05328">van den Oord et al., 2016</a>) as the density model.</p>
<p>The density model can also be a Gaussian Mixture Model as in <a href="https://arxiv.org/abs/1902.08039">Zhao & Tresp (2018)</a>. They used a variational GMM to estimate the density of trajectories (e.g. concatenation of a sequence of states) and its predicted probabilities to guide prioritization in experience replay in off-policy setting.</p>
<h4 id="counting-after-hashing">Counting after Hashing</h4>
<p>Another idea to make it possible to count high-dimensional states is to map states into <strong>hash codes</strong> so that the occurrences of states become trackable (<a href="https://arxiv.org/abs/1611.04717">Tang et al. 2017</a>). The state space is discretized with a hash function <script type="math/tex">\phi: \mathcal{S} \mapsto \mathbb{Z}^k</script>. An exploration bonus <script type="math/tex">r^{i}: \mathcal{S} \mapsto \mathbb{R}</script> is added to the reward function, defined as <script type="math/tex">r^{i}(s) = {N(\phi(s))}^{-1/2}</script>, where <script type="math/tex">N(\phi(s))</script> is an empirical count of occurrences of <script type="math/tex">\phi(s)</script>.</p>
<p><a href="https://arxiv.org/abs/1611.04717">Tang et al. (2017)</a> proposed to use <em>Locality-Sensitive Hashing</em> (<a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing"><em>LSH</em></a>) to convert continuous, high-dimensional data to discrete hash codes. LSH is a popular class of hash functions for querying nearest neighbors based on certain similarity metrics. A hashing scheme <script type="math/tex">x \mapsto h(x)</script> is locality-sensitive if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. (See how LSH is used in <a href="/lil-log/2020/04/07/the-transformer-family.html#LSH">Transformer improvement</a> if interested.) <a href="https://www.cs.princeton.edu/courses/archive/spr04/cos598B/bib/CharikarEstim.pdf">SimHash</a> is a type of computationally efficient LSH and it measures similarity by angular distance:</p>
<script type="math/tex; mode=display">\phi(s) = \text{sgn}(A g(s)) \in \{-1, 1\}^k</script>
<p>where <script type="math/tex">A \in \mathbb{R}^{k \times D}</script> is a matrix with each entry drawn i.i.d. from a standard Gaussian and <script type="math/tex">g: \mathcal{S} \mapsto \mathbb{R}^D</script> is an optional preprocessing function. The dimension of binary codes is <script type="math/tex">k</script>, controlling the granularity of the state space discretization. A higher <script type="math/tex">k</script> leads to higher granularity and fewer collisions.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/count-hashing-exploration.png" alt="#Exploration" /></p>
<p><em>Fig. 2. Algorithm of count-based exploration through hashing high-dimensional states by SimHash. (Image source: <a href="https://arxiv.org/abs/1611.04717">Tang et al. 2017</a>)</em></p>
<p>For high-dimensional images, SimHash may not work well on the raw pixel level. <a href="https://arxiv.org/abs/1611.04717">Tang et al. (2017)</a> designed an autoencoder (AE) which takes as input states <script type="math/tex">s</script> to learn hash codes. It has one special dense layer composed of <script type="math/tex">k</script> sigmoid functions as the latent state in the middle and then the sigmoid activation values <script type="math/tex">b(s)</script> of this layer are binarized by rounding to their closest binary numbers <script type="math/tex">\lfloor b(s)\rceil \in \{0, 1\}^D</script> as the binary hash codes for state <script type="math/tex">s</script>. The AE loss over <script type="math/tex">n</script> states includes two terms:</p>
<script type="math/tex; mode=display">\mathcal{L}(\{s_n\}_{n=1}^N) = \underbrace{-\frac{1}{N} \sum_{n=1}^N \log p(s_n)}_\text{reconstruction loss} + \underbrace{\frac{1}{N} \frac{\lambda}{K} \sum_{n=1}^N\sum_{i=1}^k \min \big \{ (1-b_i(s_n))^2, b_i(s_n)^2 \big\}}_\text{sigmoid activation being closer to binary}</script>
<p>One problem with this approach is that dissimilar inputs <script type="math/tex">s_i, s_j</script> may be mapped to identical hash codes but the AE still reconstructs them perfectly. One can imagine replacing the bottleneck layer <script type="math/tex">b(s)</script> with the hash codes <script type="math/tex">\lfloor b(s)\rceil</script>, but then gradients cannot be back-propagated through the rounding function. Injecting uniform noise could mitigate this effect, as the AE has to learn to push the latent variable far apart to counteract the noise.</p>
<h3 id="prediction-based-exploration">Prediction-based Exploration</h3>
<p>The second category of intrinsic exploration bonuses are rewarded for improvement of the agent’s knowledge about the environment. The agent’s familiarity with the environment dynamics can be estimated through a prediction model. This idea of using a prediction model to measure <em>curiosity</em> was actually proposed quite a long time ago (<a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.45.957">Schmidhuber, 1991</a>).</p>
<h4 id="forward-dynamics">Forward Dynamics</h4>
<p>Learning a <strong>forward dynamics prediction model</strong> is a great way to approximate how much knowledge our model has obtained about the environment and the task MDPs. It captures an agent’s capability of predicting the consequence of its own behavior, <script type="math/tex">f: (s_t, a_t) \mapsto s_{t+1}</script>. Such a model cannot be perfect (e.g. due to partial observation), the error <script type="math/tex">e(s_t, a_t) = \| f(s_t, a_t) - s_{t+1} \|^2_2</script> can be used for providing intrinsic exploration rewards. The higher the prediction error, the less familiar we are with that state. The faster the error rate drops, the more learning progress signals we acquire.</p>
<p><em>Intelligent Adaptive Curiosity</em> (<strong>IAC</strong>; <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.177.7661&rep=rep1&type=pdf">Oudeyer, et al. 2007</a>) sketched an idea of using a forward dynamics prediction model to estimate learning progress and assigned intrinsic exploration reward accordingly.</p>
<p>IAC relies on a memory which stores all the experiences encountered by the robot, <script type="math/tex">M=\{(s_t, a_t, s_{t+1})\}</script> and a forward dynamics model <script type="math/tex">f</script>. IAC incrementally splits the state space (i.e. sensorimotor space in the context of robotics, as discussed in the paper) into separate regions based on the transition samples, using a process similar to how a decision tree is split: The split happens when the number of samples is larger than a threshold, and the variance of states in each leaf should be minimal. Each tree node is characterized by its exclusive set of samples and has its own forward dynamics predictor <script type="math/tex">f</script>, named “expert”.</p>
<p>The prediction error <script type="math/tex">e_t</script> of an expert is pushed into a list associated with each region. The <em>learning progress</em> is then measured as the difference between the mean error rate of a moving window with offset <script type="math/tex">\tau</script> and the current moving window. The intrinsic reward is defined for tracking the learning progress: <script type="math/tex">r^i_t = \frac{1}{k}\sum_{i=0}^{k-1}(e_{t-i-\tau} - e_{t-i})</script>, where <script type="math/tex">k</script> is the moving window size. So the larger prediction error rate decrease we can achieve, the higher intrinsic reward we would assign to the agent. In other words, the agent is encouraged to take actions to quickly learn about the environment.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/IAC.png" alt="IAC" /></p>
<p><em>Fig. 3. Architecture of the IAC (Intelligent Adaptive Curiosity) module: the intrinsic reward is assigned w.r.t the learning progress in reducing prediction error of the dynamics model. (Image source: <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.177.7661&rep=rep1&type=pdf">Oudeyer, et al. 2007</a>)</em></p>
<p><a href="https://arxiv.org/abs/1507.00814">Stadie et al. (2015)</a> trained a forward dynamics model in the encoding space defined by <script type="math/tex">\phi</script>, <script type="math/tex">f_\phi: (\phi(s_t), a_t) \mapsto \phi(s_{t+1})</script>. The model’s prediction error at time <script type="math/tex">T</script> is normalized by the maximum error up to time <script type="math/tex">t</script>, <script type="math/tex">\bar{e}_t = \frac{e_t}{\max_{i \leq t} e_i}</script>, so it is always between 0 and 1. The intrinsic reward is defined accordingly: <script type="math/tex">r^i_t = (\frac{\bar{e}_t(s_t, a_t)}{t \cdot C})</script>, where <script type="math/tex">C > 0</script> is a decay constant.</p>
<p>Encoding the state space via <script type="math/tex">\phi(.)</script> is necessary, as experiments in the paper have shown that a dynamics model trained directly on raw pixels has <em>very poor</em> behavior — assigning same exploration bonuses to all the states. In <a href="https://arxiv.org/abs/1507.00814">Stadie et al. (2015)</a>, the encoding function <script type="math/tex">\phi</script> is learned via an autocoder (AE) and <script type="math/tex">\phi(.)</script> is one of the output layers in AE. The AE can be statically trained using a set of images collected by a random agent, or dynamically trained together with the policy where the early frames are gathered using <a href="#classic-exploration-strategies"><script type="math/tex">\epsilon</script>-greedy</a> exploration.</p>
<p><a name="ICM"></a>Instead of autoencoder, <em>Intrinsic Curiosity Module</em> (<strong>ICM</strong>; <a href="https://arxiv.org/abs/1705.05363">Pathak, et al., 2017</a>) learns the state space encoding <script type="math/tex">\phi(.)</script> with a self-supervised <strong>inverse dynamics</strong> model. Predicting the next state given the agent’s own action is not easy, especially considering that some factors in the environment cannot be controlled by the agent or do not affect the agent. ICM believes that a good state feature space should exclude such factors because <em>they cannot influence the agent’s behavior and thus the agent has no incentive for learning them</em>. By learning an inverse dynamics model <script type="math/tex">g: (\phi(s_t), \phi(s_{t+1})) \mapsto a_t</script>, the feature space only captures those changes in the environment related to the actions of our agent, and ignores the rest.</p>
<p>Given a forward model <script type="math/tex">f</script>, an inverse dynamics model <script type="math/tex">g</script> and an observation <script type="math/tex">(s_t, a_t, s_{t+1})</script>:</p>
<script type="math/tex; mode=display">g_{\psi_I}(\phi(s_t), \phi(s_{t+1})) = \hat{a}_t \quad
f_{\psi_F}(\phi(s_t), a_t) = \hat{\phi}(s_{t+1}) \quad
r_t^i = \| \hat{\phi}(s_{t+1}) - \phi(s_{t+1}) \|_2^2</script>
<p>Such <script type="math/tex">\phi(.)</script> is expected to be robust to uncontrollable aspects of the environment.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ICM.png" alt="ICM" /></p>
<p><em>Fig. 4. ICM (Intrinsic Curiosity Module) assigns the forward dynamics prediction error to the agent as the intrinsic reward. This dynamics model operates in a state encoding space learned through an inverse dynamics model to exclude environmental factors that do not affect the agent’s behavior. (Image source: <a href="https://arxiv.org/abs/1705.05363">Pathak, et al. 2017</a>)</em></p>
<p><a href="https://arxiv.org/abs/1808.04355">Burda, Edwards & Pathak, et al. (2018)</a> did a set of large-scale comparison experiments on purely curiosity-driven learning, meaning that only intrinsic rewards are provided to the agent. In this study, the reward is <script type="math/tex">r_t = r^i_t = \| f(s_t, a_t) - \phi(s_{t+1})\|_2^2</script>. A good choice of <script type="math/tex">\phi</script> is crucial to learning forward dynamics, which is expected to be <em>compact</em>, <em>sufficient</em> and <em>stable</em>, making the prediction task more tractable and filtering out irrelevant observation.</p>
<p>In comparison of 4 encoding functions:</p>
<ol>
<li>Raw image pixels: No encoding, <script type="math/tex">\phi(x) = x</script>.</li>
<li><a name="random-feature"></a>Random features (RF): Each state is compressed through a fixed random neural network.</li>
<li><a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#vae-variational-autoencoder">VAE</a>: The probabilistic encoder is used for encoding, <script type="math/tex">\phi(x) = q(z \vert x)</script>.</li>
<li>Inverse dynamic features (IDF): The same feature space as used in <a href="#ICM">ICM</a>.</li>
</ol>
<p>All the experiments have the reward signals normalized by a running estimation of standard deviation of the cumulative returns. And all the experiments are running in an infinite horizon setting to avoid “done” flag leaking information.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/large-scale-curiosity-learning.png" alt="Large-scale curiosity learning" /></p>
<p><em>Fig. 5. The mean reward in different games when training with only curiosity signals, generated by different state encoding functions.
(Image source: <a href="https://arxiv.org/abs/1808.04355">Burda, Edwards & Pathak, et al. 2018</a>)</em></p>
<p>Interestingly <em>random features</em> turn out to be quite competitive, but in feature transfer experiments (i.e. train an agent in Super Mario Bros level 1-1 and then test it in another level), learned IDF features can generalize better.</p>
<p>They also compared RF and IDF in an environment with a <a href="#the-noisy-tv-problem">noisy TV</a> on. Unsurprisingly the noisy TV drastically slows down the learning and extrinsic rewards are much lower in time.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/noisy-TV-experiment.png" alt="Noisy TV experiment" /></p>
<p><em>Fig. 6. Experiments using RF and IDF feature encoding in an environment with noisy TV on or off. The plot tracks extrinsic reward per episode as the training progresses. (Image source: <a href="https://arxiv.org/abs/1808.04355">Burda, Edwards & Pathak, et al. 2018</a>)</em></p>
<p>The forward dynamics optimization can be modeled via variational inference as well. <strong>VIME</strong> (short for <em>“Variational information maximizing exploration”</em>; <a href="https://arxiv.org/abs/1605.09674">Houthooft, et al. 2017</a>) is an exploration strategy based on maximization of <em>information gain</em> about the agent’s belief of environment dynamics. How much additional information has been obtained about the forward dynamics can be measured as the reduction in entropy.</p>
<p>Let <script type="math/tex">\mathcal{P}</script> be the environment transition function, <script type="math/tex">p(s_{t+1}\vert s_t, a_t; \theta)</script> be the forward prediction model, parameterized by <script type="math/tex">\theta \in \Theta</script>, and <script type="math/tex">\xi_t = \{s_1, a_1, \dots, s_t\}</script> be the trajectory history. We would like to reduce the entropy after taking a new action and observing the next state, which is to maximize the following:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
&\sum_t H(\Theta \vert \xi_t, a_t) - H(\Theta \vert S_{t+1}, \xi_t, a_t) \\
=& I(\Theta; S_{t+1} \vert \xi_t, a_t) \quad \scriptstyle{\text{; because } I(X; Y) = I(X) - I(X \vert Y)} \\
=& \mathbb{E}_{s_{t+1} \sim \mathcal{P}(.\vert\xi_t,a_t)} [D_\text{KL}(p(\theta \vert \xi_t, a_t, s_{t+1}) \| p(\theta \vert \xi_t, a_t))] \quad \scriptstyle{\text{; because } I(X; Y) = \mathbb{E}_Y [D_\text{KL} (p_{X \vert Y} \| p_X)]} \\
=& \mathbb{E}_{s_{t+1} \sim \mathcal{P}(.\vert\xi_t,a_t)} [D_\text{KL}(p(\theta \vert \xi_t, a_t, s_{t+1}) \| p(\theta \vert \xi_t))] \quad \scriptstyle{\text{; because } \theta \text{ does not depend on } a_t}
\end{aligned} %]]></script>
<p>While taking expectation over the new possible states, the agent is expected to take a new action to increase the KL divergence (<em>“information gain”</em>) between its new belief over the prediction model to the old one. This term can be added into the reward function as an intrinsic reward: <script type="math/tex">r^i_t = D_\text{KL} [p(\theta \vert \xi_t, a_t, s_{t+1}) \| p(\theta \vert \xi_t))]</script>.</p>
<p>However, computing the posterior <script type="math/tex">p(\theta \vert \xi_t, a_t, s_{t+1})</script> is generally intractable.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
p(\theta \vert \xi_t, a_t, s_{t+1})
&= \frac{p(\theta \vert \xi_t, a_t) p(s_{t+1} \vert \xi_t, a_t; \theta)}{p(s_{t+1}\vert\xi_t, a_t)} \\
&= \frac{p(\theta \vert \xi_t) p(s_{t+1} \vert \xi_t, a_t; \theta)}{p(s_{t+1}\vert\xi_t, a_t)} & \scriptstyle{\text{; because action doesn't affect the belief.}} \\
&= \frac{\color{red}{p(\theta \vert \xi_t)} p(s_{t+1} \vert \xi_t, a_t; \theta)}{\int_\Theta p(s_{t+1}\vert\xi_t, a_t; \theta) \color{red}{p(\theta \vert \xi_t)} d\theta} & \scriptstyle{\text{; red part is hard to compute directly.}}
\end{aligned} %]]></script>
<p>Since it is difficult to compute <script type="math/tex">p(\theta\vert\xi_t)</script> directly, a natural choice is to approximate it with an alternative distribution <script type="math/tex">q_\phi(\theta)</script>. With variational lower bound, we know the maximization of <script type="math/tex">q_\phi(\theta)</script> is equivalent to maximizing <script type="math/tex">p(\xi_t\vert\theta)</script> and minimizing <script type="math/tex">D_\text{KL}[q_\phi(\theta) \| p(\theta)]</script>.</p>
<p>Using the approximation distribution <script type="math/tex">q</script>, the intrinsic reward becomes:</p>
<script type="math/tex; mode=display">r^i_t = D_\text{KL} [q_{\phi_{t+1}}(\theta) \| q_{\phi_t}(\theta))]</script>
<p>where <script type="math/tex">\phi_{t+1}</script> represents <script type="math/tex">q</script>’s parameters associated with the new relief after seeing <script type="math/tex">a_t</script> and <script type="math/tex">s_{t+1}</script>. When used as an exploration bonus, it is normalized by division by the moving median of this KL divergence value.</p>
<p>Here the dynamics model is parameterized as a <a href="https://link.springer.com/book/10.1007/978-1-4612-0745-0">Bayesian neural network</a> (BNN), as it maintains a distribution over its weights. The BNN weight distribution <script type="math/tex">q_\phi(\theta)</script> is modeled as a fully <em>factorized</em> Gaussian with <script type="math/tex">\phi = \{\mu, \sigma\}</script> and we can easily sample <script type="math/tex">\theta \sim q_\phi(.)</script>. After applying a second-order Taylor expansion, the KL term <script type="math/tex">D_\text{KL}[q_{\phi + \lambda \Delta\phi}(\theta) \| q_{\phi}(\theta)]</script> can be estimated using <a href="/lil-log/2019/09/05/evolution-strategies.html#estimation-using-fisher-information-matrix">Fisher Information Matrix</a> <script type="math/tex">\mathbf{F}_\phi</script>, which is easy to compute, because <script type="math/tex">q_\phi</script> is factorized Gaussian and thus the covariance matrix is only a diagonal matrix. See more details in <a href="https://arxiv.org/abs/1605.09674">the paper</a>, especially section 2.3-2.5.</p>
<p><a name="exploration-via-disagreement"></a>All the methods above depend on a single prediction model. If we have multiple such models, we could use the disagreement among models to set the exploration bonus (<a href="https://arxiv.org/abs/1906.04161">Pathak, et al. 2019</a>). High disagreement indicates low confidence in prediction and thus requires more exploration. <a href="https://arxiv.org/abs/1906.04161">Pathak, et al. (2019)</a> proposed to train a set of forward dynamics models and to use the variance over the ensemble of model outputs as <script type="math/tex">r_t^i</script>. Precisely, they encode the state space with <a href="#random-feature">random feature</a> and learn 5 models in the ensemble.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/exploration-via-disagreement.png" alt="Disagreement" /></p>
<p><em>Fig. 7. Illustration of training architecture for self-supervised exploration via disagreement. (Image source: <a href="https://arxiv.org/abs/1906.04161">Pathak, et al. 2019</a>)</em></p>
<p>Because <script type="math/tex">r^i_t</script> is differentiable, the intrinsic reward in the model could be directly optimized through gradient descent so as to inform the policy agent to change actions. This differentiable exploration approach is very efficient but limited by having a short exploration horizon.</p>
<h4 id="random-networks">Random Networks</h4>
<p>But, what if the prediction task is not about the environment dynamics at all? It turns out when the prediction is for a random task, it still can help exploration.</p>
<p><strong>DORA</strong> (short for <em>“Directed Outreaching Reinforcement Action-Selection”</em>; <a href="https://arxiv.org/abs/1804.04012">Fox & Choshen, et al. 2018</a>) is a novel framework that injects exploration signals based on a newly introduced, <strong>task-independent</strong> MDP. The idea of DORA depends on two parallel MDPs:</p>
<ul>
<li>One is the original task MDP;</li>
<li>The other is an identical MDP but with <em>no reward attached</em>: Rather, every state-action pair is designed to have value 0. The Q-value learned for the second MDP is called <em>E-value</em>. If the model cannot perfectly predict E-value to be zero, it is still missing information.</li>
</ul>
<p>Initially E-value is assigned with value 1. Such positive initialization can encourage directed exploration for better E-value prediction. State-action pairs with high E-value estimation don’t have enough information gathered yet, at least not enough to exclude their high E-values. To some extent, the logarithm of E-values can be considered as a generalization of <em>visit counters</em>.</p>
<p>When using a neural network to do function approximation for E-value, another value head is added to predict E-value and it is simply expected to predict zero. Given a predicted E-value <script type="math/tex">E(s_t, a_t)</script>, the exploration bonus is <script type="math/tex">r^i_t = \frac{1}{\sqrt{-\log E(s_t, a_t)}}</script>.</p>
<p><a name="RND"></a>Similar to DORA, <strong>Random Network Distillation</strong> (<strong>RND</strong>; <a href="https://arxiv.org/abs/1810.12894">Burda, et al. 2018</a>) introduces a prediction task <em>independent of the main task</em>. The RND exploration bonus is defined as the error of a neural network <script type="math/tex">\hat{f}(s_t)</script> predicting features of the observations given by a <em>fixed randomly initialized</em> neural network <script type="math/tex">f(s_t)</script>. The motivation is that given a new state, if similar states have been visited many times in the past, the prediction should be easier and thus has lower error. The exploration bonus is <script type="math/tex">r^i(s_t) = \|\hat{f}(s_t; \theta) - f(s_t) \|_2^2</script>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RND.png" alt="RND" /></p>
<p><em>Fig. 8. How RND (Random Network Distillation) works for providing an intrinsic reward. The features <script type="math/tex">O_{i+1} \mapsto f_{i+1}</script> are generated by a fixed random neural network. (Image source: <a href="https://openai.com/blog/reinforcement-learning-with-prediction-based-rewards/">OpenAI Blog: “Reinforcement Learning with Prediction-Based Rewards”</a>)</em></p>
<p>Two factors are important in RND experiments:</p>
<ol>
<li>Non-episodic setting results in better exploration, especially when not using any extrinsic rewards. It means that the return is not truncated at “Game over” and intrinsic return can spread across multiple episodes.</li>
<li>Normalization is important since the scale of the reward is tricky to adjust given a random neural network as a prediction target. The intrinsic reward is normalized by division by a running estimate of the standard deviations of the intrinsic return.</li>
</ol>
<p>The RND setup works well for resolving the hard-exploration problem. For example, maximizing the RND exploration bonus consistently finds more than half of the rooms in Montezuma’s Revenge.</p>
<h4 id="physical-properties">Physical Properties</h4>
<p>Different from games in simulators, some RL applications like Robotics need to understand objects and intuitive reasoning in the physical world. Some prediction tasks require the agent to perform a sequence of interactions with the environment and to observe the corresponding consequences, such as estimating some hidden properties in physics (e.g. mass, friction, etc).</p>
<p>Motivated by such ideas, <a href="https://arxiv.org/abs/1611.01843">Denil, et al. (2017)</a> found that DRL agents can learn to perform necessary exploration to discover such hidden properties. Precisely they considered two experiments:</p>
<ol>
<li><em>“Which is heavier?”</em> — The agent has to interact with the blocks and infer which one is heavier.</li>
<li><em>“Towers”</em> — The agent needs to infer how many rigid bodies a tower is composed of by knocking it down.</li>
</ol>
<p>The agent in the experiments first goes through an exploration phase to interact with the environment and to collect information. Once the exploration phase ends, the agent is asked to output a <em>labeling</em> action to answer the question. Then a positive reward is assigned to the agent if the answer is correct; otherwise a negative one is assigned. Because the answer requires a decent amount of interactions with items in the scene, the agent has to learn to efficiently play around so as to figure out the physics and the correct answer. The exploration naturally happens.</p>
<p>In their experiments, the agent is able to learn in both tasks with performance varied by the difficulty of the task. Although the paper didn’t use the physics prediction task to provide intrinsic reward bonus along with extrinsic reward associated with another learning task, rather it focused on the exploration tasks themselves. I do enjoy the idea of encouraging sophisticated exploration behavior by predicting hidden physics properties in the environment.</p>
<h2 id="memory-based-exploration">Memory-based Exploration</h2>
<p>Reward-based exploration suffers from several drawbacks:</p>
<ul>
<li>Function approximation is slow to catch up.</li>
<li>Exploration bonus is non-stationary.</li>
<li>Knowledge fading, meaning that states cease to be novel and cannot provide intrinsic reward signals in time.</li>
</ul>
<p>Methods in this section rely on external memory to resolve disadvantages of reward bonus-based exploration.</p>
<h3 id="episodic-memory">Episodic Memory</h3>
<p>As mentioned above, <a href="#RND">RND</a> is better running in an non-episodic setting, meaning the prediction knowledge is accumulated across multiple episodes. The exploration strategy, <strong>Never Give Up</strong> (<strong>NGU</strong>; <a href="https://arxiv.org/abs/2002.06038">Badia, et al. 2020a</a>), combines an episodic novelty module that can rapidly adapt within one episode with RND as a lifelong novelty module.</p>
<p>Precisely, the intrinsic reward in NGU consists of two exploration bonuses from two modules, <em>within one episode</em> and <em>across multiple episodes</em>, respectively.</p>
<p>The short-term per-episode reward is provided by an <em>episodic novelty module</em>. It contains an episodic memory <script type="math/tex">M</script>, a dynamically-sized slot-based memory, and an IDF (inverse dynamics features) embedding function <script type="math/tex">\phi</script>, same as the feature encoding in <a href="#ICM">ICM</a></p>
<ol>
<li>At every step the current state embedding <script type="math/tex">\phi(s_t)</script> is added into <script type="math/tex">M</script>.</li>
<li>The intrinsic bonus is determined by comparing how similar the current observation is to the content of <script type="math/tex">M</script>. A larger difference results in a larger bonus.
<br />
<script type="math/tex">r^\text{episodic}_t \approx \frac{1}{\sqrt{\sum_{\phi_i \in N_k} K(\phi(x_t), \phi_i)} + c}</script>
<br />
where <script type="math/tex">K(x, y)</script> is a kernel function for measuring the distance between two samples. <script type="math/tex">N_k</script> is a set of <script type="math/tex">k</script> nearest neighbors in <script type="math/tex">M</script> according to <script type="math/tex">K(., .)</script>. <script type="math/tex">c</script> is a small constant to keep the denominator non-zero. In the paper, <script type="math/tex">K(x, y)</script> is configured to be the inverse kernel:
<br />
<script type="math/tex">K(x, y) = \frac{\epsilon}{\frac{d^2(x, y)}{d^2_m} + \epsilon}</script>
<br />
where <script type="math/tex">d(.,.)</script> is Euclidean distance between two samples and <script type="math/tex">d_m</script> is a running average of the squared Euclidean distance of the k-th nearest neighbors for better robustness. <script type="math/tex">\epsilon</script> is a small constant.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NGU.png" alt="RND" /></p>
<p><em>Fig. 9. The architecture of NGU’s embedding function (left) and reward generator (right). (Image source: <a href="https://arxiv.org/abs/2002.06038">Badia, et al. 2020a</a>)</em></p>
<p>The long-term across-episode novelty relies on RND prediction error in <em>life-long novelty module</em>. The exploration bonus is <script type="math/tex">\alpha_t = 1 + \frac{e^\text{RND}(s_t) - \mu_e}{\sigma_e}</script> where <script type="math/tex">\mu_e</script> and <script type="math/tex">\sigma_e</script> are running mean and std dev for RND error <script type="math/tex">e^\text{RND}(s_t)</script>.</p>
<blockquote>
<p>However in the conclusion section of the <a href="https://arxiv.org/abs/1810.12894">RND paper</a>, I noticed the following statement:</p>
<p>“We find that the RND exploration bonus is sufficient to deal with local exploration, i.e. exploring the consequences of short-term decisions, like whether to interact with a particular object, or avoid it. However global exploration that involves coordinated decisions over long time horizons is beyond the reach of our method. “</p>
<p>And this confuses me a bit how RND can be used as a good life-long novelty bonus provider. If you know why, feel free to leave a comment below.</p>
</blockquote>
<p>The final combined intrinsic reward is <script type="math/tex">r^i_t = r^\text{episodic}_t \cdot \text{clip}(\alpha_t, 1, L)</script>, where <script type="math/tex">L</script> is a constant maximum reward scalar.</p>
<p>The design of NGU enables it to have two nice properties:</p>
<ol>
<li><em>Rapidly discourages</em> revisiting the same state <em>within</em> the same episode;</li>
<li><em>Slowly discourages</em> revisiting states that have been visited many times <em>across</em> episodes.</li>
</ol>
<p>Later, built on top of NGU, DeepMind proposed “Agent57” (<a href="https://arxiv.org/abs/2003.13350">Badia, et al. 2020b</a>), the first deep RL agent that outperforms the standard human benchmark on <em>all</em> 57 Atari games. Two major improvements in Agent57 over NGU are:</p>
<ol>
<li>A <em>population</em> of policies are trained in Agent57, each equipped with a different exploration parameter pair <script type="math/tex">\{(\beta_j, \gamma_j)\}_{j=1}^N</script>. Recall that given <script type="math/tex">\beta_j</script>, the reward is constructed as <script type="math/tex">r_{j,t} = r_t^e + \beta_j r^i_t</script> and <script type="math/tex">\gamma_j</script> is the reward discounting factor. It is natural to expect policies with higher <script type="math/tex">\beta_j</script> and lower <script type="math/tex">\gamma_j</script> to make more progress early in training, while the opposite would be expected as training progresses. A meta-controller (<a href="https://arxiv.org/pdf/0805.3415.pdf">sliding-window UCB bandit algorithm</a>) is trained to select which policies should be prioritized.</li>
<li>The second improvement is a new parameterization of Q-value function that decomposes the contributions of the intrinsic and extrinsic rewards in a similar form as the bundled reward: <script type="math/tex">Q(s, a; \theta_j) = Q(s, a; \theta_j^e) + \beta_j Q(s, a; \theta_j^i)</script>. During training, <script type="math/tex">Q(s, a; \theta_j^e)</script> and <script type="math/tex">Q(s, a; \theta_j^i)</script> are optimized separately with rewards <script type="math/tex">r_j^e</script> and <script type="math/tex">r_j^i</script>, respectively.</li>
</ol>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/agent57.png" alt="Agent57" /></p>
<p><em>Fig. 10. A pretty cool illustration of techniques developed in time since DQN in 2015, eventually leading to Agent57. (Image source: <a href="https://deepmind.com/blog/article/Agent57-Outperforming-the-human-Atari-benchmark">DeepMind Blog: “Agent57: Outperforming the human Atari benchmark”</a>)</em></p>
<p>Instead of using the Euclidean distance to measure closeness of states in episodic memory, <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. (2019)</a> took the transition between states into consideration and proposed a method to measure the number of steps needed to visit one state from other states in memory, named <strong>Episodic Curiosity (EC)</strong> module. The novelty bonus depends on reachability between states.</p>
<ol>
<li>At the beginning of each episode, the agent starts with an empty episodic memory <script type="math/tex">M</script>.</li>
<li>At every step, the agent compares the current state with saved states in memory to determine novelty bonus: If the current state is novel (i.e., takes more steps to reach from observations in memory than a threshold), the agent gets a bonus.</li>
<li>The current state is added into the episodic memory if the novelty bonus is high enough. (Imagine that if all the states were added into memory, any new state could be added within 1 step.)</li>
<li>Repeat 1-3 until the end of this episode.</li>
</ol>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/transition-graph.png" alt="Transition graph" /></p>
<p><em>Fig. 11. The nodes in the graph are states, the edges are possible transitions. The blue nodes are states in memory. The green nodes are reachable from the memory within <script type="math/tex">k = 2</script> steps (not novel). The orange nodes are further away, so they are considered as novel states. (Image source: <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. 2019</a>)</em></p>
<p>In order to estimate reachability between states, we need to access the transition graph, which is unfortunately not entirely known. Thus, <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. (2019)</a> trained a <a href="/lil-log/2018/11/30/meta-learning.html#convolutional-siamese-neural-network">siamese</a> neural network to predict how many steps separate two states. It contains one embedding network <script type="math/tex">\phi: \mathcal{S} \mapsto \mathbb{R}^n</script> to first encode the states to feature vectors and then one comparator network <script type="math/tex">C: \mathbb{R}^n \times \mathbb{R}^n \mapsto [0, 1]</script> to output a binary label on whether two states are close enough (i.e., reachable within <script type="math/tex">k</script> steps) in the transition graph, <script type="math/tex">C(\phi(s_i), \phi(s_j)) \mapsto [0, 1]</script>.</p>
<p>An episodic memory buffer <script type="math/tex">M</script> stores embeddings of some past observations within the same episode. A new observation will be compared with existing state embeddings via <script type="math/tex">C</script> and the results are aggregated (e.g. max, 90th percentile) to provide a reachability score <script type="math/tex">C^M(\phi(s_t))</script>. The exploration bonus is <script type="math/tex">r^i_t = \big(C' - C^M(f(s_t))\big)</script>, where <script type="math/tex">C'</script> is a predefined threshold for determining the sign of the reward (e.g. <script type="math/tex">C'=0.5</script> works well for fixed-duration episodes). High bonus is awarded to new states when they are not easily reachable from states in the memory buffer.</p>
<p>They claimed that the EC module can overcome the <a href="#the-noisy-tv-problem">noisy-TV</a> problem.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/episodic-memory-overview.png" alt="EC module" /></p>
<p><em>Fig. 12. The architecture of episodic curiosity (EC) module for intrinsic reward generation. (Image source: <a href="https://arxiv.org/abs/1810.02274">Savinov, et al. 2019</a>)</em></p>
<h3 id="direct-exploration">Direct Exploration</h3>
<p><strong>Go-Explore</strong> (<a href="https://arxiv.org/abs/1901.10995">Ecoffet, et al., 2019</a>) is an algorithm aiming to solve the “hard-exploration” problem. It is composed of the following two phases.</p>
<p><strong>Phase 1 (“Explore until solved”)</strong> feels quite like <a href="https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm">Dijkstra’s algorithm</a> for finding shortest paths in a graph. Indeed, no neural network is involved in phase 1. By maintaining a memory of interesting states as well as trajectories leading to them, the agent can go back (given a simulator is <em>deterministic</em>) to promising states and continue doing <em>random</em> exploration from there. The state is mapped into a short discretized code (named “cell”) in order to be memorized. The memory is updated if a new state appears or a better/shorter trajectory is found. When selecting which past states to return to, the agent might select one in the memory uniformly or according to heuristics like recency, visit count, count of neighbors in the memory, etc. This process is repeated until the task is solved and at least one solution trajectory is found.</p>
<p>The above found high-performance trajectories would not work well on evaluation envs with any stochasticity. Thus, <strong>Phase 2 (“Robustification”)</strong> is needed to robustify the solution via imitation learning. They adopted <a href="https://arxiv.org/abs/1812.03381">Backward Algorithm</a>, in which the agent is started near the last state in the trajectory and then runs RL optimization from there.</p>
<p>One important note in phase 1 is: In order to go back to a state deterministically without exploration, Go-Explore depends on a resettable and deterministic simulator, which is a big disadvantage.</p>
<p>To make the algorithm more generally useful to environments with stochasticity, an enhanced version of Go-Explore (<a href="https://arxiv.org/abs/2004.12919">Ecoffet, et al., 2020</a>), named <strong>policy-based Go-Explore</strong> was proposed later.</p>
<ul>
<li>Instead of resetting the simulator state effortlessly, the policy-based Go-Explore learns a <em>goal-conditioned policy</em> and uses that to access a known state in memory repeatedly. The goal-conditioned policy is trained to follow the best trajectory that previously led to the selected states in memory. They include a <strong>Self-Imitation Learning</strong> (<strong>SIL</strong>; <a href="https://arxiv.org/abs/1806.05635">Oh, et al. 2018</a>) loss to help extract as much information as possible from successful trajectories.</li>
<li>Also, they found sampling from policy works better than random actions when the agent returns to promising states to continue exploration.</li>
<li>Another improvement in policy-based Go-Explore is to make the downscaling function of images to cells adjustable. It is optimized so that there would be neither too many nor too few cells in the memory.</li>
</ul>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/policy-based-Go-Explore.png" alt="Policy-based Go-Explore" /></p>
<p><em>Fig. 13. An overview of the Go-Explore algorithm. (Image source: <a href="https://arxiv.org/abs/2004.12919">Ecoffet, et al., 2020</a>)</em></p>
<p>After vanilla Go-Explore, <a href="https://arxiv.org/abs/1907.10247">Yijie Guo, et al. (2019)</a> proposed <strong>DTSIL</strong> (Diverse Trajectory-conditioned Self-Imitation Learning), which shared a similar idea as policy-based Go-Explore above. DTSIL maintains a memory of diverse demonstrations collected during training and uses them to train a trajectory-conditioned policy via <a href="https://arxiv.org/abs/1806.05635">SIL</a>. They prioritize trajectories that end with a rare state during sampling.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DTSIL-algo.png" alt="DTSIL" /></p>
<p><em>Fig. 14. Algorithm of DTSIL (Diverse Trajectory-conditioned Self-Imitation Learning). (Image source: <a href="https://arxiv.org/abs/1907.10247">Yijie Guo, et al. 2019</a>)</em></p>
<p>The similar approach is also seen in <a href="https://arxiv.org/abs/1906.07805">Guo, et al. (2019)</a>. The main idea is to store goals with <em>high uncertainty</em> in memory so that later the agent can revisit these goal states with a goal-conditioned policy repeatedly. In each episode, the agent flips a coin (probability 0.5) to decide whether it will act greedily w.r.t. the policy or do directed exploration by sampling goals from the memory.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/directed-exploration.png" alt="Directed exploration" /></p>
<p><em>Fig. 15. Different components in directed exploration with function approximation. (Image source: <a href="https://arxiv.org/abs/1906.07805">Guo, et al. 2019</a>)</em></p>
<p>The uncertainty measure of a state can be something simple like count-based bonuses or something complex like density or bayesian models. The paper trained a forward dynamics model and took its prediction error as the uncertainty metric.</p>
<h2 id="q-value-exploration">Q-Value Exploration</h2>
<p>Inspired by <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#thompson-sampling">Thompson sampling</a>, <strong>Bootstrapped DQN</strong> (<a href="https://arxiv.org/abs/1602.04621">Osband, et al. 2016</a>) introduces a notion of uncertainty in Q-value approximation in classic <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#deep-q-network">DQN</a> by using the <a href="https://en.wikipedia.org/wiki/Bootstrapping_(statistics)">bootstrapping</a> method. Bootstrapping is to approximate a distribution by sampling with replacement from the same population multiple times and then aggregate the results.</p>
<p>Multiple Q-value heads are trained in parallel but each only consumes a bootstrapped sub-sampled set of data and each has its own corresponding target network. All the Q-value heads share the same backbone network.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/bootstrapped-DQN-algo.png" alt="Bootstrapped DQN" /></p>
<p><em>Fig. 16. The algorithm of Bootstrapped DQN. (Image source: <a href="https://arxiv.org/abs/1602.04621">Osband, et al. 2016</a>)</em></p>
<p>At the beginning of one episode, one Q-value head is sampled uniformly and acts for collecting experience data in this episode. Then a binary mask is sampled from the masking distribution <script type="math/tex">m \sim \mathcal{M}</script> and decides which heads can use this data for training. The choice of masking distribution <script type="math/tex">\mathcal{M}</script> determines how bootstrapped samples are generated; For example,</p>
<ul>
<li>If <script type="math/tex">\mathcal{M}</script> is an independent Bernoulli distribution with <script type="math/tex">p=0.5</script>, this corresponds to the double-or-nothing bootstrap.</li>
<li>If <script type="math/tex">\mathcal{M}</script> always returns an all-one mask, the algorithm reduces to an ensemble method.</li>
</ul>
<p>However, this kind of exploration is still restricted, because uncertainty introduced by bootstrapping fully relies on the training data. It is better to inject some prior information independent of the data. This “noisy” prior is expected to drive the agent to keep exploring when the reward is sparse. The algorithm of adding random prior into bootstrapped DQN for better exploration (<a href="https://arxiv.org/abs/1806.03335">Osband, et al. 2018</a>) depends on Bayesian linear regression. The core idea of Bayesian regression is: We can <em>“generate posterior samples by training on noisy versions of the data, together with some random regularization”</em>.</p>
<p>Let <script type="math/tex">\theta</script> be the Q function parameter and <script type="math/tex">\theta^-</script> for the target Q, the loss function using a randomized prior function <script type="math/tex">p</script> is:</p>
<script type="math/tex; mode=display">\mathcal{L}(\theta, \theta^{-}, p, \mathcal{D}; \gamma) = \sum_{t\in\mathcal{D}}\Big( r_t + \gamma \max_{a'\in\mathcal{A}} (\underbrace{Q_{\theta^-} + p)}_\text{target Q}(s'_t, a') - \underbrace{(Q_\theta + p)}_\text{Q to optimize}(s_t, a_t) \Big)^2</script>
<h2 id="varitional-options">Varitional Options</h2>
<p>Options are policies with termination conditions. There are a large set of options available in the search space and they are independent of an agent’s intentions. By explicitly including intrinsic options into modeling, the agent can obtain intrinsic rewards for exploration.</p>
<p><strong>VIC</strong> (short for <em>“Variational Intrinsic Control”</em>; <a href="https://arxiv.org/abs/1611.07507">Gregor, et al. 2017</a>) is such a framework for providing the agent with intrinsic exploration bonuses based on modeling options and learning policies conditioned on options. Let <script type="math/tex">\Omega</script> represent an option which starts from <script type="math/tex">s_0</script> and ends at <script type="math/tex">s_f</script>. An environment probability distribution <script type="math/tex">p^J(s_f \vert s_0, \Omega)</script> defines where an option <script type="math/tex">\Omega</script> terminates given a starting state <script type="math/tex">s_0</script>. A controllability distribution <script type="math/tex">p^C(\Omega \vert s_0)</script> defines the probability distribution of options we can sample from. And by definition we have <script type="math/tex">p(s_f, \Omega \vert s_0) = p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0)</script>.</p>
<p>While choosing options, we would like to achieve two goals:</p>
<ul>
<li>Achieve a diverse set of the final states from <script type="math/tex">s_0</script> ⇨ Maximization of <script type="math/tex">H(s_f \vert s_0)</script>.</li>
<li>Know precisely which state a given option <script type="math/tex">\Omega</script> can end with ⇨ Minimization of <script type="math/tex">H(s_f \vert s_0, \Omega)</script>.</li>
</ul>
<p>Combining them, we get mutual information <script type="math/tex">I(\Omega; s_f \vert s_0)</script> to maximize:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
I(\Omega; s_f \vert s_0)
&= H(s_f \vert s_0) - H(s_f \vert s_0, \Omega) \\
&= - \sum_{s_f} p(s_f \vert s_0) \log p(s_f \vert s_0) + \sum_{s_f, \Omega} p(s_f, \Omega \vert s_0) \log \frac{p(s_f, \Omega \vert s_0)}{p^C(\Omega \vert s_0)} \\
&= - \sum_{s_f} p(s_f \vert s_0) \log p(s_f \vert s_0) + \sum_{s_f, \Omega} p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0) \log p^J(s_f \vert s_0, \Omega) \\
\end{aligned} %]]></script>
<p>Because mutual information is symmetric, we can switch <script type="math/tex">s_f</script> and <script type="math/tex">\Omega</script> in several places without breaking the equivalence. Also because <script type="math/tex">p(\Omega \vert s_0, s_f)</script> is difficult to observe, let us replace it with an approximation distribution <script type="math/tex">q</script>. According to the variational lower bound, we would have <script type="math/tex">I(\Omega; s_f \vert s_0) \geq I^{VB}(\Omega; s_f \vert s_0)</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
I(\Omega; s_f \vert s_0)
&= I(s_f; \Omega \vert s_0) \\
&= - \sum_{\Omega} p(\Omega \vert s_0) \log p(\Omega \vert s_0) + \sum_{s_f, \Omega} p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0) \log \color{red}{p(\Omega \vert s_0, s_f)}\\
I^{VB}(\Omega; s_f \vert s_0)
&= - \sum_{\Omega} p(\Omega \vert s_0) \log p(\Omega \vert s_0) + \sum_{s_f, \Omega} p^J(s_f \vert s_0, \Omega) p^C(\Omega \vert s_0) \log \color{red}{q(\Omega \vert s_0, s_f)} \\
I(\Omega; s_f \vert s_0) &\geq I^{VB}(\Omega; s_f \vert s_0)
\end{aligned} %]]></script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/VIC-explicit-options.png" alt="VIC" /></p>
<p><em>Fig. 17. The algorithm for VIC (Variational Intrinsic Control). (Image source: <a href="https://arxiv.org/abs/1611.07507">Gregor, et al. 2017</a>)</em></p>
<p>Here <script type="math/tex">\pi(a \vert \Omega, s)</script> can be optimized with any RL algorithm. The option inference function <script type="math/tex">q(\Omega \vert s_0, s_f)</script> is doing supervised learning. The prior <script type="math/tex">p^C</script> is updated so that it tends to choose <script type="math/tex">\Omega</script> with higher rewards. Note that <script type="math/tex">p^C</script> can also be fixed (e.g. a Gaussian). Various <script type="math/tex">\Omega</script> will result in different behavior through learning. Additionally, <a href="https://arxiv.org/abs/1611.07507">Gregor, et al. (2017)</a> observed that it is difficult to make VIC with explicit options work in practice with function approximation and therefore they also proposed another version of VIC with implicit options.</p>
<p>Different from VIC which models <script type="math/tex">\Omega</script> conditioned only on the start and end states, <strong>VALOR</strong> (short for <em>“Variational Auto-encoding Learning of Options by Reinforcement”</em>; <a href="https://arxiv.org/abs/1807.10299">Achiam, et al. 2018</a>) relies on the whole trajectory to extract the option context <script type="math/tex">c</script>, which is sampled from a fixed Gaussian distribution. In VALOR:</p>
<ul>
<li>A policy acts as an encoder, translating contexts from a noise distribution into trajectories</li>
<li>A decoder attempts to recover the contexts from the trajectories, and rewards the policies for making contexts easier to distinguish. The decoder never sees the actions during training, so the agent has to interact with the environment in a way that facilitates communication with the decoder for better prediction. Also, the decoder recurrently takes in a sequence of steps in one trajectory to better model the correlation between timesteps.</li>
</ul>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/VALOR-decoder.png" alt="VALOR" /></p>
<p><em>Fig. 18. The decoder of VALOR is a biLSTM which takes <script type="math/tex">N = 11</script> equally spaced observations from one trajectory as inputs. (Image source: <a href="https://arxiv.org/abs/1807.10299">Achiam, et al. 2018</a>)</em></p>
<p>DIAYN (“Diversity is all you need”; <a href="https://arxiv.org/abs/1802.06070">Eysenbach, et al. 2018</a>) has the idea lying in the same direction, although with a different name — DIAYN models the policies conditioned on a latent <em>skill</em> variable. See my <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#learning-with-random-rewards">previous post</a> for more details.</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020exploration,
title = "Exploration Strategies in Deep Reinforcement Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/06/07/exploration-strategies-in-deep-reinforcement-learning.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Pierre-Yves Oudeyer & Frederic Kaplan. <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.567.6524&rep=rep1&type=pdf">“How can we define intrinsic motivation?”</a> Conf. on Epigenetic Robotics, 2008.</p>
<p>[2] Marc G. Bellemare, et al. <a href="https://arxiv.org/abs/1606.01868">“Unifying Count-Based Exploration and Intrinsic Motivation”</a>. NIPS 2016.</p>
<p>[3] Georg Ostrovski, et al. <a href="https://arxiv.org/abs/1703.01310">“Count-Based Exploration with Neural Density Models”</a>. PMLR 2017.</p>
<p>[4] Rui Zhao & Volker Tresp. <a href="https://arxiv.org/abs/1902.08039">“Curiosity-Driven Experience Prioritization via
Density Estimation”</a>. NIPS 2018.</p>
<p>[5] Haoran Tang, et al. <a href="https://arxiv.org/abs/1611.04717">“#Exploration: A Study of Count-Based Exploration for Deep Reinforcement Learning”</a>. NIPS 2017.</p>
<p>[6] Jürgen Schmidhuber. <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.45.957">“A possibility for implementing curiosity and boredom in model-building neural controllers”</a> 1991.</p>
<p>[7] Pierre-Yves Oudeyer, et al. <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.177.7661&rep=rep1&type=pdf">“Intrinsic Motivation Systems for Autonomous Mental Development”</a> IEEE Transactions on Evolutionary Computation, 2007.</p>
<p>[8] Bradly C. Stadie, et al. <a href="https://arxiv.org/abs/1507.00814">“Incentivizing Exploration In Reinforcement Learning With Deep Predictive Models”</a>. ICLR 2016.</p>
<p>[9] Deepak Pathak, et al. <a href="https://arxiv.org/abs/1705.05363">“Curiosity-driven Exploration by Self-supervised Prediction”</a>. CVPR 2017.</p>
<p>[10] Yuri Burda, Harri Edwards & Deepak Pathak, et al. <a href="https://arxiv.org/abs/1808.04355">“Large-Scale Study of Curiosity-Driven Learning”</a>. arXiv 1808.04355 (2018).</p>
<p>[11] Joshua Achiam & Shankar Sastry. <a href="https://arxiv.org/abs/1703.01732">“Surprise-Based Intrinsic Motivation for Deep Reinforcement Learning”</a> NIPS 2016 Deep RL Workshop.</p>
<p>[12] Rein Houthooft, et al. <a href="https://arxiv.org/abs/1605.09674">“VIME: Variational information maximizing exploration”</a>. NIPS 2016.</p>
<p>[13] Leshem Choshen, Lior Fox & Yonatan Loewenstein. <a href="https://arxiv.org/abs/1804.04012">“DORA the explorer: Directed outreaching reinforcement action-selection”</a>. ICLR 2018</p>
<p>[14] Yuri Burda, et al. <a href="https://arxiv.org/abs/1810.12894">“Exploration by Random Network Distillation”</a> ICLR 2019.</p>
<p>[15] OpenAI Blog: <a href="https://openai.com/blog/reinforcement-learning-with-prediction-based-rewards/">“Reinforcement Learning with
Prediction-Based Rewards”</a> Oct, 2018.</p>
<p>[16] Misha Denil, et al. <a href="https://arxiv.org/abs/1611.01843">“Learning to Perform Physics Experiments via Deep Reinforcement Learning”</a>. ICLR 2017.</p>
<p>[17] Ian Osband, et al. <a href="https://arxiv.org/abs/1602.04621">“Deep Exploration via Bootstrapped DQN”</a>. NIPS 2016.</p>
<p>[18] Ian Osband, John Aslanides & Albin Cassirer. <a href="https://arxiv.org/abs/1806.03335">“Randomized Prior Functions for Deep Reinforcement Learning”</a>. NIPS 2018.</p>
<p>[19] Karol Gregor, Danilo Jimenez Rezende & Daan Wierstra. <a href="https://arxiv.org/abs/1611.07507">“Variational Intrinsic Control”</a>. ICLR 2017.</p>
<p>[20] Joshua Achiam, et al. <a href="https://arxiv.org/abs/1807.10299">“Variational Option Discovery Algorithms”</a>. arXiv 1807.10299 (2018).</p>
<p>[21] Benjamin Eysenbach, et al. <a href="https://arxiv.org/abs/1802.06070">“Diversity is all you need: Learning skills without a reward function.”</a>. ICLR 2019.</p>
<p>[22] Adrià Puigdomènech Badia, et al. <a href="https://arxiv.org/abs/2002.06038">“Never Give Up (NGU): Learning Directed Exploration Strategies”</a> ICLR 2020.</p>
<p>[23] Adrià Puigdomènech Badia, et al. <a href="https://arxiv.org/abs/2003.13350">“Agent57: Outperforming the Atari Human Benchmark”</a>. arXiv 2003.13350 (2020).</p>
<p>[24] DeepMind Blog: <a href="https://deepmind.com/blog/article/Agent57-Outperforming-the-human-Atari-benchmark">“Agent57: Outperforming the human Atari benchmark”</a> Mar 2020.</p>
<p>[25] Nikolay Savinov, et al. <a href="https://arxiv.org/abs/1810.02274">“Episodic Curiosity through Reachability”</a> ICLR 2019.</p>
<p>[26] Adrien Ecoffet, et al. <a href="https://arxiv.org/abs/1901.10995">“Go-Explore: a New Approach for Hard-Exploration Problems”</a>. arXiv 1901.10995 (2019).</p>
<p>[27] Adrien Ecoffet, et al. <a href="https://arxiv.org/abs/2004.12919">“First return then explore”</a>. arXiv 2004.12919 (2020).</p>
<p>[28] Junhyuk Oh, et al. <a href="https://arxiv.org/abs/1806.05635">“Self-Imitation Learning”</a>. ICML 2018.</p>
<p>[29] Yijie Guo, et al. <a href="https://arxiv.org/abs/1907.10247">“Self-Imitation Learning via Trajectory-Conditioned Policy for Hard-Exploration Tasks”</a>. arXiv 1907.10247 (2019).</p>
<p>[30] Zhaohan Daniel Guo & Emma Brunskill. <a href="https://arxiv.org/abs/1906.07805">“Directed Exploration for Reinforcement Learning”</a>. arXiv 1906.07805 (2019).</p>
<p>[31] Deepak Pathak, et al. <a href="https://arxiv.org/abs/1906.04161">“Self-Supervised Exploration via Disagreement.”</a> ICML 2019.</p>Lilian WengExploitation versus exploration is a critical topic in reinforcement learning. This post introduces several common approaches for better exploration in Deep RL.The Transformer Family2020-04-07T12:00:00+00:002020-04-07T12:00:00+00:00https://lilianweng.github.io/lil-log/2020/04/07/the-transformer-family<blockquote>
<p>Inspired by recent progress on various enhanced versions of Transformer models, this post presents how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving, etc.</p>
</blockquote>
<!--more-->
<p>It has been almost two years since my last post on <a href="/lil-log/2018/06/24/attention-attention.html">attention</a>. Recent progress on new and enhanced versions of Transformer motivates me to write another post on this specific topic, focusing on how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving and more.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#notations" id="markdown-toc-notations">Notations</a></li>
<li><a href="#attention-and-self-attention" id="markdown-toc-attention-and-self-attention">Attention and Self-Attention</a></li>
<li><a href="#multi-head-self-attention" id="markdown-toc-multi-head-self-attention">Multi-Head Self-Attention</a></li>
<li><a href="#transformer" id="markdown-toc-transformer">Transformer</a></li>
<li><a href="#adaptive-computation-time-act" id="markdown-toc-adaptive-computation-time-act">Adaptive Computation Time (ACT)</a></li>
<li><a href="#improved-attention-span" id="markdown-toc-improved-attention-span">Improved Attention Span</a> <ul>
<li><a href="#longer-attention-span-transformer-xl" id="markdown-toc-longer-attention-span-transformer-xl">Longer Attention Span (Transformer-XL)</a></li>
<li><a href="#adaptive-attention-span" id="markdown-toc-adaptive-attention-span">Adaptive Attention Span</a></li>
<li><a href="#localized-attention-span-image-transformer" id="markdown-toc-localized-attention-span-image-transformer">Localized Attention Span (Image Transformer)</a></li>
</ul>
</li>
<li><a href="#less-time-and-memory-cost" id="markdown-toc-less-time-and-memory-cost">Less Time and Memory Cost</a> <ul>
<li><a href="#sparse-attention-matrix-factorization-sparse-transformers" id="markdown-toc-sparse-attention-matrix-factorization-sparse-transformers">Sparse Attention Matrix Factorization (Sparse Transformers)</a></li>
<li><a href="#locality-sensitive-hashing-reformer" id="markdown-toc-locality-sensitive-hashing-reformer">Locality-Sensitive Hashing (Reformer)</a></li>
</ul>
</li>
<li><a href="#make-it-recurrent-universal-transformer" id="markdown-toc-make-it-recurrent-universal-transformer">Make it Recurrent (Universal Transformer)</a></li>
<li><a href="#stabilization-for-rl-gtrxl" id="markdown-toc-stabilization-for-rl-gtrxl">Stabilization for RL (GTrXL)</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h3 id="notations">Notations</h3>
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr>
<td><script type="math/tex">d</script></td>
<td>The model size / hidden state dimension / positional encoding size.</td>
</tr>
<tr>
<td><script type="math/tex">h</script></td>
<td>The number of heads in multi-head attention layer.</td>
</tr>
<tr>
<td><script type="math/tex">L</script></td>
<td>The segment length of input sequence.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{X} \in \mathbb{R}^{L \times d}</script></td>
<td>The input sequence where each element has been mapped into an embedding vector of shape <script type="math/tex">d</script>, same as the model size.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{W}^k \in \mathbb{R}^{d \times d_k}</script></td>
<td>The key weight matrix.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{W}^q \in \mathbb{R}^{d \times d_k}</script></td>
<td>The query weight matrix.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{W}^v \in \mathbb{R}^{d \times d_v}</script></td>
<td>The value weight matrix. Often we have <script type="math/tex">d_k = d_v = d</script>.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{W}^k_i, \mathbf{W}^q_i \in \mathbb{R}^{d \times d_k/h}; \mathbf{W}^v_i \in \mathbb{R}^{d \times d_v/h}</script></td>
<td>The weight matrices per head.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{W}^o \in \mathbb{R}^{d_v \times d}</script></td>
<td>The output weight matrix.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{Q} = \mathbf{X}\mathbf{W}^q \in \mathbb{R}^{L \times d_k}</script></td>
<td>The query embedding inputs.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{K} = \mathbf{X}\mathbf{W}^k \in \mathbb{R}^{L \times d_k}</script></td>
<td>The key embedding inputs.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{V} = \mathbf{X}\mathbf{W}^v \in \mathbb{R}^{L \times d_v}</script></td>
<td>The value embedding inputs.</td>
</tr>
<tr>
<td><script type="math/tex">S_i</script></td>
<td>A collection of key positions for the <script type="math/tex">i</script>-th query <script type="math/tex">\mathbf{q}_i</script> to attend to.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{A} \in \mathbb{R}^{L \times L}</script></td>
<td>The self-attention matrix between a input sequence of lenght <script type="math/tex">L</script> and itself. <script type="math/tex">\mathbf{A} = \text{softmax}(\mathbf{Q}\mathbf{K}^\top / \sqrt{d_k})</script>.</td>
</tr>
<tr>
<td><script type="math/tex">a_{ij} \in \mathbf{A}</script></td>
<td>The scalar attention score between query <script type="math/tex">\mathbf{q}_i</script> and key <script type="math/tex">\mathbf{k}_j</script>.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{P} \in \mathbb{R}^{L \times d}</script></td>
<td>position encoding matrix, where the <script type="math/tex">i</script>-th row <script type="math/tex">\mathbf{p}_i</script> is the positional encoding for input <script type="math/tex">\mathbf{x}_i</script>.</td>
</tr>
</tbody>
</table>
<h2 id="attention-and-self-attention">Attention and Self-Attention</h2>
<p><em>Attention</em> is a mechanism in the neural network that a model can learn to make predictions by selectively attending to a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed as a weighted average.</p>
<p><em>Self-attention</em> is a type of attention mechanism where the model makes prediction for one part of a data sample using other parts of the observation about the same sample. Conceptually, it feels quite similar to <a href="https://en.wikipedia.org/wiki/Non-local_means">non-local means</a>. Also note that self-attention is permutation-invariant; in other words, it is an operation on sets.</p>
<p>There are various forms of attention / self-attention, Transformer (<a href="https://arxiv.org/abs/1706.03762">Vaswani et al., 2017</a>) relies on the <em>scaled dot-product attention</em>: given a query matrix <script type="math/tex">\mathbf{Q}</script>, a key matrix <script type="math/tex">\mathbf{K}</script> and a value matrix <script type="math/tex">\mathbf{V}</script>, the output is a weighted sum of the value vectors, where the weight assigned to each value slot is determined by the dot-product of the query with the corresponding key:</p>
<script type="math/tex; mode=display">\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}(\frac{\mathbf{Q} {\mathbf{K}}^\top}{\sqrt{d_k}})\mathbf{V}</script>
<p>And for a query and a key vector <script type="math/tex">\mathbf{q}_i, \mathbf{k}_j \in \mathbb{R}^d</script> (row vectors in query and key matrices), we have a scalar score:</p>
<script type="math/tex; mode=display">a_{ij} = \text{softmax}(\frac{\mathbf{q}_i {\mathbf{k}_j}^\top}{\sqrt{d_k}})
= \frac{\exp(\mathbf{q}_i {\mathbf{k}_j}^\top)}{ \sqrt{d_k} \sum_{r \in S_i} \exp(\mathbf{q}_i {\mathbf{k}_r}^\top) }</script>
<p>where <script type="math/tex">S_i</script> is a collection of key positions for the <script type="math/tex">i</script>-th query to attend to.</p>
<p>See my old <a href="/lil-log/2018/06/24/attention-attention.html#a-family-of-attention-mechanisms">post</a> for other types of attention if interested.</p>
<h2 id="multi-head-self-attention">Multi-Head Self-Attention</h2>
<p>The <em>multi-head self-attention</em> module is a key component in Transformer. Rather than only computing the attention once, the multi-head mechanism splits the inputs into smaller chunks and then computes the scaled dot-product attention over each subspace in parallel. The independent attention outputs are simply concatenated and linearly transformed into expected dimensions.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\text{MultiHeadAttention}(\mathbf{X}_q, \mathbf{X}_k, \mathbf{X}_v) &= [\text{head}_1; \dots; \text{head}_h] \mathbf{W}^o \\
\text{where head}_i &= \text{Attention}(\mathbf{X}_q\mathbf{W}^q_i, \mathbf{X}_k\mathbf{W}^k_i, \mathbf{X}_v\mathbf{W}^v_i)
\end{aligned} %]]></script>
<p>where <script type="math/tex">[.;.]</script> is a concatenation operation. <script type="math/tex">\mathbf{W}^q_i, \mathbf{W}^k_i \in \mathbb{R}^{d \times d_k/h}, \mathbf{W}^v_i \in \mathbb{R}^{d \times d_v/h}</script> are weight matrices to map input embeddings of size <script type="math/tex">L \times d</script> into query, key and value matrices. And <script type="math/tex">\mathbf{W}^o \in \mathbb{R}^{d_v \times d}</script> is the output linear transformation. All the weights should be learned during training.</p>
<p style="width: 30%;" class="center"><img src="/lil-log/assets/images/multi-head-attention.png" alt="Multi-head scaled dot-product attention" /></p>
<p><em>Fig. 1. Illustration of the multi-head scaled dot-product attention mechanism. (Image source: Figure 2 in <a href="https://arxiv.org/abs/1706.03762">Vaswani, et al., 2017</a>)</em></p>
<h2 id="transformer">Transformer</h2>
<p>The <strong>Transformer</strong> (which will be referred to as “vanilla Transformer” to distinguish it from other enhanced versions; <a href="https://arxiv.org/abs/1706.03762">Vaswani, et al., 2017</a>) model has an encoder-decoder architecture, as commonly used in many <a href="/lil-log/2018/06/24/attention-attention.html#born-for-translation">NMT</a> models. Later decoder-only Transformer was shown to achieve great performance in language modeling tasks, like in <a href="/lil-log/2019/01/31/generalized-language-models.html#openai-gpt">GPT and BERT</a>.</p>
<p><strong>Encoder-Decoder Architecture</strong></p>
<p>The <strong>encoder</strong> generates an attention-based representation with capability to locate a specific piece of information from a large context. It consists of a stack of 6 identity modules, each containing two submodules, a <em>multi-head self-attention</em> layer and a <em>point-wise</em> fully connected feed-forward network. By point-wise, it means that it applies the same linear transformation (with same weights) to each element in the sequence. This can also be viewed as a convolutional layer with filter size 1. Each submodule has a residual connection and layer normalization. All the submodules output data of the same dimension <script type="math/tex">d</script>.</p>
<p>The function of Transformer <strong>decoder</strong> is to retrieve information from the encoded representation. The architecture is quite similar to the encoder, except that the decoder contains two multi-head attention submodules instead of one in each identical repeating module. The first multi-head attention submodule is <em>masked</em> to prevent positions from attending to the future.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer.png" alt="Transformer" /></p>
<p><em>Fig. 2. The architecture of the vanilla Transformer model. (Image source: <a href="/lil-log/2018/06/24/attention-attention.html#full-architecture">Figure 17</a>)</em></p>
<p><strong>Positional Encoding</strong></p>
<p>Because self-attention operation is permutation invariant, it is important to use proper <strong>positional encoding</strong>to provide <em>order information</em> to the model. The positional encoding <script type="math/tex">\mathbf{P} \in \mathbb{R}^{L \times d}</script> has the same dimension as the input embedding, so it can be added on the input directly. The vanilla Transformer considered two types of encodings:</p>
<p>(1) <em>Sinusoidal positional encoding</em> is defined as follows, given the token position <script type="math/tex">i=1,\dots,L</script> and the dimension <script type="math/tex">\delta=1,\dots,d</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\text{PE}(i,\delta) =
\begin{cases}
\sin(\frac{i}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta'\\
\cos(\frac{i}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta' + 1\\
\end{cases} %]]></script>
<p>In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in different dimensions, from <script type="math/tex">2\pi</script> to <script type="math/tex">10000 \cdot 2\pi</script>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/sinoidual-positional-encoding.png" alt="Transformer" /></p>
<p><em>Fig. 3. Sinusoidal positional encoding with <script type="math/tex">L=32</script> and <script type="math/tex">d=128</script>. The value is between -1 (black) and 1 (white) and the value 0 is in gray.</em></p>
<p>(2) <em>Learned positional encoding</em>, as its name suggested, assigns each element with a learned column vector which encodes its <em>absolute</em> position (<a href="https://arxiv.org/abs/1705.03122">Gehring, et al. 2017</a>).</p>
<p><strong>Quick Follow-ups</strong></p>
<p>Following the vanilla Transformer, <a href="https://arxiv.org/abs/1808.04444">Al-Rfou et al. (2018)</a> added a set of auxiliary losses to enable training a deep Transformer model on character-level language modeling which outperformed LSTMs. Several types of auxiliary tasks are used:</p>
<ul>
<li>Instead of producing only one prediction at the sequence end, every <em>immediate position</em> is also asked to make a correct prediction, forcing the model to predict given smaller contexts (e.g. first couple tokens at the beginning of a context window).</li>
<li>Each intermediate Transformer layer is used for making predictions as well. Lower layers are weighted to contribute less and less to the total loss as training progresses.</li>
<li>Each position in the sequence can predict multiple targets, i.e. two or more predictions of the future tokens.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer-aux-losses.png" alt="Transformer" /></p>
<p><em>Fig. 4. Auxiliary prediction tasks used in deep Transformer for character-level language modeling. (Image source: <a href="https://arxiv.org/abs/1808.04444">Al-Rfou et al. (2018)</a>)</em></p>
<h2 id="adaptive-computation-time-act">Adaptive Computation Time (ACT)</h2>
<p><strong>Adaptive Computation Time</strong> (short for <strong>ACT</strong>; <a href="https://arxiv.org/abs/1603.08983">Graves, 2016</a>) is a mechanism for dynamically deciding how many computational steps are needed in a recurrent neural network. Here is a cool <a href="https://distill.pub/2016/augmented-rnns/#adaptive-computation-time">tutorial</a> on ACT from distill.pub.</p>
<p>Let’s say, we have a RNN model <script type="math/tex">\mathcal{R}</script> composed of input weights <script type="math/tex">W_x</script>, a parametric state transition function <script type="math/tex">\mathcal{S}(.)</script>, a set of output weights <script type="math/tex">W_y</script> and an output bias <script type="math/tex">b_y</script>. Given an input sequence <script type="math/tex">(x_1, \dots, x_L)</script>, the output sequence <script type="math/tex">(y_1, \dots, y_L)</script> is computed by:</p>
<script type="math/tex; mode=display">s_t = \mathcal{S}(s_{t-1}, W_x x_t), \quad y_t = W_y s_t + b_y\quad\text{for }t=1, \dots, L</script>
<p>ACT enables the above RNN setup to perform a variable number of steps at each input element. Multiple computational steps lead to a sequence of intermediate states <script type="math/tex">(s_t^1, \dots, s_t^{N(t)})</script> and outputs <script type="math/tex">(y_t^1, \dots, y_t^{N(t)})</script> — they all share the same state transition function <script type="math/tex">\mathcal{S}(.)</script>, as well as the same output weights <script type="math/tex">W_y</script> and bias <script type="math/tex">b_y</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
s_t^0 &= s_{t-1} \\
s_t^n &= \mathcal{S}(s_{t}^{n-1}, x_t^n) = \mathcal{S}(s_{t}^{n-1}, x_t + \delta_{n,1}) \text{ for } n=1, \dots, N(t)\\
y_t^n &= W_y s_t^n + b_y
\end{aligned} %]]></script>
<p>where <script type="math/tex">\delta_{n,1}</script> is a binary flag indicating whether the input step has been incremented.</p>
<p>The number of steps <script type="math/tex">N(t)</script> is determined by an extra sigmoidal halting unit <script type="math/tex">h</script>, with associated weight matrix <script type="math/tex">W_h</script> and bias <script type="math/tex">b_h</script>, outputting a halting probability <script type="math/tex">p_t^n</script> at immediate step <script type="math/tex">n</script> for <script type="math/tex">t</script>-th input element:</p>
<script type="math/tex; mode=display">h_t^n = \sigma(W_h s_t^n + b_h)</script>
<p>In order to allow the computation to halt after a single step, ACT introduces a small constant <script type="math/tex">\epsilon</script> (e.g. 0.01), so that whenever the cumulative probability goes above <script type="math/tex">1-\epsilon</script>, the computation stops.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
N(t) &= \min(\min\{n': \sum_{n=1}^{n'} h_t^n \geq 1 -\epsilon\}, M) \\
p_t^n &= \begin{cases}
h_t^n & \text{if }n < N(t) \\
R(t) = 1 - \sum_{n=1}^{N(t)-1} h_t^n & \text{if }n= N(t)\\
\end{cases}
\end{aligned} %]]></script>
<p>where <script type="math/tex">M</script> is an upper limit for the number of immediate steps allowed.</p>
<p>The final state and output are mean-field updates:</p>
<script type="math/tex; mode=display">s_t = \sum_{n=1}^{N(t)} p_t^n s_t^n,\quad y_t = \sum_{n=1}^{N(t)} p_t^n y_t^n</script>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/ACT-computation-graph.png" alt="ACT computation graph" /></p>
<p><em>Fig. 5. The computation graph of a RNN with ACT mechanism. (Image source: <a href="https://arxiv.org/abs/1603.08983">Graves, 2016</a>)</em></p>
<p>To avoid unnecessary pondering over each input, ACT adds a <em>ponder cost</em> <script type="math/tex">\mathcal{P}(x) = \sum_{t=1}^L N(t) + R(t)</script> in the loss function to encourage a smaller number of intermediate computational steps.</p>
<h2 id="improved-attention-span">Improved Attention Span</h2>
<p>The goal of improving attention span is to make the context that can be used in self-attention longer, more efficient and flexible.</p>
<h3 id="longer-attention-span-transformer-xl">Longer Attention Span (Transformer-XL)</h3>
<p>The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments.</p>
<p>This <em>context segmentation</em> causes several issues:</p>
<ul>
<li>The model cannot capture very long term dependencies.</li>
<li>It is hard to predict the first few tokens in each segment given no or thin context.</li>
<li>The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens.</li>
</ul>
<p><strong>Transformer-XL</strong> (<a href="https://arxiv.org/abs/1901.02860">Dai et al., 2019</a>; “XL” means “extra long”) solves the context segmentation problem with two main modifications:</p>
<ol>
<li>Reusing hidden states between segments.</li>
<li>Adopting a new positional encoding that is suitable for reused states.</li>
</ol>
<p><strong>Hidden State Reuse</strong></p>
<p>The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer-XL-training.png" alt="Training phrase of Transformer-XL" /></p>
<p><em>Fig. 6. A comparison between the training phrase of vanilla Transformer & Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in <a href="https://arxiv.org/abs/1901.02860">Dai et al., 2019</a>).</em></p>
<p>Let’s label the hidden state of the <script type="math/tex">n</script>-th layer for the <script type="math/tex">(\tau + 1)</script>-th segment in the model as <script type="math/tex">\mathbf{h}_{\tau+1}^{(n)} \in \mathbb{R}^{L \times d}</script>. In addition to the hidden state of the last layer for the same segment <script type="math/tex">\mathbf{h}_{\tau+1}^{(n-1)}</script>, it also depends on the hidden state of the same layer for the previous segment <script type="math/tex">\mathbf{h}_{\tau}^{(n)}</script>. By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} &= [\text{stop-gradient}(\mathbf{h}_{\tau}^{(n-1)}) \circ \mathbf{h}_{\tau+1}^{(n-1)}] \\
\mathbf{Q}_{\tau+1}^{(n)} &= \mathbf{h}_{\tau+1}^{(n-1)}\mathbf{W}^q \\
\mathbf{K}_{\tau+1}^{(n)} &= \color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} \mathbf{W}^k \\
\mathbf{V}_{\tau+1}^{(n)} &= \color{red}{\widetilde{\mathbf{h}}_{\tau+1}^{(n-1)}} \mathbf{W}^v \\
\mathbf{h}_{\tau+1}^{(n)} &= \text{transformer-layer}(\mathbf{Q}_{\tau+1}^{(n)}, \mathbf{K}_{\tau+1}^{(n)}, \mathbf{V}_{\tau+1}^{(n)})
\end{aligned} %]]></script>
<p>Note that both key and value rely on the extended hidden state, while the query only consumes hidden state at current step. The concatenation operation <script type="math/tex">[. \circ .]</script> is along the sequence length dimension.</p>
<p><strong>Relative Positional Encoding</strong></p>
<p>In order to work with this new form of attention span, Transformer-XL proposed a new type of positional encoding. If using the same approach by vanilla Transformer and encoding the absolute position, the previous and current segments will be assigned with the same encoding, which is undesired.</p>
<p>To keep the positional information flow coherently across segments, Transformer-XL encodes the <em>relative</em> position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. <script type="math/tex">i-j</script>, between one key vector <script type="math/tex">\mathbf{k}_{\tau, j}</script> and its query <script type="math/tex">\mathbf{q}_{\tau, i}</script>.</p>
<p>If omitting the scalar <script type="math/tex">1/\sqrt{d_k}</script> and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position <script type="math/tex">i</script> and key at position <script type="math/tex">j</script> as:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
a_{ij}
&= \mathbf{q}_i {\mathbf{k}_j}^\top = (\mathbf{x}_i + \mathbf{p}_i)\mathbf{W}^q ((\mathbf{x}_j + \mathbf{p}_j)\mathbf{W}^k)^\top \\
&= \mathbf{x}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{x}_j^\top + \mathbf{x}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{p}_j^\top + \mathbf{p}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{x}_j^\top + \mathbf{p}_i\mathbf{W}^q {\mathbf{W}^k}^\top\mathbf{p}_j^\top
\end{aligned} %]]></script>
<p>Transformer-XL reparameterizes the above four terms as follows:</p>
<script type="math/tex; mode=display">a_{ij}^\text{rel} =
\underbrace{ \mathbf{x}_i\mathbf{W}^q \color{blue}{ {\mathbf{W}_E^k}^\top } \mathbf{x}_j^\top }_\text{content-based addressing} +
\underbrace{ \mathbf{x}_i\mathbf{W}^q \color{blue}{ {\mathbf{W}_R^k}^\top } \color{green}{\mathbf{r}_{i-j}^\top} }_\text{content-dependent positional bias} +
\underbrace{ \color{red}{\mathbf{u}} \color{blue}{ {\mathbf{W}_E^k}^\top } \mathbf{x}_j^\top }_\text{global content bias} +
\underbrace{ \color{red}{\mathbf{v}} \color{blue}{ {\mathbf{W}_R^k}^\top } \color{green}{\mathbf{r}_{i-j}^\top} }_\text{global positional bias}</script>
<ul>
<li>Replace <script type="math/tex">\mathbf{p}_j</script> with relative positional encoding <script type="math/tex">\mathbf{r}_{i-j} \in \mathbf{R}^{d}</script>;</li>
<li>Replace <script type="math/tex">\mathbf{p}_i\mathbf{W}^q</script> with two trainable parameters <script type="math/tex">\mathbf{u}</script> (for content) and <script type="math/tex">\mathbf{v}</script> (for location) in two different terms;</li>
<li>Split <script type="math/tex">\mathbf{W}^k</script> into two matrices, <script type="math/tex">\mathbf{W}^k_E</script> for content information and <script type="math/tex">\mathbf{W}^k_R</script> for location information.</li>
</ul>
<h3 id="adaptive-attention-span">Adaptive Attention Span</h3>
<p>One key advantage of Transformer is the capability of capturing long-term dependencies. Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other. If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.</p>
<p>This is the motivation for <strong>Adaptive Attention Span</strong>. <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al., (2019)</a> proposed a self-attention mechanism that seeks an optimal attention span. They hypothesized that different attention heads might assign scores differently within the same context window (See Fig. 7) and thus the optimal span would be trained separately per head.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/attention-per-head.png" alt="Attention per head" /></p>
<p><em>Fig. 7. Two attention heads in the same model, A & B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al. 2019</a>)</em></p>
<p>Given the <script type="math/tex">i</script>-th token, we need to compute the attention weights between this token and other keys at positions <script type="math/tex">j \in S_i</script>, where <script type="math/tex">S_i</script> defineds the <script type="math/tex">i</script>-th token’s context window.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
e_{ij} &= \mathbf{q}_i {\mathbf{k}_j}^\top \\
a_{ij} &= \text{softmax}(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{r=i-s}^{i-1} \exp(e_{ir})} \\
\mathbf{y}_i &= \sum_{r=i-s}^{i-1}a_{ir}\mathbf{v}_r = \sum_{r=i-s}^{i-1}a_{ir}\mathbf{x}_r\mathbf{W}^v
\end{aligned} %]]></script>
<p>A <em>soft mask function</em> <script type="math/tex">m_z</script> is added to control for an effective adjustable attention span, which maps the distance between query and key into a [0, 1] value. <script type="math/tex">m_z</script> is parameterized by <script type="math/tex">z \in [0, s]</script> and <script type="math/tex">z</script> is to be learned:</p>
<script type="math/tex; mode=display">m_z(x) = \text{clamp}(\frac{1}{R}(R+z-x), 0, 1)</script>
<p>where <script type="math/tex">R</script> is a hyper-parameter which defines the softness of <script type="math/tex">m_z</script>.</p>
<p style="width: 55%;" class="center"><img src="/lil-log/assets/images/soft-masking-function.png" alt="Soft masking function" /></p>
<p><em>Fig. 8. The soft masking function used in the adaptive attention span. (Image source: <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al. 2019</a>.)</em></p>
<p>The soft mask function is applied to the softmax elements in the attention weights:</p>
<script type="math/tex; mode=display">a_{ij} = \frac{m_z(i-j)\exp(s_{ij})}{\sum_{r=i-s}^{i-1}m_z(i-r) \exp(s_{ir})}</script>
<p>In the above equation, <script type="math/tex">z</script> is differentiable so it is trained jointly with other parts of the model. Parameters <script type="math/tex">z^{(i)}, i=1, \dots, h</script> are learned <em>separately per head</em>. Moreover, the loss function has an extra L1 penalty on <script type="math/tex">\sum_{i=1}^h z^{(i)}</script>.</p>
<p>Using <a href="#adaptive-computation-time-act">Adaptive Computation Time</a>, the approach can be further enhanced to have flexible attention span length, adaptive to the current input dynamically. The span parameter <script type="math/tex">z_t</script> of an attention head at time <script type="math/tex">t</script> is a sigmoidal function, <script type="math/tex">z_t = S \sigma(\mathbf{v} \cdot \mathbf{x}_t +b)</script>, where the vector <script type="math/tex">\mathbf{v}</script> and the bias scalar <script type="math/tex">b</script> are learned jointly with other parameters.</p>
<p>In the experiments of Transformer with adaptive attention span, <a href="https://arxiv.org/abs/1905.07799">Sukhbaatar, et al. (2019)</a> found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans. Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.</p>
<h3 id="localized-attention-span-image-transformer">Localized Attention Span (Image Transformer)</h3>
<p>The original, also the most popular, use case for Transformer is to do language modeling. The text sequence is one-dimensional in a clearly defined chronological order and thus the attention span grows linearly with increased context size.</p>
<p>However, if we want to use Transformer on images, it is unclear how to define the scope of context or the order. <strong>Image Transformer</strong> (<a href="https://arxiv.org/abs/1802.05751">Parmer, et al 2018</a>) embraces a formulation of image generation similar to sequence modeling within the Transformer framework. Additionally, Image Transformer restricts the self-attention span to only <em>local</em> neighborhoods, so that the model can scale up to process more images in parallel and keep the likelihood loss tractable.</p>
<p>The encoder-decoder architecture remains for image-conditioned generation:</p>
<ul>
<li>The encoder generates a contextualized, per-pixel-channel representation of the source image;</li>
<li>The decoder <em>autoregressively</em> generates an output image, one channel per pixel at each time step.</li>
</ul>
<p>Let’s label the representation of the current pixel to be generated as the query <script type="math/tex">\mathbf{q}</script>. Other positions whose representations will be used for computing <script type="math/tex">\mathbf{q}</script> are key vector <script type="math/tex">\mathbf{k}_1, \mathbf{k}_2, \dots</script> and they together form a memory matrix <script type="math/tex">\mathbf{M}</script>. The scope of <script type="math/tex">\mathbf{M}</script> defines the context window for pixel query <script type="math/tex">\mathbf{q}</script>.</p>
<p>Image Transformer introduced two types of localized <script type="math/tex">\mathbf{M}</script>, as illustrated below.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/image-transformer-attention.png" alt="Attention patterns in Image Transformer" /></p>
<p><em>Fig. 9. Illustration of 1D and 2D attention span for visual inputs in Image Transformer. The black line marks a query block and the cyan outlines the actual attention span for pixel q. (Image source: Figure 2 in <a href="https://arxiv.org/abs/1802.05751">Parmer et al, 2018</a>)</em></p>
<p>(1) <em>1D Local Attention</em>: The input image is flattened in the <a href="https://en.wikipedia.org/wiki/Raster_scan#Scanning_pattern">raster scanning</a> order, that is, from left to right and top to bottom. The linearized image is then partitioned into non-overlapping query blocks. The context window consists of pixels in the same query block as <script type="math/tex">\mathbf{q}</script> and a fixed number of additional pixels generated before this query block.</p>
<p>(2) <em>2D Local Attention</em>: The image is partitioned into multiple non-overlapping rectangular query blocks. The query pixel can attend to all others in the same memory blocks. To make sure the pixel at the top-left corner can also have a valid context window, the memory block is extended to the top, left and right by a fixed amount, respectively.</p>
<h2 id="less-time-and-memory-cost">Less Time and Memory Cost</h2>
<p>This section introduces several improvements made on Transformer to reduce the computation time and memory consumption.</p>
<h3 id="sparse-attention-matrix-factorization-sparse-transformers">Sparse Attention Matrix Factorization (Sparse Transformers)</h3>
<p>The compute and memory cost of the vanilla Transformer grows quadratically with sequence length and thus it is hard to be applied on very long sequences.</p>
<p><strong>Sparse Transformer</strong> (<a href="https://arxiv.org/abs/1904.10509">Child et al., 2019</a>) introduced <em>factorized self-attention</em>, through sparse matrix factorization, making it possible to train dense attention networks with hundreds of layers on sequence length up to 16,384, which would be infeasible on modern hardware otherwise.</p>
<p>Given a set of attention connectivity pattern <script type="math/tex">\mathcal{S} = \{S_1, \dots, S_n\}</script>, where each <script type="math/tex">S_i</script> records a set of key positions that the <script type="math/tex">i</script>-th query vector attends to.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\text{Attend}(\mathbf{X}, \mathcal{S}) &= \Big( a(\mathbf{x}_i, S_i) \Big)_{i \in \{1, \dots, L\}} \\
\text{ where } a(\mathbf{x}_i, S_i) &= \text{softmax}\Big(\frac{(\mathbf{x}_i \mathbf{W}^q)(\mathbf{x}_j \mathbf{W}^k)_{j \in S_i}^\top}{\sqrt{d_k}}\Big) (\mathbf{x}_j \mathbf{W}^v)_{j \in S_i}
\end{aligned} %]]></script>
<p>Note that although the size of <script type="math/tex">S_i</script> is not fixed, <script type="math/tex">a(\mathbf{x}_i, S_i)</script> is always of size <script type="math/tex">d_v</script> and thus <script type="math/tex">\text{Attend}(\mathbf{X}, \mathcal{S}) \in \mathbb{R}^{L \times d_v}</script>.</p>
<p>In anto-regressive models, one attention span is defined as <script type="math/tex">S_i = \{j: j \leq i\}</script> as it allows each token to attend to all the positions in the past.</p>
<p>In factorized self-attention, the set <script type="math/tex">S_i</script> is decomposed into a <em>tree</em> of dependencies, such that for every pair of <script type="math/tex">(i, j)</script> where <script type="math/tex">j \leq i</script>, there is a path connecting <script type="math/tex">i</script> back to <script type="math/tex">j</script> and <script type="math/tex">i</script> can attend to <script type="math/tex">j</script> either directly or indirectly.</p>
<p>Precisely, the set <script type="math/tex">S_i</script> is divided into <script type="math/tex">p</script> <em>non-overlapping</em> subsets, where the <script type="math/tex">m</script>-th subset is denoted as <script type="math/tex">A^{(m)}_i \subset S_i, m = 1,\dots, p</script>. Therefore the path between the output position <script type="math/tex">i</script> and any <script type="math/tex">j</script> has a maximum length <script type="math/tex">p + 1</script>. For example, if <script type="math/tex">(j, a, b, c, \dots, i)</script> is a path of indices between <script type="math/tex">i</script> and <script type="math/tex">j</script>, we would have <script type="math/tex">j \in A_a^{(1)}, a \in A_b^{(2)}, b \in A_c^{(3)}, \dots</script>, so on and so forth.</p>
<p><strong>Sparse Factorized Attention</strong></p>
<p>Sparse Transformer proposed two types of fractorized attention. It is easier to understand the concepts as illustrated in Fig. 10 with 2D image inputs as examples.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/sparse-attention.png" alt="Sparse attention" /></p>
<p><em>Fig. 10. The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: <a href="https://arxiv.org/abs/1904.10509">Child et al., 2019</a> + a few of extra annotations.)</em></p>
<p>(1) <em>Strided</em> attention with stride <script type="math/tex">\ell \sim \sqrt{n}</script>. This works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous <script type="math/tex">\ell</script> pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
A_i^{(1)} &= \{ t, t+1, \dots, i\} \text{, where } t = \max(0, i - \ell) \\
A_i^{(2)} &= \{j: (i-j) \mod \ell = 0\}
\end{aligned} %]]></script>
<p>(2) <em>Fixed</em> attention. A small set of tokens summarize previous locations and propagate that information to all future locations.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
A_i^{(1)} &= \{j: \lfloor \frac{j}{\ell} \rfloor = \lfloor \frac{i}{\ell} \rfloor \} \\
A_i^{(2)} &= \{j: j \mod \ell \in \{\ell-c, \dots, \ell-1\} \}
\end{aligned} %]]></script>
<p>where <script type="math/tex">c</script> is a hyperparameter. If <script type="math/tex">c=1</script>, it restricts the representation whereas many depend on a few positions. The paper chose <script type="math/tex">c\in \{ 8, 16, 32 \}</script> for <script type="math/tex">\ell \in \{ 128, 256 \}</script>.</p>
<p><strong>Use Factorized Self-Attention in Transformer</strong></p>
<p>There are three ways to use sparse factorized attention patterns in Transformer architecture:</p>
<ol>
<li>One attention type per residual block and then interleave them, <br />
<script type="math/tex">\text{attention}(\mathbf{X}) = \text{Attend}(\mathbf{X}, A^{(n \mod p)}) \mathbf{W}^o</script>, where <script type="math/tex">n</script> is the index of the current residual block.</li>
<li>Set up a single head which attends to locations that all the factorized heads attend to, <br />
<script type="math/tex">\text{attention}(\mathbf{X}) = \text{Attend}(\mathbf{X}, \cup_{m=1}^p A^{(m)}) \mathbf{W}^o</script>.</li>
<li>Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2. => This option often performs the best.</li>
</ol>
<p>Sparse Transformer also proposed a set of changes so as to train the Transformer up to hundreds of layers, including gradient checkpointing, recomputing attention & FF layers during the backward pass, mixed precision training, efficient block-sparse implementation, etc. Please check the <a href="https://arxiv.org/abs/1904.10509">paper</a> for more details.</p>
<h3 id="locality-sensitive-hashing-reformer">Locality-Sensitive Hashing (Reformer)</h3>
<p>The improvements proposed by the <strong>Reformer</strong> model (<a href="https://arxiv.org/abs/2001.04451">Kitaev, et al. 2020</a>) aim to solve the following pain points in Transformer:</p>
<ul>
<li>Memory in a model with <script type="math/tex">N</script> layers is <script type="math/tex">N</script>-times larger than in a single-layer model because we need to store activations for back-propagation.</li>
<li>The intermediate FF layers are often quite large.</li>
<li>The attention matrix on sequences of length <script type="math/tex">L</script> often requires <script type="math/tex">O(L^2)</script> in both memory and time.</li>
</ul>
<p>Reformer proposed two main changes:</p>
<ol>
<li>Replace the dot-product attention with <em>locality-sensitive hashing (LSH) attention</em>, reducing the complexity from <script type="math/tex">O(L^2)</script> to <script type="math/tex">O(L\log L)</script>.</li>
<li>Replace the standard residual blocks with <em>reversible residual layers</em>, which allows storing activations only once during training instead of <script type="math/tex">N</script> times (i.e. proportional to the number of layers).</li>
</ol>
<p><a name="LSH"></a><strong>Locality-Sensitive Hashing Attention</strong></p>
<p>In <script type="math/tex">\mathbf{Q} \mathbf{K}^\top</script> part of the <a href="#attention-and-self-attention">attention formula</a>, we are only interested in the largest elements as only large elements contribute a lot after softmax. For each query <script type="math/tex">\mathbf{q}_i \in \mathbf{Q}</script>, we are looking for row vectors in <script type="math/tex">\mathbf{K}</script> closest to <script type="math/tex">\mathbf{q}_i</script>. In order to find nearest neighbors quickly in high-dimensional space, Reformer incorporates <a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing">Locality-Sensitive Hashing (LSH)</a> into its attention mechanism.</p>
<p>A hashing scheme <script type="math/tex">x \mapsto h(x)</script> is <em>locality-sensitive</em> if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. The Reformer adopts a hashing scheme as such, given a fixed random matrix <script type="math/tex">\mathbf{R} \in \mathbb{R}^{d \times b/2}</script> (where <script type="math/tex">b</script> is a hyperparam), the hash function is <script type="math/tex">h(x) = \arg\max([xR; −xR])</script>.</p>
<!-- If we omit the scalar in self-attention and summarize the denominator into a normalizing term $$Z(.)$$, an normal attention output looks as follows:
$$
\mathbf{o}_i = \sum_{j \in S_i} \exp(\mathbf{q}_i \cdot \mathbf{k}_j - Z(i, S_i)) \mathbf{v}_j \text{, where } S_i = \{j: j \leq i\}
$$
-->
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/LSH-attention-matrix.png" alt="LSH attention matrix" /></p>
<p><em>Fig. 11. Illustration of Locality-Sensitive Hashing (LSH) attention. (Image source: right part of Figure 1 in <a href="https://arxiv.org/abs/2001.04451">Kitaev, et al. 2020</a>).</em></p>
<p>In LSH attention, a query can only attend to positions in the same hashing bucket, <script type="math/tex">S_i = \{j: h(\mathbf{q}_i) = h(\mathbf{k}_j)\}</script>. It is carried out in the following process, as illustrated in Fig. 11:</p>
<ul>
<li>(a) The attention matrix for full attention is often sparse.</li>
<li>(b) Using LSH, we can sort the keys and queries to be aligned according to their hash buckets.</li>
<li>(c) Set <script type="math/tex">\mathbf{Q} = \mathbf{K}</script> (precisely <script type="math/tex">\mathbf{k}_j = \mathbf{q}_j / \|\mathbf{q}_j\|</script>), so that there are equal numbers of keys and queries in one bucket, easier for batching. Interestingly, this “shared-QK” config does not affect the performance of the Transformer.</li>
<li>(d) Apply batching where chunks of <script type="math/tex">m</script> consecutive queries are grouped together.</li>
</ul>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/LSH-attention.png" alt="LSH attention" /></p>
<p><em>Fig. 12. The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in <a href="https://arxiv.org/abs/2001.04451">Kitaev, et al. 2020</a>).</em></p>
<p><strong>Reversible Residual Network</strong></p>
<p>Another improvement by Reformer is to use <em>reversible residual layers</em> (<a href="https://arxiv.org/abs/1707.04585">Gomez et al. 2017</a>). The motivation for reversible residual network is to design the architecture in a way that activations at any given layer can be recovered from the activations at the following layer, using only the model parameters. Hence, we can save memory by recomputing the activation during backprop rather than storing all the activations.</p>
<p>Given a layer <script type="math/tex">x \mapsto y</script>, the normal residual layer does <script type="math/tex">y = x + F(x)</script>, but the reversible layer splits both input and output into pairs <script type="math/tex">(x_1, x_2) \mapsto (y_1, y_2)</script> and then executes the following:</p>
<script type="math/tex; mode=display">y_1 = x_1 + F(x_2),\; y_2 = x_2 + G(y_1)</script>
<p>and reversing is easy:</p>
<script type="math/tex; mode=display">x_2 = y_2 - G(y_1), \; x_1 = y_1 − F(x_2)</script>
<p>Reformer applies the same idea to Transformer by combination attention (<script type="math/tex">F</script>) and feed-forward layers (<script type="math/tex">G</script>) within a reversible net block:</p>
<script type="math/tex; mode=display">Y_1 = X_1 + \text{Attention}(X_2), \; Y_2 = X_2 + \text{FeedForward}(Y_1)</script>
<p>The memory can be further reduced by chunking the feed-forward computation:
<script type="math/tex">Y_2 = [Y_2^{(1)}; \dots; Y_2^{(c)}] = [X_2^{(1)} + \text{FeedForward}(Y_1^{(1)}); \dots; X_2^{(c)} + \text{FeedForward}(Y_1^{(c)})]</script></p>
<p>The resulting reversible Transformer does not need to store activation in every layer.</p>
<h2 id="make-it-recurrent-universal-transformer">Make it Recurrent (Universal Transformer)</h2>
<p>The <strong>Universal Transformer</strong> (<a href="https://arxiv.org/abs/1807.03819">Dehghani, et al. 2019</a>) combines self-attention in Transformer with the recurrent mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned inductive biases of RNN.</p>
<p>Rather than going through a fixed number of layers, Universal Transformer dynamically adjusts the number of steps using <a href="#adaptive-computation-time-act">adaptive computation time</a>. If we fix the number of steps, an Universal Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.</p>
<p>On a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state representation per token. The recurrent function evolves in parallel across token positions and the information between positions is shared through self-attention.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/universal-transformer-loop.png" alt="Universal Transformer Recurrent Step" /></p>
<p><em>Fig. 13. How the Universal Transformer refines a set of hidden state representations repeatedly for every position in parallel. (Image source: Figure 1 in <a href="https://arxiv.org/abs/1807.03819">Dehghani, et al. 2019</a>).</em></p>
<p>Given an input sequence of length <script type="math/tex">L</script>, Universal Transformer iteratively updates the representation <script type="math/tex">\mathbf{H}^t \in \mathbb{R}^{L \times d}</script> at step <script type="math/tex">t</script> for an adjustable number of steps. At step 0, <script type="math/tex">\mathbf{H}^0</script> is initialized to be same as the input embedding matrix. All the positions are processed in parallel in the multi-head self-attention mechanism and then go through a recurrent transition function.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{A}^t &= \text{LayerNorm}(\mathbf{H}^{t-1} + \text{MultiHeadAttention}(\mathbf{H}^{t-1} + \mathbf{P}^t) \\
\mathbf{H}^t &= \text{LayerNorm}(\mathbf{A}^{t-1} + \text{Transition}(\mathbf{A}^t))
\end{aligned} %]]></script>
<p>where <script type="math/tex">\text{Transition}(.)</script> is either a <a href="https://arxiv.org/abs/1610.02357">separable convolution</a> or a fully-connected neural network that consists of two position-wise (i.e. applied to each row of <script type="math/tex">\mathbf{A}^t</script> individually) affine transformation + one ReLU.</p>
<p>The positional encoding <script type="math/tex">\mathbf{P}^t</script> uses sinusoidal position signal but with an additional time dimension:</p>
<script type="math/tex; mode=display">% <![CDATA[
\text{PE}(i, t, \delta) =
\begin{cases}
\sin(\frac{i}{10000^{2\delta'/d}}) \oplus \sin(\frac{t}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta'\\
\cos(\frac{i}{10000^{2\delta'/d}}) \oplus \cos(\frac{t}{10000^{2\delta'/d}}) & \text{if } \delta = 2\delta' + 1\\
\end{cases} %]]></script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/universal-transformer.png" alt="Universal Transformer" /></p>
<p><em>Fig. 14. A simplified illustration of Universal Transformer. The encoder and decoder share the same basic recurrent structure. But the decoder also attends to final encoder representation <script type="math/tex">\mathbf{H}^T</script>. (Image source: Figure 2 in <a href="https://arxiv.org/abs/1807.03819">Dehghani, et al. 2019</a>)</em></p>
<p>In the adaptive version of Universal Transformer, the number of recurrent steps <script type="math/tex">T</script> is dynamically determined by <a href="#adaptive-computation-time-act">ACT</a>. Each position is equipped with a dynamic ACT halting mechanism. Once a per-token recurrent block halts, it stops taking more recurrent updates but simply copies the current value to the next step until all the blocks halt or until the model reaches a maximum step limit.</p>
<h2 id="stabilization-for-rl-gtrxl">Stabilization for RL (GTrXL)</h2>
<p>The self-attention mechanism avoids compressing the whole past into a fixed-size hidden state and does not suffer from vanishing or exploding gradients as much as RNNs. Reinforcement Learning tasks can for sure benefit from these traits. <em>However</em>, it is quite difficult to train Transformer even in supervised learning, let alone in the RL context. It could be quite challenging to stabilize and train a LSTM agent by itself, after all.</p>
<p>The <strong>Gated Transformer-XL</strong> (<strong>GTrXL</strong>; <a href="https://arxiv.org/abs/1910.06764">Parisotto, et al. 2019</a>) is one attempt to use Transformer for RL. GTrXL succeeded in stabilizing training with two changes on top of <a href="#longer-attention-span-transformer-xl">Transformer-XL</a>:</p>
<ol>
<li>The layer normalization is only applied on the input stream in a residual module, but NOT on the shortcut stream. A key benefit to this reordering is to allow the original input to flow from the first to last layer.</li>
<li>The residual connection is replaced with a GRU-style (Gated Recurrent Unit; <a href="https://arxiv.org/abs/1412.3555">Chung et al., 2014</a>) <em>gating</em> mechanism.</li>
</ol>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
r &= \sigma(W_r^{(l)} y + U_r^{(l)} x) \\
z &= \sigma(W_z^{(l)} y + U_z^{(l)} x - b_g^{(l)}) \\
\hat{h} &= \tanh(W_g^{(l)} y + U_g^{(l)} (r \odot x)) \\
g^{(l)}(x, y) &= (1-z)\odot x + z\odot \hat{h}
\end{aligned} %]]></script>
<p>The gating function parameters are explicitly initialized to be close to an identity map - this is why there is a <script type="math/tex">b_g</script> term. A <script type="math/tex">b_g > 0</script> greatly helps with the learning speedup.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/gated-transformer-XL.png" alt="GTrXL" /></p>
<p><em>Fig. 15. Comparison of the model architecture of Transformer-XL, Transformer-XL with the layer norm reordered, and Gated Transformer-XL. (Image source: Figure 1 in <a href="https://arxiv.org/abs/1910.06764">Parisotto, et al. 2019</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020transformer,
title = "The Transformer Family",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/03/27/the-transformer-family.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Ashish Vaswani, et al. <a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">“Attention is all you need.”</a> NIPS 2017.</p>
<p>[2] Rami Al-Rfou, et al. <a href="https://arxiv.org/abs/1808.04444">“Character-level language modeling with deeper self-attention.”</a> AAAI 2019.</p>
<p>[3] Olah & Carter, <a href="http://doi.org/10.23915/disti">“Attention and Augmented Recurrent Neural Networks”</a>, Distill, 2016.</p>
<p>[4] Sainbayar Sukhbaatar, et al. <a href="https://arxiv.org/abs/1905.07799">“Adaptive Attention Span in Transformers”</a>. ACL 2019.</p>
<p>[5] Rewon Child, et al. <a href="https://arxiv.org/abs/1904.10509">“Generating Long Sequences with Sparse Transformers”</a> arXiv:1904.10509 (2019).</p>
<p>[6] Nikita Kitaev, et al. <a href="https://arxiv.org/abs/2001.04451">“Reformer: The Efficient Transformer”</a> ICLR 2020.</p>
<p>[7] Alex Graves. (“Adaptive Computation Time for Recurrent Neural Networks”)[https://arxiv.org/abs/1603.08983]</p>
<p>[8] Niki Parmar, et al. <a href="https://arxiv.org/abs/1802.05751">“Image Transformer”</a> ICML 2018.</p>
<p>[9] Zihang Dai, et al. <a href="https://arxiv.org/abs/1901.02860">“Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.”</a> ACL 2019.</p>
<p>[10] Aidan N. Gomez, et al. <a href="https://arxiv.org/abs/1707.04585">“The Reversible Residual Network: Backpropagation Without Storing Activations”</a> NIPS 2017.</p>
<p>[11] Mostafa Dehghani, et al. <a href="https://arxiv.org/abs/1807.03819">“Universal Transformers”</a> ICLR 2019.</p>
<p>[12] Emilio Parisotto, et al. <a href="https://arxiv.org/abs/1910.06764">“Stabilizing Transformers for Reinforcement Learning”</a> arXiv:1910.06764 (2019).</p>Lilian WengInspired by recent progress on various enhanced versions of Transformer models, this post presents how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving, etc.Curriculum for Reinforcement Learning2020-01-29T18:00:00+00:002020-01-29T18:00:00+00:00https://lilianweng.github.io/lil-log/2020/01/29/curriculum-for-reinforcement-learning<blockquote>
<p>A curriculum is an efficient tool for humans to progressively learn from simple concepts to hard problems. It breaks down complex knowledge by providing a sequence of learning steps of increasing difficulty. In this post, we will examine how the idea of curriculum can help reinforcement learning models learn to solve complicated tasks.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2020-02-03: mentioning <a href="#pcg">PCG</a> in the “Task-Specific Curriculum” section.</span><br />
<span style="color: #286ee0;">[Updated on 2020-02-04: Add a new <a href="#curriculum-through-distillation">“curriculum through distillation”</a> section.</span></p>
<p>It sounds like an impossible task if we want to teach integral or derivative to a 3-year-old who does not even know basic arithmetics. That’s why education is important, as it provides a systematic way to break down complex knowledge and a nice curriculum for teaching concepts from simple to hard. A curriculum makes learning difficult things easier and approachable for us humans. But, how about machine learning models? Can we train our models more efficiently with a curriculum? Can we design a curriculum to speed up learning?</p>
<p>Back in 1993, Jeffrey Elman has proposed the idea of training neural networks with a curriculum. His early work on learning simple language grammar demonstrated the importance of such a strategy: starting with a restricted set of simple data and gradually increasing the complexity of training samples; otherwise the model was not able to learn at all.</p>
<p>Compared to training without a curriculum, we would expect the adoption of the curriculum to expedite the speed of convergence and may or may not improve the final model performance. To design an efficient and effective curriculum is not easy. Keep in mind that, a bad curriculum may even hamper learning.</p>
<p>Next, we will look into several categories of curriculum learning, as illustrated in Fig. 1. Most cases are applied to Reinforcement Learning, with a few exceptions on Supervised Learning.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/types-of-curriculum-2.png" alt="Types of curriculum" /></p>
<p><em>Fig. 1. Five types of curriculum for reinforcement learning.</em></p>
<p>In “The importance of starting small” paper (<a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.128.4487&rep=rep1&type=pdf">Elman 1993</a>), I especially like the starting sentences and find them both inspiring and affecting:</p>
<blockquote>
<p>“Humans differ from other species along many dimensions, but two are particularly noteworthy. Humans display an exceptional capacity to learn; and humans are remarkable for the unusually long time it takes to reach maturity. The adaptive advantage of learning is clear, and it may be argued that, through culture, learning has created the basis for a non-genetically based transmission of behaviors which may accelerate the evolution of our species.”</p>
</blockquote>
<p>Indeed, learning is probably the best superpower we humans have.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#task-specific-curriculum" id="markdown-toc-task-specific-curriculum">Task-Specific Curriculum</a></li>
<li><a href="#teacher-guided-curriculum" id="markdown-toc-teacher-guided-curriculum">Teacher-Guided Curriculum</a></li>
<li><a href="#curriculum-through-self-play" id="markdown-toc-curriculum-through-self-play">Curriculum through Self-Play</a></li>
<li><a href="#automatic-goal-generation" id="markdown-toc-automatic-goal-generation">Automatic Goal Generation</a></li>
<li><a href="#skill-based-curriculum" id="markdown-toc-skill-based-curriculum">Skill-Based Curriculum</a></li>
<li><a href="#curriculum-through-distillation" id="markdown-toc-curriculum-through-distillation">Curriculum through Distillation</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="task-specific-curriculum">Task-Specific Curriculum</h2>
<p><a href="https://www.researchgate.net/profile/Y_Bengio/publication/221344862_Curriculum_learning/links/546cd2570cf2193b94c577ac/Curriculum-learning.pdf">Bengio, et al. (2009)</a> provided a good overview of curriculum learning in the old days. The paper presented two ideas with toy experiments using a manually designed task-specific curriculum:</p>
<ol>
<li>Cleaner Examples may yield better generalization faster.</li>
<li>Introducing gradually more difficult examples speeds up online training.</li>
</ol>
<p>It is plausible that some curriculum strategies could be useless or even harmful. A good question to answer in the field is: <em>What could be the general principles that make some curriculum strategies work better than others?</em> The Bengio 2009 paper hypothesized it would be beneficial to make learning focus on “interesting” examples that are neither too hard or too easy.</p>
<p>If our naive curriculum is to train the model on samples with a gradually increasing level of complexity, we need a way to quantify the difficulty of a task first. One idea is to use its minimal loss with respect to another model while this model is pretrained on other tasks (<a href="https://arxiv.org/abs/1802.03796">Weinshall, et al. 2018</a>). In this way, the knowledge of the pretrained model can be transferred to the new model by suggesting a rank of training samples. Fig. 2 shows the effectiveness of the <code class="language-plaintext highlighter-rouge">curriculum</code> group (green), compared to <code class="language-plaintext highlighter-rouge">control</code> (random order; yellow) and <code class="language-plaintext highlighter-rouge">anti</code> (reverse the order; red) groups.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/curriculum-by-transfer-learning.png" alt="Curriculum by transfer learning" /></p>
<p><em>Fig. 2. Image classification accuracy on test image set (5 member classes of “small mammals” in CIFAR100). There are 4 experimental groups, (a) <code class="language-plaintext highlighter-rouge">curriculum</code>: sort the labels by the confidence of another trained classifier (e.g. the margin of an SVM); (b) <code class="language-plaintext highlighter-rouge">control-curriculum</code>: sort the labels randomly; (c) <code class="language-plaintext highlighter-rouge">anti-curriculum</code>: sort the labels reversely; (d) <code class="language-plaintext highlighter-rouge">None</code>: no curriculum. (Image source: <a href="https://arxiv.org/abs/1802.03796">Weinshall, et al. 2018</a>)</em></p>
<p><a href="https://arxiv.org/abs/1410.4615">Zaremba & Sutskever (2014)</a> did an interesting experiment on training LSTM to predict the output of a short Python program for mathematical ops without actually executing the code. They found curriculum is necessary for learning. The program’s complexity is controlled by two parameters, <code class="language-plaintext highlighter-rouge">length</code> ∈ [1, a] and <code class="language-plaintext highlighter-rouge">nesting</code>∈ [1, b]. Three strategies are considered:</p>
<ol>
<li>Naive curriculum: increase <code class="language-plaintext highlighter-rouge">length</code> first until reaching <code class="language-plaintext highlighter-rouge">a</code>; then increase <code class="language-plaintext highlighter-rouge">nesting</code> and reset <code class="language-plaintext highlighter-rouge">length</code> to 1; repeat this process until both reach maximum.</li>
<li>Mix curriculum: sample <code class="language-plaintext highlighter-rouge">length</code> ~ [1, a] and <code class="language-plaintext highlighter-rouge">nesting</code> ~ [1, b]</li>
<li>Combined: naive + mix.</li>
</ol>
<p>They noticed that combined strategy always outperformed the naive curriculum and would generally (but not always) outperform the mix strategy — indicating that it is quite important to mix in easy tasks during training to <em>avoid forgetting</em>.</p>
<p><a name="pcg"></a>Procedural content generation (<a href="https://en.wikipedia.org/wiki/Procedural_generation">PCG</a>) is a popular approach for creating video games of various levels of difficulty. PCG involves algorithmic randomness and a heavy dose of human expertise in designing game elements and dependencies among them. Procedurally generated levels have been introduced into several benchmark environments for evaluating whether an RL agent can generalize to a new level that it is not trained on (<a href="/lil-log/2019/06/23/meta-reinforcement-learning.html">meta-RL</a>!), such as <a href="http://www.gvgai.net/">GVGAI</a>, OpenAI <a href="https://openai.com/blog/quantifying-generalization-in-reinforcement-learning/">CoinRun</a> and <a href="https://openai.com/blog/procgen-benchmark/">Procgen benchmark</a>. Using GVGAI, <a href="https://arxiv.org/abs/1806.10729">Justesen, et al. (2018)</a> demonstrated that an RL policy can easily overfit to a specific game but training over a simple curriculum that grows the task difficulty together with the model performance helps its generalization to new human-designed levels. Similar results are also found in CoinRun (<a href="https://arxiv.org/abs/1812.02341">Cobbe, et al. 2018</a>). POET (<a href="https://arxiv.org/abs/1901.01753">Wang et al, 2019</a>) is another example for leveraging evolutionary algorithm and procedural generated game levels to improve RL generalization, which I’ve described in details in my <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#evolutionary-algorithm-on-environment-generation">meta-RL post</a>.</p>
<p>To follow the curriculum learning approaches described above, generally we need to figure out two problems in the training procedure:</p>
<ol>
<li>Design a metric to quantify how hard a task is so that we can sort tasks accordingly.</li>
<li>Provide a sequence of tasks with an increasing level of difficulty to the model during training.</li>
</ol>
<p>However, the order of tasks does not have to be sequential. In our Rubik’s cube paper (<a href="https://arxiv.org/abs/1910.07113.">OpenAI et al, 2019</a>), we depended on <em>Automatic domain randomization</em> (<strong>ADR</strong>) to generate a curriculum by growing a distribution of environments with increasing complexity. The difficulty of each task (i.e. solving a Rubik’s cube in a set of environments) depends on the randomization ranges of various environmental parameters. Even with a simplified assumption that all the environmental parameters are uncorrelated, we were able to create a decent curriculum for our robot hand to learn the task.</p>
<h2 id="teacher-guided-curriculum">Teacher-Guided Curriculum</h2>
<p><a name="grave-et-al-2017"></a>The idea of <em>Automatic Curriculum Learning</em> was proposed by <a href="https://arxiv.org/abs/1704.03003">Graves, et al. 2017</a> slightly earlier. It considers a <script type="math/tex">N</script>-task curriculum as an <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html"><script type="math/tex">N</script>-armed bandit</a> problem and an adaptive policy which learns to optimize the returns from this bandit.</p>
<p>Two categories of learning signals have been considered in the paper:</p>
<ol>
<li>Loss-driven progress: the loss function change before and after one gradient update. This type of reward signals tracks the speed of the learning process, because the greatest task loss decrease is equivalent to the fastest learning.</li>
<li>Complex-driven progress: the KL divergence between posterior and prior distribution over network weights. This type of learning signals are inspired by the <a href="https://en.wikipedia.org/wiki/Minimum_description_length">MDL</a> principle, “increasing the model complexity by a certain amount is only worthwhile if it compresses the data by a greater amount”. The model complexity is therefore expected to increase most in response to the model nicely generalizing to training examples.</li>
</ol>
<p><a name="TSCL"></a>This framework of proposing curriculum automatically through another RL agent was formalized as <em>Teacher-Student Curriculum Learning</em> (<strong>TSCL</strong>; <a href="https://arxiv.org/abs/1707.00183">Matiisen, et al. 2017</a>). In TSCL, a <em>student</em> is an RL agent working on actual tasks while a <em>teacher</em> agent is a policy for selecting tasks. The student aims to master a complex task that might be hard to learn directly. To make this task easier to learn, we set up the teacher agent to guide the student’s training process by picking proper sub-tasks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/teacher-student-curriculum.png" alt="Teacher-student curriculum" /></p>
<p><em>Fig. 3. The setup of teacher-student curriculum learning. (Image source: <a href="https://arxiv.org/abs/1707.00183">Matiisen, et al. 2017</a> + my annotation in red.)</em></p>
<p>In the process, the student should learn tasks which:</p>
<ol>
<li>can help the student make fastest learning progress, or</li>
<li>are at risk of being forgotten.</li>
</ol>
<blockquote>
<p>Note: The setup of framing the teacher model as an RL problem feels quite similar to Neural Architecture Search (NAS), but differently the RL model in TSCL operates on the task space and NAS operates on the main model architecture space.</p>
</blockquote>
<p>Training the teacher model is to solve a <a href="https://en.wikipedia.org/wiki/Partially_observable_Markov_decision_process">POMDP</a> problem:</p>
<ul>
<li>The unobserved <script type="math/tex">s_t</script> is the full state of the student model.</li>
<li>The observed <script type="math/tex">o = (x_t^{(1)}, \dots, x_t^{(N)})</script> are a list of scores for <script type="math/tex">N</script> tasks.</li>
<li>The action <script type="math/tex">a</script> is to pick on subtask.</li>
<li>The reward per step is the score delta.<script type="math/tex">r_t = \sum_{i=1}^N x_t^{(i)} - x_{t-1}^{(i)}</script> (i.e., equivalent to maximizing the score of all tasks at the end of the episode).</li>
</ul>
<p>The method of estimating learning progress from noisy task scores while balancing exploration vs exploitation can be borrowed from the non-stationary multi-armed bandit problem — use <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#ε-greedy-algorithm">ε-greedy</a>, or <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#thompson-sampling">Thompson sampling</a>.</p>
<p>The core idea, in summary, is to use one policy to propose tasks for another policy to learn better. Interestingly, both works above (in the discrete task space) found that uniformly sampling from all tasks is a surprisingly strong benchmark.</p>
<p>What if the task space is continuous? <a href="https://arxiv.org/abs/1910.07224">Portelas, et al. (2019)</a> studied a continuous teacher-student framework, where the teacher has to sample parameters from continuous task space to generate a learning curriculum. Given a newly sampled parameter <script type="math/tex">p</script>, the absolute learning progress (short for ALP) is measured as <script type="math/tex">\text{ALP}_p = \vert r - r_\text{old} \vert</script>, where <script type="math/tex">r</script> is the episodic reward associated with <script type="math/tex">p</script> and <script type="math/tex">r_\text{old}</script> is the reward associated with <script type="math/tex">p_\text{old}</script>. Here, <script type="math/tex">p_\text{old}</script> is a previous sampled parameter closest to <script type="math/tex">p</script> in the task space, which can be retrieved by nearest neighbor. Note that how this ALP score is different from learning signals in <a href="#TSCL">TSCL</a> or <a href="#grave-et-al-2017">Grave, et al. 2017</a> above: ALP score measures the reward difference between two tasks rather than performance at two time steps of the same task.</p>
<p>On top of the task parameter space, a Gaussian mixture model is trained to fit the distribution of <script type="math/tex">\text{ALP}_p</script> over <script type="math/tex">p</script>. ε-greedy is used when sampling the tasks: with some probability, sampling a random task; otherwise sampling proportionally to ALP score from the GMM model.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ALP-GMM-algorithm.png" alt="ALP-GMM" /></p>
<p><em>Fig. 4. The algorithm of ALP-GMM (absolute learning progress Gaussian mixture model). (Image source: <a href="https://arxiv.org/abs/1910.07224">Portelas, et al., 2019</a>)</em></p>
<h2 id="curriculum-through-self-play">Curriculum through Self-Play</h2>
<p>Different from the teacher-student framework, two agents are doing very different things. The teacher learns to pick a task for the student without any knowledge of the actual task content. What if we want to make both train on the main task directly? How about even make them compete with each other?</p>
<p><a href="https://arxiv.org/abs/1703.05407">Sukhbaatar, et al. (2017)</a> proposed a framework for automatic curriculum learning through <strong>asymmetric self-play</strong>. Two agents, Alice and Bob, play the same task with different goals: Alice challenges Bob to achieve the same state and Bob attempts to complete it as fast as he can.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/self-play-maze.png" alt="Self-play experiments in MazeBase" /></p>
<p><em>Fig. 5. Illustration of the self-play setup when training two agents. The example task is <a href="https://github.com/facebook/MazeBase">MazeBase</a>: An agent is asked to reach a goal flag in a maze with a light switch, a key and a wall with a door. Toggling the key switch can open or close the door and Turning off the light makes only the glowing light switch available to the agent. (Image source: <a href="https://arxiv.org/abs/1703.05407">Sukhbaatar, et al. 2017</a>)</em></p>
<p>Let us consider Alice and Bob as two separate copies for one RL agent trained in the same environment but with different brains. Each of them has independent parameters and loss objective. The self-play-driven training consists of two types of episodes:</p>
<ul>
<li>In the <em>self-play episode</em>, Alice alters the state from <script type="math/tex">s_0</script> to <script type="math/tex">s_t</script> and then Bob is asked to return the environment to its original state <script type="math/tex">s_0</script> to get an internal reward.</li>
<li>In the <em>target task episode</em>, Bob receives an external reward if he visits the target flag.</li>
</ul>
<p>Note that since B has to repeat the actions between the same pair of <script type="math/tex">(s_0, s_t)</script> of A, this framework only works in reversible or resettable environments.</p>
<p>Alice should learn to push Bob out of his comfort zone, but not give him impossible tasks. Bob’s reward is set as <script type="math/tex">R_B = -\gamma t_B</script> and Alice’s reward is <script type="math/tex">R_A = \gamma \max(0, t_B - t_A)</script>, where <script type="math/tex">t_B</script> is the total time for B to complete the task, <script type="math/tex">t_A</script> is the time until Alice performs the STOP action and <script type="math/tex">\gamma</script> is a scalar constant to rescale the reward to be comparable with the external task reward. If B fails a task, <script type="math/tex">t_B = t_\max - t_A</script>.
Both policies are goal-conditioned. The losses imply:</p>
<ol>
<li>B wants to finish a task asap.</li>
<li>A prefers tasks that take more time of B.</li>
<li>A does not want to take too many steps when B is failing.</li>
</ol>
<p>In this way, the interaction between Alice and Bob automatically builds a curriculum of increasingly challenging tasks. Meanwhile, as A has done the task herself before proposing the task to B, the task is guaranteed to be solvable.</p>
<p>The paradigm of A suggesting tasks and then B solving them does sound similar to the Teacher-Student framework. However, in asymmetric self-play, Alice, who plays a teacher role, also works on the same task to find challenging cases for Bob, rather than optimizes B’s learning process explicitly.</p>
<h2 id="automatic-goal-generation">Automatic Goal Generation</h2>
<p>Often RL policy needs to be able to perform over a set of tasks. The goal should be carefully chosen so that at every training stage, it would not be too hard or too easy for the current policy. A goal <script type="math/tex">g \in \mathcal{G}</script> can be defined as a set of states <script type="math/tex">S^g</script> and a goal is considered as achieved whenever an agent arrives at any of those states.</p>
<p>The approach of Generative Goal Learning (<a href="https://arxiv.org/abs/1705.06366">Florensa, et al. 2018</a>) relies on a <strong>Goal GAN</strong> to generate desired goals automatically. In their experiment, the reward is very sparse, just a binary flag for whether a goal is achieved or not and the policy is conditioned on goal,</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\pi^{*}(a_t\vert s_t, g) &= \arg\max_\pi \mathbb{E}_{g\sim p_g(.)} R^g(\pi) \\
\text{where }R^g(\pi) &= \mathbb{E}_\pi(.\mid s_t, g) \mathbf{1}[\exists t \in [1,\dots, T]: s_t \in S^g]
\end{aligned} %]]></script>
<p>Here <script type="math/tex">R^g(\pi)</script> is the expected return, also equivalent to the success probability. Given sampled trajectories from the current policy, as long as any state belongs to the goal set, the return will be positive.</p>
<p>Their approach iterates through 3 steps until the policy converges:</p>
<ol>
<li>Label a set of goals based on whether they are at the appropriate level of difficulty for the current policy.
<ul>
<li>The set of goals at the appropriate level of difficulty are named <strong>GOID</strong> (short for “Goals of Intermediate Difficulty”).<br /><script type="math/tex">\text{GOID}_i := \{g : R_\text{min} \leq R^g(\pi_i) \leq R_\text{max} \} \subseteq G</script></li>
<li>Here <script type="math/tex">R_\text{min}</script> and <script type="math/tex">R_\text{max}</script> can be interpreted as a minimum and maximum probability of reaching a goal over T time-steps.</li>
</ul>
</li>
<li>Train a Goal GAN model using labelled goals from step 1 to produce new goals</li>
<li>Use these new goals to train the policy, improving its coverage objective.</li>
</ol>
<p>The Goal GAN generates a curriculum automatically:</p>
<ul>
<li>Generator <script type="math/tex">G(z)</script>: produces a new goal. => expected to be a goal uniformly sampled from <script type="math/tex">GOID</script> set.</li>
<li>Discriminator <script type="math/tex">D(g)</script>: evaluates whether a goal can be achieved. => expected to tell whether a goal is from <script type="math/tex">GOID</script> set.</li>
</ul>
<p>The Goal GAN is constructed similar to LSGAN (Least-Squared GAN; <a href="https://arxiv.org/abs/1611.04076">Mao et al., (2017)</a>), which has better stability of learning compared to vanilla GAN. According to LSGAN, we should minimize the following losses for <script type="math/tex">D</script> and <script type="math/tex">G</script> respectively:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}_\text{LSGAN}(D) &= \frac{1}{2} \mathbb{E}_{g \sim p_\text{data}(g)} [ (D(g) - b)^2] + \frac{1}{2} \mathbb{E}_{z \sim p_z(z)} [ (D(G(z)) - a)^2] \\
\mathcal{L}_\text{LSGAN}(G) &= \frac{1}{2} \mathbb{E}_{z \sim p_z(z)} [ (D(G(z)) - c)^2]
\end{aligned} %]]></script>
<p>where <script type="math/tex">a</script> is the label for fake data, <script type="math/tex">b</script> for real data, and <script type="math/tex">c</script> is the value that <script type="math/tex">G</script> wants <script type="math/tex">D</script> to believe for fake data. In LSGAN paper’s experiments, they used <script type="math/tex">a=-1, b=1, c=0</script>.</p>
<p>The Goal GAN introduces an extra binary flag <script type="math/tex">y_b</script> indicating whether a goal <script type="math/tex">g</script> is real (<script type="math/tex">y_g = 1</script>) or fake (<script type="math/tex">y_g = 0</script>) so that the model can use negative samples for training:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}_\text{GoalGAN}(D) &= \frac{1}{2} \mathbb{E}_{g \sim p_\text{data}(g)} [ (D(g) - b)^2 + (1-y_g) (D(g) - a)^2] + \frac{1}{2} \mathbb{E}_{z \sim p_z(z)} [ (D(G(z)) - a)^2] \\
\mathcal{L}_\text{GoalGAN}(G) &= \frac{1}{2} \mathbb{E}_{z \sim p_z(z)} [ (D(G(z)) - c)^2]
\end{aligned} %]]></script>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/generative-goal-learning-algorithm.png" alt="Generative goal learning" /></p>
<p><em>Fig. 6. The algorithm of Generative Goal Learning. (Image source: (<a href="https://arxiv.org/abs/1705.06366">Florensa, et al. 2018</a>)</em></p>
<p>Following the same idea, <a href="https://arxiv.org/abs/1909.12892">Racaniere & Lampinen, et al. (2019)</a> designs a method to make the objectives of goal generator more sophisticated. Their method contains three components, same as generative goal learning above:</p>
<ul>
<li><strong>Solver</strong>/Policy <script type="math/tex">\pi</script>: In each episode, the solver gets a goal <script type="math/tex">g</script> at the beginning and get a single binary reward <script type="math/tex">R^g</script> at the end.</li>
<li><strong>Judge</strong>/Discriminator <script type="math/tex">D(.)</script>: A classifier to predict the binary reward (whether goal can be achieved or not); precisely it outputs the logit of a probability of achieving the given goal, <script type="math/tex">\sigma(D(g)) = p(R^g=1\vert g)</script>, where <script type="math/tex">\sigma</script> is the sigmoid function.</li>
<li><strong>Setter</strong>/Generator <script type="math/tex">G(.)</script>: The goal setter takes as input a desired feasibility score <script type="math/tex">f \in \text{Unif}(0, 1)</script> and generates <script type="math/tex">g = G(z, f)</script>, where the latent variable <script type="math/tex">z</script> is sampled by <script type="math/tex">z \sim \mathcal{N}(0, I)</script>. The goal generator is designed to reversible, so <script type="math/tex">G^{-1}</script> can map backwards from a goal <script type="math/tex">g</script> to a latent <script type="math/tex">z = G^{-1}(g, f)</script></li>
</ul>
<p>The generator is optimized with three objectives:</p>
<ul>
<li>(1) Goal <strong>validity</strong>: The proposed goal should be achievable by an expert policy. The corresponding generative loss is designed to increase the likelihood of generating goals that the solver policy has achieved before (like in <a href="https://arxiv.org/abs/1707.01495">HER</a>).
<ul>
<li><script type="math/tex">\mathcal{L}_\text{val}</script> is the negative log-likelihood of generated goals that have been solved by the solver in the past.</li>
<li>
<script type="math/tex; mode=display">\begin{align*}
\mathcal{L}_\text{val} = \mathbb{E}_{\substack{
g \sim \text{ achieved by solver}, \\
\xi \in \text{Uniform}(0, \delta), \\
f \in \text{Uniform}(0, 1)
}} \big[ -\log p(G^{-1}(g + \xi, f)) \big]
\end{align*}</script>
</li>
</ul>
</li>
<li>(2) Goal <strong>feasibility</strong>: The proposed goal should be achievable by the current policy; that is, the level of difficulty should be appropriate.
<ul>
<li><script type="math/tex">\mathcal{L}_\text{feas}</script> is the output probability by the judge model <script type="math/tex">D</script> on the generated goal <script type="math/tex">G(z, f)</script> should match the desired $f$.</li>
<li>
<script type="math/tex; mode=display">\begin{align*}
\mathcal{L}_\text{feas} = \mathbb{E}_{\substack{
z \in \mathcal{N}(0, 1), \\
f \in \text{Uniform}(0, 1)
}} \big[ D(G(z, f)) - \sigma^{-1}(f)^2 \big]
\end{align*}</script>
</li>
</ul>
</li>
<li>(3) Goal <strong>coverage</strong>: We should maximize the entropy of generated goals to encourage diverse goal and to improve the coverage over the goal space.
<ul>
<li>
<script type="math/tex; mode=display">\begin{align*}
\mathcal{L}_\text{cov} = \mathbb{E}_{\substack{
z \in \mathcal{N}(0, 1), \\
f \in \text{Uniform}(0, 1)
}} \big[ \log p(G(z, f)) \big]
\end{align*}</script>
</li>
</ul>
</li>
</ul>
<p>Their experiments showed complex environments require all three losses above. When the environment is changing between episodes, both the goal generator and the discriminator need to be conditioned on environmental observation to produce better results. If there is a desired goal distribution, an additional loss can be added to match a desired goal distribution using Wasserstein distance. Using this loss, the generator can push the solver toward mastering the desired tasks more efficiently.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/setter-judge-goal-generation.png" alt="Goal setter and judge models" /></p>
<p><em>Fig. 7. Training schematic for the (a) solver/policy, (b) judge/discriminator, and (c) setter/goal generator models. (Image source: <a href="https://arxiv.org/abs/1909.12892">Racaniere & Lampinen, et al., 2019</a>)</em></p>
<h2 id="skill-based-curriculum">Skill-Based Curriculum</h2>
<p>Another view is to decompose what an agent is able to complete into a variety of skills and each skill set could be mapped into a task. Let’s imagine when an agent interacts with the environment in an unsupervised manner, is there a way to discover useful skills from such interaction and further build into the solutions for more complicated tasks through a curriculum?</p>
<p><a href="https://arxiv.org/abs/1912.04226">Jabri, et al. (2019)</a> developed an automatic curriculum, <strong>CARML</strong> (short for “Curricula for Unsupervised Meta-Reinforcement Learning”), by modeling unsupervised trajectories into a latent skill space, with a focus on training <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html">meta-RL</a> policies (i.e. can transfer to unseen tasks). The setting of training environments in CARML is similar to <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#learning-with-random-rewards">DIAYN</a>. Differently, CARML is trained on pixel-level observations but DIAYN operates on the true state space. An RL algorithm <script type="math/tex">\pi_\theta</script>, parameterized by <script type="math/tex">\theta</script>, is trained via unsupervised interaction formulated as a CMP combined with a learned reward function <script type="math/tex">r</script>. This setting naturally works for the meta-learning purpose, since a customized reward function can be given only at the test time.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CARML.png" alt="CARML" /></p>
<p><em>Fig. 8. An illustration of CARML, containing two steps: (1) organizing experiential data into the latent skill space; (2) meta-training the policy with the reward function constructed from the learned skills. (Image source: <a href="https://arxiv.org/abs/1912.04226">Jabri, et al 2019</a>)</em></p>
<p>CARML is framed as a <a href="https://chrischoy.github.io/research/Expectation-Maximization-and-Variational-Inference/">variational Expectation-Maximization (EM)</a>.</p>
<p>(1) <strong>E-Step</strong>: This is the stage for organizing experiential data. Collected trajectories are modeled with a mixture of latent components forming the <a href="https://en.wikipedia.org/wiki/Basis_(linear_algebra)">basis</a> of <em>skills</em>.</p>
<p>Let <script type="math/tex">z</script> be a latent task variable and <script type="math/tex">q_\phi</script> be a variational distribution of <script type="math/tex">z</script>, which could be a mixture model with discrete <script type="math/tex">z</script> or a VAE with continuous <script type="math/tex">z</script>. A variational posterior <script type="math/tex">q_\phi(z \vert s)</script> works like a classifier, predicting a skill given a state, and we would like to maximize <script type="math/tex">q_\phi(z \vert s)</script> to discriminate between data produced by different skills as much as possible. In E-step, <script type="math/tex">q_\phi</script> is fitted to a set of trajectories produced by <script type="math/tex">\pi_\theta</script>.</p>
<p>Precisely, given a trajectory <script type="math/tex">\tau = (s_1,\dots,s_T)</script>, we would like to find <script type="math/tex">\phi</script> such that</p>
<script type="math/tex; mode=display">\max_\phi \mathbb{E}_{z\sim q_\phi(z)} \big[ \log q_\phi(\tau \vert z) \big]
= \max_\phi \mathbb{E}_{z\sim q_\phi(z)} \big[ \sum_{s_i \in \tau} \log q_\phi(s_i \vert z) \big]</script>
<p>A simplifying assumption is made here to ignore the order of states in one trajectory.</p>
<p>(2) <strong>M-Step</strong>: This is the stage for doing meta-RL training with <script type="math/tex">\pi_\theta</script>. The learned skill space is considered as a training task distribution. CARML is agnostic to the type of meta-RL algorithm for policy parameter updates.</p>
<p>Given a trajectory <script type="math/tex">\tau</script>, it makes sense for the policy to maximize the mutual information between <script type="math/tex">\tau</script> and <script type="math/tex">z</script>, <script type="math/tex">I(\tau;z) = H(\tau) - H(\tau \vert z)</script>, because:</p>
<ul>
<li>maximizing <script type="math/tex">H(\tau)</script> => diversity in the policy data space; expected to be large.</li>
<li>minimizing <script type="math/tex">H(\tau \vert z)</script> => given a certain skill, the behavior should be restricted; expected to be small.</li>
</ul>
<p>Then we have,</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
I(\tau; z)
&= \mathcal{H}(z) - \mathcal{H}(z \vert s_1,\dots, s_T) \\
&\geq \mathbb{E}_{s \in \tau} [\mathcal{H}(z) - \mathcal{H}(z\vert s)] & \scriptstyle{\text{; discard the order of states.}} \\
&= \mathbb{E}_{s \in \tau} [\mathcal{H}(s_t) - \mathcal{H}(s\vert z)] & \scriptstyle{\text{; by definition of MI.}} \\
&= \mathbb{E}_{z\sim q_\phi(z), s\sim \pi_\theta(s|z)} [\log q_\phi(s|z) - \log \pi_\theta(s)] \\
&\approx \mathbb{E}_{z\sim q_\phi(z), s\sim \pi_\theta(s|z)} [\color{green}{\log q_\phi(s|z) - \log q_\phi(s)}] & \scriptstyle{\text{; assume learned marginal distr. matches policy.}}
\end{aligned} %]]></script>
<p>We can set the reward as <script type="math/tex">\log q_\phi(s \vert z) - \log q_\phi(s)</script>, as shown in the <span style="color: green;">red</span> part in the equation above. In order to balance between task-specific exploration (as in <span style="color: red;">red</span> below) and latent skill matching (as in <span style="color: blue;">blue</span> below) , a parameter <script type="math/tex">\lambda \in [0, 1]</script> is added. Each realization of <script type="math/tex">z \sim q_\phi(z)</script> induces a reward function <script type="math/tex">r_z(s)</script> (remember that reward + CMP => MDP) as follows:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
r_z(s)
&= \lambda \log q_\phi(s|z) - \log q_\phi(s) \\
&= \lambda \log q_\phi(s|z) - \log \frac{q_\phi(s|z) q_\phi(z)}{q_\phi(z|s)} \\
&= \lambda \log q_\phi(s|z) - \log q_\phi(s|z) - \log q_\phi(z) + \log q_\phi(z|s) \\
&= (\lambda - 1) \log \color{red}{q_\phi(s|z)} + \color{blue}{\log q_\phi(z|s)} + C
\end{aligned} %]]></script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CARML-algorithm.png" alt="CARML algorithm" /></p>
<p><em>Fig. 9. The algorithm of CARML. (Image source: <a href="https://arxiv.org/abs/1912.04226">Jabri, et al 2019</a>)</em></p>
<p>Learning a latent skill space can be done in different ways, such as in <a href="https://openreview.net/forum?id=rk07ZXZRb">Hausman, et al. 2018</a>. The goal of their approach is to learn a task-conditioned policy, <script type="math/tex">\pi(a \vert s, t^{(i)})</script>, where <script type="math/tex">t^{(i)}</script> is from a discrete list of <script type="math/tex">N</script> tasks, <script type="math/tex">\mathcal{T} = [t^{(1)}, \dots, t^{(N)}]</script>. However, rather than learning <script type="math/tex">N</script> separate solutions, one per task, it would be nice to learn a latent skill space so that each task could be represented in a distribution over skills and thus skills are <em>reused between tasks</em>. The policy is defined as <script type="math/tex">\pi_\theta(a \vert s,t) = \int \pi_\theta(a \vert z,s,t) p_\phi(z \vert t)\mathrm{d}z</script>, where <script type="math/tex">\pi_\theta</script> and <script type="math/tex">p_\phi</script> are policy and embedding networks to learn, respectively. If <script type="math/tex">z</script> is discrete, i.e. drawn from a set of <script type="math/tex">K</script> skills, then the policy becomes a mixture of <script type="math/tex">K</script> sub-policies. The policy training uses <a href="http://127.0.0.1:4000/lil-log/2018/04/07/policy-gradient-algorithms.html#sac">SAC</a> and the dependency on <script type="math/tex">z</script> is introduced in the entropy term.</p>
<h2 id="curriculum-through-distillation">Curriculum through Distillation</h2>
<p>[I was thinking of the name of this section for a while, deciding between cloning, inheritance, and distillation. Eventually, I picked distillation because it sounds the coolest B-)]</p>
<p>The motivation for the <strong>progressive neural network</strong> (<a href="https://arxiv.org/abs/1606.04671">Rusu et al. 2016</a>) architecture is to efficiently transfer learned skills between different tasks and in the meantime avoid catastrophic forgetting. The curriculum is realized through a set of progressively stacked neural network towers (or “columns”, as in the paper).</p>
<p>A progressive network has the following structure:</p>
<ol>
<li>It starts with a single column containing <script type="math/tex">L</script> layers of neurons, in which the corresponding activation layers are labelled as <script type="math/tex">h^{(1)}_i, i=1, \dots, L</script>. We first train this single-column network for one task to convergence, achieving parameter config <script type="math/tex">\theta^{(1)}</script>.</li>
<li>Once switch to the next task, we need to add a new column to adapt to the new context while freezing <script type="math/tex">\theta^{(1)}</script> to lock down the learned skills from the previous task. The new column has activation layers labelled as <script type="math/tex">h^{(2)}_i, i=1, \dots, L</script>, and parameters <script type="math/tex">\theta^{(2)}</script>.</li>
<li>
<p>Step 2 can be repeated with every new task. The <script type="math/tex">i</script>-th layer activation in the <script type="math/tex">k</script>-th column depends on the previous activation layers in all the existing columns:</p>
<script type="math/tex; mode=display">% <![CDATA[
h^{(k)}_i = f(W^{(k)}_i h^{(k)}_{i-1} + \sum_{j < k} U_i^{(k:j)} h^{(j)}_{i-1}) %]]></script>
<p>where <script type="math/tex">W^{(k)}_i</script> is the weight matrix of the layer <script type="math/tex">i</script> in the column <script type="math/tex">k</script>; <script type="math/tex">% <![CDATA[
U_i^{(k:j)}, j < k %]]></script> are the weight matrices for projecting the layer <script type="math/tex">i-1</script> of the column <script type="math/tex">j</script> to the layer <script type="math/tex">i</script> of column <script type="math/tex">k</script> (<script type="math/tex">% <![CDATA[
j < k %]]></script>). The above weights matrices should be learned. <script type="math/tex">f(.)</script> is a non-linear activation function by choice.</p>
</li>
</ol>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/progressive-networks.png" alt="Progressive networks" /></p>
<p><em>Fig. 10. The progressive neural network architecture. (Image source: <a href="https://arxiv.org/abs/1610.04286">Rusu, et al. 2017</a>)</em></p>
<p>The paper experimented with Atari games by training a progressive network on multiple games to check whether features learned in one game can transfer to another. That is indeed the case. Though interestingly, learning a high dependency on features in the previous columns does not always indicate good transfer performance on the new task. One hypothesis is that features learned from the old task might introduce biases into the new task, leading to policy getting trapped in a sub-optimal solution. Overall, the progressive network works better than only fine-tuning the top layer and can achieve similar transfer performance as fine-tuning the entire network.</p>
<p>One use case for the progressive network is to do sim2real transfer (<a href="https://arxiv.org/abs/1610.04286">Rusu, et al. 2017</a>), in which the first column is trained in simulator with a lot of samples and then the additional columns (could be for different real-world tasks) are added and trained with a few real data samples.</p>
<p><a href="https://arxiv.org/abs/1806.01780">Czarnecki, et al. (2018)</a> proposed another RL training framework, <strong>Mix & Match</strong> (short for <strong>M&M</strong>) to provide curriculum through coping knowledge between agents. Given a sequence of agents from simple to complex, <script type="math/tex">\pi_1, \dots, \pi_K</script>, each parameterized with some shared weights (e.g. by shared some lower common layers). M&M trains a mixture of agents, but only the final performance of the most complex one <script type="math/tex">\pi_K</script> matters.</p>
<p>In the meantime, M&M learns a categorical distribution <script type="math/tex">c \sim \text{Categorical}(1, \dots, K \vert \alpha)</script> with <a href="https://en.wikipedia.org/wiki/Probability_mass_function">pmf</a> <script type="math/tex">p(c=i) = \alpha_i</script> probability to pick which policy to use at a given time. The mixed M&M policy is a simple weighted sum: <script type="math/tex">\pi_\text{mm}(a \vert s) = \sum_{i=1}^K \alpha_i \pi_i(a \vert s)</script>. Curriculum learning is realized by dynamically adjusting <script type="math/tex">\alpha_i</script>, from <script type="math/tex">\alpha_K=0</script> to <script type="math/tex">\alpha_K=1</script>. The tuning of <script type="math/tex">\alpha</script> can be manual or through <a href="/lil-log/2019/09/05/evolution-strategies.html#hyperparameter-tuning-pbt">population-based training</a>.</p>
<p>To encourage cooperation rather than competition among policies, besides the RL loss <script type="math/tex">\mathcal{L}_\text{RL}</script>, another <a href="https://arxiv.org/abs/1511.06295">distillation</a>-like loss <script type="math/tex">\mathcal{L}_\text{mm}(\theta)</script> is added. The knowledge transfer loss <script type="math/tex">\mathcal{L}_\text{mm}(\theta)</script> measures the KL divergence between two policies, <script type="math/tex">\propto D_\text{KL}(\pi_{i}(. \vert s) \| \pi_j(. \vert s))</script> for <script type="math/tex">% <![CDATA[
i < j %]]></script>. It encourages complex agents to match the simpler ones early on. The final loss is <script type="math/tex">\mathcal{L} = \mathcal{L}_\text{RL}(\theta \vert \pi_\text{mm}) + \lambda \mathcal{L}_\text{mm}(\theta)</script>.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/mix-and-match.png" alt="Mix & Match" /></p>
<p><em>Fig. 11. The Mix & Match architecture for training a mixture of policies. (Image source: <a href="https://arxiv.org/abs/1806.01780">Czarnecki, et al., 2018</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2020curriculum,
title = "Curriculum for Reinforcement Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2020",
url = "https://lilianweng.github.io/lil-log/2020/01/29/curriculum-for-reinforcement-learning.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Jeffrey L. Elman. <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.128.4487&rep=rep1&type=pdf">“Learning and development in neural networks: The importance of starting small.”</a> Cognition 48.1 (1993): 71-99.</p>
<p>[2] Yoshua Bengio, et al. <a href="https://www.researchgate.net/profile/Y_Bengio/publication/221344862_Curriculum_learning/links/546cd2570cf2193b94c577ac/Curriculum-learning.pdf">“Curriculum learning.”</a> ICML 2009.</p>
<p>[3] Daphna Weinshall, Gad Cohen, and Dan Amir. <a href="https://arxiv.org/abs/1802.03796">“Curriculum learning by transfer learning: Theory and experiments with deep networks.”</a> ICML 2018.</p>
<p>[4] Wojciech Zaremba and Ilya Sutskever. <a href="https://arxiv.org/abs/1410.4615">“Learning to execute.”</a> arXiv preprint arXiv:1410.4615 (2014).</p>
<p>[5] Tambet Matiisen, et al. <a href="https://arxiv.org/abs/1707.00183">“Teacher-student curriculum learning.”</a> IEEE Trans. on neural networks and learning systems (2017).</p>
<p>[6] Alex Graves, et al. <a href="https://arxiv.org/abs/1704.03003">“Automated curriculum learning for neural networks.”</a> ICML 2017.</p>
<p>[7] Remy Portelas, et al. <a href="https://arxiv.org/abs/1910.07224">Teacher algorithms for curriculum learning of Deep RL in continuously parameterized environments</a>. CoRL 2019.</p>
<p>[8] Sainbayar Sukhbaatar, et al. <a href="https://arxiv.org/abs/1703.05407">“Intrinsic Motivation and Automatic Curricula via Asymmetric Self-Play.”</a> ICLR 2018.</p>
<p>[9] Carlos Florensa, et al. <a href="https://arxiv.org/abs/1705.06366">“Automatic Goal Generation for Reinforcement Learning Agents”</a> ICML 2019.</p>
<p>[10] Sebastien Racaniere & Andrew K. Lampinen, et al. <a href="https://arxiv.org/abs/1909.12892">“Automated Curriculum through Setter-Solver Interactions”</a> ICLR 2020.</p>
<p>[11] Allan Jabri, et al. <a href="https://arxiv.org/abs/1912.04226">“Unsupervised Curricula for Visual Meta-Reinforcement Learning”</a> NeuriPS 2019.</p>
<p>[12] Karol Hausman, et al. <a href="https://openreview.net/forum?id=rk07ZXZRb">“Learning an Embedding Space for Transferable Robot Skills “</a> ICLR 2018.</p>
<p>[13] Josh Merel, et al. <a href="https://arxiv.org/abs/1911.06636">“Reusable neural skill embeddings for vision-guided whole body movement and object manipulation”</a> arXiv preprint arXiv:1911.06636 (2019).</p>
<p>[14] OpenAI, et al. <a href="https://arxiv.org/abs/1910.07113">“Solving Rubik’s Cube with a Robot Hand.”</a> arXiv preprint arXiv:1910.07113 (2019).</p>
<p>[15] Niels Justesen, et al. <a href="https://arxiv.org/abs/1806.10729">“Illuminating Generalization in Deep Reinforcement Learning through Procedural Level Generation”</a> NeurIPS 2018 Deep RL Workshop.</p>
<p>[16] Karl Cobbe, et al. <a href="https://arxiv.org/abs/1812.02341">“Quantifying Generalization in Reinforcement Learning”</a> arXiv preprint arXiv:1812.02341 (2018).</p>
<p>[17] Andrei A. Rusu et al. <a href="https://arxiv.org/abs/1606.04671">“Progressive Neural Networks”</a> arXiv preprint arXiv:1606.04671 (2016).</p>
<p>[18] Andrei A. Rusu et al. <a href="https://arxiv.org/abs/1610.04286">“Sim-to-Real Robot Learning from Pixels with Progressive Nets.”</a> CoRL 2017.</p>
<p>[19] Wojciech Marian Czarnecki, et al. <a href="https://arxiv.org/abs/1806.01780">“Mix & Match – Agent Curricula for Reinforcement Learning.”</a> ICML 2018.</p>Lilian WengA curriculum is an efficient tool for humans to progressively learn from simple concepts to hard problems. It breaks down complex knowledge by providing a sequence of learning steps of increasing difficulty. In this post, we will examine how the idea of curriculum can help reinforcement learning models learn to solve complicated tasks.Self-Supervised Representation Learning2019-11-10T18:00:00+00:002019-11-10T18:00:00+00:00https://lilianweng.github.io/lil-log/2019/11/10/self-supervised-learning<blockquote>
<p>Self-supervised learning opens up a huge opportunity for better utilizing unlabelled data, while learning in a supervised learning manner. This post covers many interesting ideas of self-supervised learning tasks on images, videos, and control problems.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2020-01-09: add a new session on <a href="#contrastive-predictive-coding">Contrastive Predictive Coding</a>].</span>
<br />
<span style="color: #286ee0;">[Updated on 2020-04-13: add a new <a href="#momentum-contrast">“Momentum Contrast”</a> session on MoCo, SimCLR and CURL.</span></p>
<p>Given a task and enough labels, supervised learning can solve it really well. Good performance usually requires a decent amount of labels, but collecting manual labels is expensive (i.e. ImageNet) and hard to be scaled up. Considering the amount of unlabelled data (e.g. free text, all the images on the Internet) is substantially more than a limited number of human curated labelled datasets, it is kinda wasteful not to use them. However, unsupervised learning is not easy and usually works much less efficiently than supervised learning.</p>
<p>What if we can get labels for free for unlabelled data and train unsupervised dataset in a supervised manner? We can achieve this by framing a supervised learning task in a special form to predict only a subset of information using the rest. In this way, all the information needed, both inputs and labels, has been provided. This is known as <em>self-supervised learning</em>.</p>
<p>This idea has been widely used in language modeling. The default task for a language model is to predict the next word given the past sequence. <a href="/lil-log/2019/01/31/generalized-language-models.html#bert">BERT</a> adds two other auxiliary tasks and both rely on self-generated labels.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/self-sup-lecun.png" alt="Self-supervised learning summary" /></p>
<p><em>Fig. 1. A great summary of how self-supervised learning tasks can be constructed (Image source: <a href="https://www.youtube.com/watch?v=7I0Qt7GALVk">LeCun’s talk</a>)</em></p>
<p><a href="https://github.com/jason718/awesome-self-supervised-learning">Here</a> is a nicely curated list of papers in self-supervised learning. Please check it out if you are interested in reading more in depth.</p>
<p>Note that this post does not focus on either NLP / <a href="/lil-log/2019/01/31/generalized-language-models.html">language modeling</a> or <a href="https://lilianweng.github.io/lil-log/tag/generative-model">generative modeling</a>.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#why-self-supervised-learning" id="markdown-toc-why-self-supervised-learning">Why Self-Supervised Learning?</a></li>
<li><a href="#images-based" id="markdown-toc-images-based">Images-Based</a> <ul>
<li><a href="#distortion" id="markdown-toc-distortion">Distortion</a></li>
<li><a href="#patches" id="markdown-toc-patches">Patches</a></li>
<li><a href="#colorization" id="markdown-toc-colorization">Colorization</a></li>
<li><a href="#generative-modeling" id="markdown-toc-generative-modeling">Generative Modeling</a></li>
<li><a href="#contrastive-predictive-coding" id="markdown-toc-contrastive-predictive-coding">Contrastive Predictive Coding</a></li>
<li><a href="#momentum-contrast" id="markdown-toc-momentum-contrast">Momentum Contrast</a></li>
</ul>
</li>
<li><a href="#video-based" id="markdown-toc-video-based">Video-Based</a> <ul>
<li><a href="#tracking" id="markdown-toc-tracking">Tracking</a></li>
<li><a href="#frame-sequence" id="markdown-toc-frame-sequence">Frame Sequence</a></li>
<li><a href="#video-colorization" id="markdown-toc-video-colorization">Video Colorization</a></li>
</ul>
</li>
<li><a href="#control-based" id="markdown-toc-control-based">Control-Based</a> <ul>
<li><a href="#multi-view-metric-learning" id="markdown-toc-multi-view-metric-learning">Multi-View Metric Learning</a></li>
<li><a href="#autonomous-goal-generation" id="markdown-toc-autonomous-goal-generation">Autonomous Goal Generation</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
</li>
</ul>
<h2 id="why-self-supervised-learning">Why Self-Supervised Learning?</h2>
<p>Self-supervised learning empowers us to exploit a variety of labels that come with the data for free. The motivation is quite straightforward. Producing a dataset with clean labels is expensive but unlabeled data is being generated all the time. To make use of this much larger amount of unlabeled data, one way is to set the learning objectives properly so as to get supervision from the data itself.</p>
<p>The <em>self-supervised task</em>, also known as <em>pretext task</em>, guides us to a supervised loss function. However, we usually don’t care about the final performance of this invented task. Rather we are interested in the learned intermediate representation with the expectation that this representation can carry good semantic or structural meanings and can be beneficial to a variety of practical downstream tasks.</p>
<p>For example, we might rotate images at random and train a model to predict how each input image is rotated. The rotation prediction task is made-up, so the actual accuracy is unimportant, like how we treat auxiliary tasks. But we expect the model to learn high-quality latent variables for real-world tasks, such as constructing an object recognition classifier with very few labeled samples.</p>
<p>Broadly speaking, all the generative models can be considered as self-supervised, but with different goals: Generative models focus on creating diverse and realistic images, while self-supervised representation learning care about producing good features generally helpful for many tasks. Generative modeling is not the focus of this post, but feel free to check my <a href="https://lilianweng.github.io/lil-log/tag/generative-model">previous posts</a>.</p>
<h2 id="images-based">Images-Based</h2>
<p>Many ideas have been proposed for self-supervised representation learning on images. A common workflow is to train a model on one or multiple pretext tasks with unlabelled images and then use one intermediate feature layer of this model to feed a multinomial logistic regression classifier on ImageNet classification. The final classification accuracy quantifies how good the learned representation is.</p>
<p>Recently, some researchers proposed to train supervised learning on labelled data and self-supervised pretext tasks on unlabelled data simultaneously with shared weights, like in <a href="https://arxiv.org/abs/1905.03670">Zhai et al, 2019</a> and <a href="https://arxiv.org/abs/1909.11825">Sun et al, 2019</a>.</p>
<h3 id="distortion">Distortion</h3>
<p>We expect small distortion on an image does not modify its original semantic meaning or geometric forms. Slightly distorted images are considered the same as original and thus the learned features are expected to be invariant to distortion.</p>
<p><mark><b>Exemplar-CNN</b></mark> (<a href="https://arxiv.org/abs/1406.6909">Dosovitskiy et al., 2015</a>) create surrogate training datasets with unlabeled image patches:</p>
<ol>
<li>Sample <script type="math/tex">N</script> patches of size 32 × 32 pixels from different images at varying positions and scales, only from regions containing considerable gradients as those areas cover edges and tend to contain objects or parts of objects. They are <em>“exemplary”</em> patches.</li>
<li>Each patch is distorted by applying a variety of random transformations (i.e., translation, rotation, scaling, etc.). All the resulting distorted patches are considered to belong to the <em>same surrogate class</em>.</li>
<li>The pretext task is to discriminate between a set of surrogate classes. We can arbitrarily create as many surrogate classes as we want.</li>
</ol>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/examplar-cnn.png" alt="Examplar CNN" /></p>
<p><em>Fig. 2. The original patch of a cute deer is in the top left corner. Random transformations are applied, resulting in a variety of distorted patches. All of them should be classified into the same class in the pretext task. (Image source: <a href="https://arxiv.org/abs/1406.6909">Dosovitskiy et al., 2015</a>)</em></p>
<p><mark><b>Rotation</b></mark> of an entire image (<a href="https://arxiv.org/abs/1803.07728">Gidaris et al. 2018</a> is another interesting and cheap way to modify an input image while the semantic content stays unchanged. Each input image is first rotated by a multiple of <script type="math/tex">90^\circ</script> at random, corresponding to <script type="math/tex">[0^\circ, 90^\circ, 180^\circ, 270^\circ]</script>. The model is trained to predict which rotation has been applied, thus a 4-class classification problem.</p>
<p>In order to identify the same image with different rotations, the model has to learn to recognize high level object parts, such as heads, noses, and eyes, and the relative positions of these parts, rather than local patterns. This pretext task drives the model to learn semantic concepts of objects in this way.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/self-sup-rotation.png" alt="Self supervised by rotation prediction" /></p>
<p><em>Fig. 3. Illustration of self-supervised learning by rotating the entire input images. The model learns to predict which rotation is applied. (Image source: <a href="https://arxiv.org/abs/1803.07728">Gidaris et al. 2018</a>)</em></p>
<h3 id="patches">Patches</h3>
<p>The second category of self-supervised learning tasks extract multiple patches from one image and ask the model to predict the relationship between these patches.</p>
<p><a href="https://arxiv.org/abs/1505.05192">Doersch et al. (2015)</a> formulates the pretext task as predicting the <mark><b>relative position</b></mark> between two random patches from one image. A model needs to understand the spatial context of objects in order to tell the relative position between parts.</p>
<p>The training patches are sampled in the following way:</p>
<ol>
<li>Randomly sample the first patch without any reference to image content.</li>
<li>Considering that the first patch is placed in the middle of a 3x3 grid, and the second patch is sampled from its 8 neighboring locations around it.</li>
<li>To avoid the model only catching low-level trivial signals, such as connecting a straight line across boundary or matching local patterns, additional noise is introduced by:
<ul>
<li>Add gaps between patches</li>
<li>Small jitters</li>
<li>Randomly downsample some patches to as little as 100 total pixels, and then upsampling it, to build robustness to pixelation.</li>
<li>Shift green and magenta toward gray or randomly drop 2 of 3 color channels (See <a href="#chromatic-aberration">“chromatic aberration”</a> below)</li>
</ul>
</li>
<li>The model is trained to predict which one of 8 neighboring locations the second patch is selected from, a classification problem over 8 classes.</li>
</ol>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/self-sup-by-relative-position.png" alt="Self-supervised learning by context" /></p>
<p><em>Fig. 4. Illustration of self-supervised learning by predicting the relative position of two random patches. (Image source: <a href="https://arxiv.org/abs/1505.05192">Doersch et al., 2015</a>)</em></p>
<p><a href="#chromatic-aberration"></a>Other than trivial signals like boundary patterns or textures continuing, another interesting and a bit surprising trivial solution was found, called <a href="https://en.wikipedia.org/wiki/Chromatic_aberration"><em>“chromatic aberration”</em></a>. It is triggered by different focal lengths of lights at different wavelengths passing through the lens. In the process, there might exist small offsets between color channels. Hence, the model can learn to tell the relative position by simply comparing how green and magenta are separated differently in two patches. This is a trivial solution and has nothing to do with the image content. Pre-processing images by shifting green and magenta toward gray or randomly dropping 2 of 3 color channels can avoid this trivial solution.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/chromatic-aberration.png" alt="Chromatic aberration" /></p>
<p><em>Fig. 5. Illustration of how chromatic aberration happens. (Image source: <a href="https://upload.wikimedia.org/wikipedia/commons/a/aa/Chromatic_aberration_lens_diagram.svg">wikipedia</a>)</em></p>
<p>Since we have already set up a 3x3 grid in each image in the above task, why not use all of 9 patches rather than only 2 to make the task more difficult? Following this idea, <a href="https://arxiv.org/abs/1603.09246">Noroozi & Favaro (2016)</a> designed a <mark><b>jigsaw puzzle</b></mark> game as pretext task: The model is trained to place 9 shuffled patches back to the original locations.</p>
<p>A convolutional network processes each patch independently with shared weights and outputs a probability vector per patch index out of a predefined set of permutations. To control the difficulty of jigsaw puzzles, the paper proposed to shuffle patches according to a predefined permutation set and configured the model to predict a probability vector over all the indices in the set.</p>
<p>Because how the input patches are shuffled does not alter the correct order to predict. A potential improvement to speed up training is to use permutation-invariant graph convolutional network (GCN) so that we don’t have to shuffle the same set of patches multiple times, same idea as in this <a href="https://arxiv.org/abs/1911.00025">paper</a>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/self-sup-jigsaw-puzzle.png" alt="Jigsaw puzzle" /></p>
<p><em>Fig. 6. Illustration of self-supervised learning by solving jigsaw puzzle. (Image source: <a href="https://arxiv.org/abs/1603.09246">Noroozi & Favaro, 2016</a>)</em></p>
<p>Another idea is to consider “feature” or “visual primitives” as a scalar-value attribute that can be summed up over multiple patches and compared across different patches. Then the relationship between patches can be defined by <mark><b>counting features</b></mark> and simple arithmetic (<a href="https://arxiv.org/abs/1708.06734">Noroozi, et al, 2017</a>).</p>
<p>The paper considers two transformations:</p>
<ol>
<li><em>Scaling</em>: If an image is scaled up by 2x, the number of visual primitives should stay the same.</li>
<li><em>Tiling</em>: If an image is tiled into a 2x2 grid, the number of visual primitives is expected to be the sum, 4 times the original feature counts.</li>
</ol>
<p>The model learns a feature encoder <script type="math/tex">\phi(.)</script> using the above feature counting relationship. Given an input image <script type="math/tex">\mathbf{x} \in \mathbb{R}^{m \times n \times 3}</script>, considering two types of transformation operators:</p>
<ol>
<li>Downsampling operator, <script type="math/tex">D: \mathbb{R}^{m \times n \times 3} \mapsto \mathbb{R}^{\frac{m}{2} \times \frac{n}{2} \times 3}</script>: downsample by a factor of 2</li>
<li>Tiling operator <script type="math/tex">T_i: \mathbb{R}^{m \times n \times 3} \mapsto \mathbb{R}^{\frac{m}{2} \times \frac{n}{2} \times 3}</script>: extract the <script type="math/tex">i</script>-th tile from a 2x2 grid of the image.</li>
</ol>
<p>We expect to learn:</p>
<script type="math/tex; mode=display">\phi(\mathbf{x}) = \phi(D \circ \mathbf{x}) = \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})</script>
<p><a href="#counting-feature-loss"></a>Thus the MSE loss is: <script type="math/tex">\mathcal{L}_\text{feat} = \|\phi(D \circ \mathbf{x}) - \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})\|^2_2</script>. To avoid trivial solution <script type="math/tex">\phi(\mathbf{x}) = \mathbf{0}, \forall{\mathbf{x}}</script>, another loss term is added to encourage the difference between features of two different images: <script type="math/tex">\mathcal{L}_\text{diff} = \max(0, c -\|\phi(D \circ \mathbf{y}) - \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})\|^2_2)</script>, where <script type="math/tex">\mathbf{y}</script> is another input image different from <script type="math/tex">\mathbf{x}</script> and <script type="math/tex">c</script> is a scalar constant. The final loss is:</p>
<script type="math/tex; mode=display">\mathcal{L}
= \mathcal{L}_\text{feat} + \mathcal{L}_\text{diff}
= \|\phi(D \circ \mathbf{x}) - \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})\|^2_2 + \max(0, M -\|\phi(D \circ \mathbf{y}) - \sum_{i=1}^4 \phi(T_i \circ \mathbf{x})\|^2_2)</script>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/self-sup-counting-features.png" alt="Counting features" /></p>
<p><em>Fig. 7. Self-supervised representation learning by counting features. (Image source: <a href="https://arxiv.org/abs/1708.06734">Noroozi, et al, 2017</a>)</em></p>
<h3 id="colorization">Colorization</h3>
<p><mark><b>Colorization</b></mark> can be used as a powerful self-supervised task: a model is trained to color a grayscale input image; precisely the task is to map this image to a distribution over quantized color value outputs (<a href="https://arxiv.org/abs/1603.08511">Zhang et al. 2016</a>).</p>
<p>The model outputs colors in the the <a href="https://en.wikipedia.org/wiki/CIELAB_color_space">CIE L<em>a</em>b* color space</a>. The L<em>a</em>b* color is designed to approximate human vision, while, in contrast, RGB or CMYK models the color output of physical devices.</p>
<ul>
<li>L* component matches human perception of lightness; L* = 0 is black and L* = 100 indicates white.</li>
<li>a* component represents green (negative) / magenta (positive) value.</li>
<li>b* component models blue (negative) /yellow (positive) value.</li>
</ul>
<p>Due to the multimodal nature of the colorization problem, cross-entropy loss of predicted probability distribution over binned color values works better than L2 loss of the raw color values. The a<em>b</em> color space is quantized with bucket size 10.</p>
<p>To balance between common colors (usually low a<em>b</em> values, of common backgrounds like clouds, walls, and dirt) and rare colors (which are likely associated with key objects in the image), the loss function is rebalanced with a weighting term that boosts the loss of infrequent color buckets. This is just like why we need both <a href="https://en.wikipedia.org/wiki/Tf%E2%80%93idf">tf and idf</a> for scoring words in information retrieval model. The weighting term is constructed as: (1-λ) * Gaussian-kernel-smoothed empirical probability distribution + λ * a uniform distribution, where both distributions are over the quantized a<em>b</em> color space.</p>
<h3 id="generative-modeling">Generative Modeling</h3>
<p>The pretext task in generative modeling is to reconstruct the original input while learning meaningful latent representation.</p>
<p>The <mark><b>denoising autoencoder</b></mark> (<a href="https://www.cs.toronto.edu/~larocheh/publications/icml-2008-denoising-autoencoders.pdf">Vincent, et al, 2008</a>) learns to recover an image from a version that is partially corrupted or has random noise. The design is inspired by the fact that humans can easily recognize objects in pictures even with noise, indicating that key visual features can be extracted and separated from noise. See my <a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#denoising-autoencoder">old post</a>.</p>
<p>The <mark><b>context encoder</b></mark> (<a href="https://arxiv.org/abs/1604.07379">Pathak, et al., 2016</a>) is trained to fill in a missing piece in the image. Let <script type="math/tex">\hat{M}</script> be a binary mask, 0 for dropped pixels and 1 for remaining input pixels. The model is trained with a combination of the reconstruction (L2) loss and the adversarial loss. The removed regions defined by the mask could be of any shape.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}(\mathbf{x}) &= \mathcal{L}_\text{recon}(\mathbf{x}) + \mathcal{L}_\text{adv}(\mathbf{x})\\
\mathcal{L}_\text{recon}(\mathbf{x}) &= \|(1 - \hat{M}) \odot (\mathbf{x} - E(\hat{M} \odot \mathbf{x})) \|_2^2 \\
\mathcal{L}_\text{adv}(\mathbf{x}) &= \max_D \mathbb{E}_{\mathbf{x}} [\log D(\mathbf{x}) + \log(1 - D(E(\hat{M} \odot \mathbf{x})))]
\end{aligned} %]]></script>
<p>where <script type="math/tex">E(.)</script> is the encoder and <script type="math/tex">D(.)</script> is the decoder.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/context-encoder.png" alt="Context encoder" /></p>
<p><em>Fig. 8. Illustration of context encoder. (Image source: <a href="https://arxiv.org/abs/1604.07379">Pathak, et al., 2016</a>)</em></p>
<p>When applying a mask on an image, the context encoder removes information of all the color channels in partial regions. How about only hiding a subset of channels? The <mark><b>split-brain autoencoder</b></mark> (<a href="https://arxiv.org/abs/1611.09842">Zhang et al., 2017</a>) does this by predicting a subset of color channels from the rest of channels. Let the data tensor <script type="math/tex">\mathbf{x} \in \mathbb{R}^{h \times w \times \vert C \vert }</script> with <script type="math/tex">C</script> color channels be the input for the <script type="math/tex">l</script>-th layer of the network. It is split into two disjoint parts, <script type="math/tex">\mathbf{x}_1 \in \mathbb{R}^{h \times w \times \vert C_1 \vert}</script> and <script type="math/tex">\mathbf{x}_2 \in \mathbb{R}^{h \times w \times \vert C_2 \vert}</script>, where <script type="math/tex">C_1 , C_2 \subseteq C</script>. Then two sub-networks are trained to do two complementary predictions: one network <script type="math/tex">f_1</script> predicts <script type="math/tex">\mathbf{x}_2</script> from <script type="math/tex">\mathbf{x}_1</script> and the other network <script type="math/tex">f_1</script> predicts <script type="math/tex">\mathbf{x}_1</script> from <script type="math/tex">\mathbf{x}_2</script>. The loss is either L1 loss or cross entropy if color values are quantized.</p>
<p>The split can happen once on the RGB-D or L<em>a</em>b* colorspace, or happen even in every layer of a CNN network in which the number of channels can be arbitrary.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/split-brain-autoencoder.png" alt="Split-brain autoencoder" /></p>
<p><em>Fig. 9. Illustration of split-brain autoencoder. (Image source: <a href="https://arxiv.org/abs/1611.09842">Zhang et al., 2017</a>)</em></p>
<p>The generative adversarial networks (GANs) are able to learn to map from simple latent variables to arbitrarily complex data distributions. Studies have shown that the latent space of such generative models captures semantic variation in the data; e.g. when training GAN models on human faces, some latent variables are associated with facial expression, glasses, gender, etc (<a href="https://arxiv.org/abs/1511.06434">Radford et al., 2016</a>).</p>
<p><mark><b>Bidirectional GANs</b></mark> (<a href="https://arxiv.org/abs/1605.09782">Donahue, et al, 2017</a>) introduces an additional encoder <script type="math/tex">E(.)</script> to learn the mappings from the input to the latent variable <script type="math/tex">\mathbf{z}</script>. The discriminator <script type="math/tex">D(.)</script> predicts in the joint space of the input data and latent representation, <script type="math/tex">(\mathbf{x}, \mathbf{z})</script>, to tell apart the generated pair <script type="math/tex">(\mathbf{x}, E(\mathbf{x}))</script> from the real one <script type="math/tex">(G(\mathbf{z}), \mathbf{z})</script>. The model is trained to optimize the objective: <script type="math/tex">\min_{G, E} \max_D V(D, E, G)</script>, where the generator <script type="math/tex">G</script> and the encoder <script type="math/tex">E</script> learn to generate data and latent variables that are realistic enough to confuse the discriminator and at the same time the discriminator <script type="math/tex">D</script> tries to differentiate real and generated data.</p>
<script type="math/tex; mode=display">V(D, E, G) = \mathbb{E}_{\mathbf{x} \sim p_\mathbf{x}} [ \underbrace{\mathbb{E}_{\mathbf{z} \sim p_E(.\vert\mathbf{x})}[\log D(\mathbf{x}, \mathbf{z})]}_{\log D(\text{real})} ] + \mathbb{E}_{\mathbf{z} \sim p_\mathbf{z}} [ \underbrace{\mathbb{E}_{\mathbf{x} \sim p_G(.\vert\mathbf{z})}[\log 1 - D(\mathbf{x}, \mathbf{z})]}_{\log(1- D(\text{fake}))}) ]</script>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/bi-GAN.png" alt="BiGAN" /></p>
<p><em>Fig. 10. Illustration of how Bidirectional GAN works. (Image source: <a href="https://arxiv.org/abs/1605.09782">Donahue, et al, 2017</a>)</em></p>
<h3 id="contrastive-predictive-coding">Contrastive Predictive Coding</h3>
<p>The <mark><b>Contrastive Predictive Coding (CPC)</b></mark> (<a href="https://arxiv.org/abs/1807.03748">van den Oord, et al. 2018</a>) is an approach for unsupervised learning from high-dimensional data by translating a generative modeling problem to a classification problem. The <em>contrastive loss</em> or <em>InfoNCE loss</em> in CPC, inspired by <a href="/lil-log/2017/10/15/learning-word-embedding.html#noise-contrastive-estimation-nce">Noise Contrastive Estimation (NCE)</a>, uses cross-entropy loss to measure how well the model can classify the “future” representation amongst a set of unrelated “negative” samples. Such design is partially motivated by the fact that the unimodal loss like MSE has no enough capacity but learning a full generative model could be too expensive.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CPC-audio.png" alt="CPC on audio input" /></p>
<p><em>Fig. 11. Illustration of applying Contrastive Predictive Coding on the audio input. (Image source: <a href="https://arxiv.org/abs/1807.03748">van den Oord, et al. 2018</a>)</em></p>
<p>CPC uses an encoder to compress the input data <script type="math/tex">z_t = g_\text{enc}(x_t)</script> and an <em>autoregressive</em> decoder to learn the high-level context that are potentially shared across future predictions, <script type="math/tex">c_t = g_\text{ar}(z_{\leq t})</script>. The end-to-end training relies on the NCE-inspired contrastive loss.</p>
<p>While predicing future information, CPC is optimized to maximize the the mutual information between input <script type="math/tex">x</script> and context vector <script type="math/tex">c</script>:</p>
<script type="math/tex; mode=display">I(x; c) = \sum_{x, c} p(x, c) \log\frac{p(x, c)}{p(x)p(c)} = \sum_{x, c} p(x, c)\log\frac{p(x|c)}{p(x)}</script>
<p>Rather than modeling the future observations <script type="math/tex">p_k(x_{t+k} \vert c_t)</script> directly (which could be fairly expensive), CPC models a density function to preserve the mutual information between <script type="math/tex">x_{t+k}</script> and <script type="math/tex">c_t</script>:</p>
<script type="math/tex; mode=display">f_k(x_{t+k}, c_t) = \exp(z_{t+k}^\top W_k c_t) \propto \frac{p(x_{t+k}|c_t)}{p(x_{t+k})}</script>
<p>where <script type="math/tex">f_k</script> can be unnormalized and a linear transformation <script type="math/tex">W_k^\top c_t</script> is used for the prediction with a different <script type="math/tex">W_k</script> matrix for every step <script type="math/tex">k</script>.</p>
<p>Given a set of <script type="math/tex">N</script> random samples <script type="math/tex">X = \{x_1, \dots, x_N\}</script> containing only one positive sample <script type="math/tex">x_t \sim p(x_{t+k} \vert c_t)</script> and <script type="math/tex">N-1</script> negative samples <script type="math/tex">x_{i \neq t} \sim p(x_{t+k})</script>, the cross-entropy loss for classifying the positive sample (where <script type="math/tex">\frac{f_k}{\sum f_k}</script> is the prediction) correctly is:</p>
<script type="math/tex; mode=display">\mathcal{L}_N = - \mathbb{E}_X \Big[\log \frac{f_k(x_{t+k}, c_t)}{\sum_{i=1}^N f_k (x_i, c_t)}\Big]</script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CPC-image.png" alt="CPC on images" /></p>
<p><em>Fig. 12. Illustration of applying Contrastive Predictive Coding on images. (Image source: <a href="https://arxiv.org/abs/1807.03748">van den Oord, et al. 2018</a>)</em></p>
<p>When using CPC on images (<a href="https://arxiv.org/abs/1905.09272">Henaff, et al. 2019</a>), the predictor network should only access a masked feature set to avoid a trivial prediction. Precisely:</p>
<ol>
<li>Each input image is divided into a set of overlapped patches and each patch is encoded by a resnet encoder, resulting in compressed feature vector <script type="math/tex">z_{i,j}</script>.</li>
<li>A masked conv net makes prediction with a mask such that the receptive field of a given output neuron can only see things above it in the image. Otherwise, the prediction problem would be trivial. The prediction can be made in both directions (top-down and bottom-up).</li>
<li>The prediction is made for <script type="math/tex">z_{i+k, j}</script> from context <script type="math/tex">c_{i,j}</script>: <script type="math/tex">\hat{z}_{i+k, j} = W_k c_{i,j}</script>.</li>
</ol>
<p>A contrastive loss quantifies this prediction with a goal to correctly identify the target among a set of negative representation <script type="math/tex">\{z_l\}</script> sampled from other patches in the same image and other images in the same batch:</p>
<script type="math/tex; mode=display">\mathcal{L}_\text{CPC}
= -\sum_{i,j,k} \log p(z_{i+k, j} \vert \hat{z}_{i+k, j}, \{z_l\})
= -\sum_{i,j,k} \log \frac{\exp(\hat{z}_{i+k, j}^\top z_{i+k, j})}{\exp(\hat{z}_{i+k, j}^\top z_{i+k, j}) + \sum_l \exp(\hat{z}_{i+k, j}^\top z_l)}</script>
<h3 id="momentum-contrast">Momentum Contrast</h3>
<p><a name="moco"></a><mark><b>Momentum Contrast</b></mark> (<strong>MoCo</strong>; <a href="https://arxiv.org/abs/1911.05722">He et al, 2019</a>) provides a framework of unsupervised learning visual representation as a <em>dynamic dictionary look-up</em>. The dictionary is structured as a large FIFO queue of encoded representations of data samples.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/MoCo.png" alt="MoCo" /></p>
<p><em>Fig. 13. Illustration of how Momentum Contrast (MoCo) learns visual representations. (Image source: <a href="https://arxiv.org/abs/1911.05722">He et al, 2019</a>)</em></p>
<p>Given a query sample <script type="math/tex">x_q</script>, we get a query representation <script type="math/tex">q</script> through an encoder <script type="math/tex">f_q</script>: <script type="math/tex">q = f_q(x_q)</script>. Key samples are encoded by a momentum encoder <script type="math/tex">k_i = f_k (x^k_i)</script> to produce a list of key representations <script type="math/tex">\{k_1, k_2, \dots \}</script> in the dictionary. Let’s assume among them there is a single <em>positive</em> key <script type="math/tex">k^+</script> in the dictionary that matches <script type="math/tex">q</script>. In the paper, <script type="math/tex">k^+</script> is created using a copy of <script type="math/tex">x_q</script> with different augmentation. Then the <a href="#contrastive-predictive-coding">InfoNCE</a> contrastive loss is applied for one positive and <script type="math/tex">K</script> negative samples:</p>
<script type="math/tex; mode=display">\mathcal{L}_q = - \log \frac{\exp(q \cdot k^+ / \tau)}{\sum_{i=0}^K \exp(q \cdot k_i / \tau)}</script>
<p>where <script type="math/tex">\tau</script> is a temperature hyper-parameter.</p>
<p>Compared to another similar idea of <strong>memory bank</strong> (<a href="https://arxiv.org/abs/1805.01978v1">Wu et al, 2018</a>) which stores representations of all the data points in the database and samples a random set of keys as negative examples, a queue-based dictionary in MoCo enables us to reuse representations of immediate preceding mini-batches of data.</p>
<p>The MoCo dictionary is not differentiable as a queue, so we cannot rely on back-propagation to update the key encoder <script type="math/tex">f_k</script>. One naive way might be to use the same encoder for both <script type="math/tex">f_q</script> and <script type="math/tex">f_k</script>. Differently, MoCo proposed to use a momentum-based update. Say, the parameters of <script type="math/tex">f_q</script> and <script type="math/tex">f_k</script> are labeled as <script type="math/tex">\theta_q</script> and <script type="math/tex">\theta_k</script>, respectively.</p>
<script type="math/tex; mode=display">\theta_k \leftarrow m \theta_k + (1-m) \theta_q</script>
<p>where <script type="math/tex">m \in [0, 1)</script> is a momentum coefficient. No gradient flows through <script type="math/tex">f_k</script>’s update.</p>
<p style="width: 68%;" class="center"><img src="/lil-log/assets/images/MoCo-algo.png" alt="MoCo Algorithm" /></p>
<p><em>Fig. 14. Pseudo code pf MoCo in PyTorch style. (Image source: <a href="https://arxiv.org/abs/1911.05722">He et al, 2019</a>)</em></p>
<p><strong>SimCLR</strong> (<a href="https://arxiv.org/abs/2002.05709">Chen et al, 2020</a>) proposed a simple framework for contrastive learning of visual representations. It learns representations for visual inputs by maximizing agreement between differently augmented views of the same sample via a contrastive loss in the latent space.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/SimCLR.png" alt="SimCLR" /></p>
<p><em>Fig. 15. A simple framework for contrastive learning of visual representations. (Image source: <a href="https://arxiv.org/abs/2002.05709">Chen et al, 2020</a>)</em></p>
<p>SimCLR works in the following three steps:</p>
<p>(1) Randomly sample a mini-batch of <script type="math/tex">n</script> samples and each sample is applied with two different data augmentation operations, resulting in <script type="math/tex">2n</script> augmented samples in total.</p>
<script type="math/tex; mode=display">\tilde{\mathbf{x}}_i = t(\mathbf{x}),\quad\tilde{\mathbf{x}}_j = t'(\mathbf{x}),\quad t, t' \sim \mathcal{T}</script>
<p>where two separate data augmentation operators, <script type="math/tex">t</script> and <script type="math/tex">t’</script>, are sampled from the same family of augmentations <script type="math/tex">\mathcal{T}</script>. Data augmentation includes random crop, resize with random flip, color distortions, and Gaussian blur.</p>
<p>(2) Given one positive pair, other <script type="math/tex">2(n-1)</script> data points are treated as negative samples. The representation is produced by a base encoder <script type="math/tex">f(.)</script>:</p>
<script type="math/tex; mode=display">\mathbf{h}_i = f(\tilde{\mathbf{x}}_i),\quad \mathbf{h}_j = f(\tilde{\mathbf{x}}_j)</script>
<p>(3) The contrastive loss is defined using cosine similarity <script type="math/tex">\text{sim}(.,.)</script>. Note that the loss operates on top of an extra projection of the representation via <script type="math/tex">g(.)</script> rather than on the representation <script type="math/tex">\mathbf{h}</script> directly. But only the representation <script type="math/tex">\mathbf{h}</script> is used for downstream tasks.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{z}_i &= g(\mathbf{h}_i),\quad
\mathbf{z}_j = g(\mathbf{h}_j),\quad
\text{sim}(\mathbf{z}_i, \mathbf{z}_j) = \frac{\mathbf{z}_i^\top\mathbf{z}_j}{\|\mathbf{z}_i\| \|\mathbf{z}_j\|} \\
\mathcal{L}_{i,j} &= - \log\frac{\exp(\text{sim}(\mathbf{z}_i, \mathbf{z}_j) / \tau)}{\sum_{k=1}^{2n} \mathbf{1}_{[k \neq i]} \exp(\text{sim}(\mathbf{z}_i, \mathbf{z}_k) / \tau)}
\end{aligned} %]]></script>
<p>where <script type="math/tex">\mathbf{1}_{[k \neq i]}</script> is an indicator function: 1 if <script type="math/tex">k\neq i</script> 0 otherwise. <script type="math/tex">\tau</script> is a temperature hyperparameter.</p>
<p style="width: 58%;" class="center"><img src="/lil-log/assets/images/SimCLR-algo.png" alt="SimCLR Algorithm" /></p>
<p><em>Fig. 16. The algorithm for SimCLR. (Image source: <a href="https://arxiv.org/abs/2002.05709">Chen et al, 2020</a>).</em></p>
<p><strong>CURL</strong> (<a href="https://arxiv.org/abs/2004.04136">Srinivas & Laskin, et al. 2020</a>) applies the above ideas in Reinforcement Learning. It learns a visual representation for RL tasks by matching embeddings of two data-augmented versions, <script type="math/tex">o_q</script> and <script type="math/tex">o_k</script>, of the raw observation <script type="math/tex">o</script> via contrastive loss. CURL primarily relies on random crop data augmentation. The key encoder is implemented as a momentum encoder with weights as EMA of the query encoder weights, same as in <a href="#moco">MoCo</a>.</p>
<p>One significant difference between RL and supervised visual tasks is that RL depends on <em>temporal</em> consistency between consecutive frames. Therefore, CURL applies augmentation consistently on each stack of frames to retain information about the temporal structure of the observation.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/CURL.png" alt="CURL" /></p>
<p><em>Fig. 17. The architecture of CURL. (Image source: <a href="https://arxiv.org/abs/2004.04136">Srinivas & Laskin, et al. 2020</a>)</em></p>
<h2 id="video-based">Video-Based</h2>
<p>A video contains a sequence of semantically related frames. Nearby frames are close in time and more correlated than frames further away. The order of frames describes certain rules of reasonings and physical logics; such as that object motion should be smooth and gravity is pointing down.</p>
<p>A common workflow is to train a model on one or multiple pretext tasks with unlabelled videos and then feed one intermediate feature layer of this model to fine-tune a simple model on downstream tasks of action classification, segmentation or object tracking.</p>
<h3 id="tracking">Tracking</h3>
<p>The movement of an object is traced by a sequence of video frames. The difference between how the same object is captured on the screen in close frames is usually not big, commonly triggered by small motion of the object or the camera. Therefore any visual representation learned for the same object across close frames should be close in the latent feature space. Motivated by this idea, <a href="https://arxiv.org/abs/1505.00687">Wang & Gupta, 2015</a> proposed a way of unsupervised learning of visual representation by <mark><b>tracking moving objects</b></mark> in videos.</p>
<p>Precisely patches with motion are tracked over a small time window (e.g. 30 frames). The first patch <script type="math/tex">\mathbf{x}</script> and the last patch <script type="math/tex">\mathbf{x}^+</script> are selected and used as training data points. If we train the model directly to minimize the difference between feature vectors of two patches, the model may only learn to map everything to the same value. To avoid such a trivial solution, same as <a href="#counting-feature-loss">above</a>, a random third patch <script type="math/tex">\mathbf{x}^-</script> is added. The model learns the representation by enforcing the distance between two tracked patches to be closer than the distance between the first patch and a random one in the feature space, <script type="math/tex">D(\mathbf{x}, \mathbf{x}^-)) > D(\mathbf{x}, \mathbf{x}^+)</script>, where <script type="math/tex">D(.)</script> is the cosine distance,</p>
<script type="math/tex; mode=display">D(\mathbf{x}_1, \mathbf{x}_2) = 1 - \frac{f(\mathbf{x}_1) f(\mathbf{x}_2)}{\|f(\mathbf{x}_1)\| \|f(\mathbf{x}_2\|)}</script>
<p>The loss function is:</p>
<script type="math/tex; mode=display">\mathcal{L}(\mathbf{x}, \mathbf{x}^+, \mathbf{x}^-)
= \max\big(0, D(\mathbf{x}, \mathbf{x}^+) - D(\mathbf{x}, \mathbf{x}^-) + M\big) + \text{weight decay regularization term}</script>
<p>where <script type="math/tex">M</script> is a scalar constant controlling for the minimum gap between two distances; <script type="math/tex">M=0.5</script> in the paper. The loss enforces <script type="math/tex">D(\mathbf{x}, \mathbf{x}^-) >= D(\mathbf{x}, \mathbf{x}^+) + M</script> at the optimal case.</p>
<p><a href="#triplet-loss"></a>This form of loss function is also known as <a href="https://arxiv.org/abs/1503.03832">triplet loss</a> in the face recognition task, in which the dataset contains images of multiple people from multiple camera angles. Let <script type="math/tex">\mathbf{x}^a</script> be an anchor image of a specific person, <script type="math/tex">\mathbf{x}^p</script> be a positive image of this same person from a different angle and <script type="math/tex">\mathbf{x}^n</script> be a negative image of a different person. In the embedding space, <script type="math/tex">\mathbf{x}^a</script> should be closer to <script type="math/tex">\mathbf{x}^p</script> than <script type="math/tex">\mathbf{x}^n</script>:</p>
<script type="math/tex; mode=display">\mathcal{L}_\text{triplet}(\mathbf{x}^a, \mathbf{x}^p, \mathbf{x}^n) = \max(0, \|\phi(\mathbf{x}^a) - \phi(\mathbf{x}^p) \|_2^2 - \|\phi(\mathbf{x}^a) - \phi(\mathbf{x}^n) \|_2^2 + M)</script>
<p><a href="#n-pair-loss"></a>A slightly different form of the triplet loss, named <a href="https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective">n-pair loss</a> is also commonly used for learning observation embedding in robotics tasks. See a <a href="#multi-view-metric-learning">later section</a> for more related content.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/tracking-videos.png" alt="tracking videos" /></p>
<p><em>Fig. 18. Overview of learning representation by tracking objects in videos. (a) Identify moving patches in short traces; (b) Feed two related patched and one random patch into a conv network with shared weights. (c) The loss function enforces the distance between related patches to be closer than the distance between random patches. (Image source: <a href="https://arxiv.org/abs/1505.00687">Wang & Gupta, 2015</a>)</em></p>
<p>Relevant patches are tracked and extracted through a two-step unsupervised <a href="https://en.wikipedia.org/wiki/Optical_flow">optical flow</a> approach:</p>
<ol>
<li>Obtain <a href="https://www.vision.ee.ethz.ch/~surf/eccv06.pdf">SURF</a> interest points and use <a href="https://hal.inria.fr/hal-00873267v2/document">IDT</a> to obtain motion of each SURF point.</li>
<li>Given the trajectories of SURF interest points, classify these points as moving if the flow magnitude is more than 0.5 pixels.</li>
</ol>
<p>During training, given a pair of correlated patches <script type="math/tex">\mathbf{x}</script> and <script type="math/tex">\mathbf{x}^+</script>, <script type="math/tex">K</script> random patches <script type="math/tex">\{\mathbf{x}^-\}</script> are sampled in this same batch to form <script type="math/tex">K</script> training triplets. After a couple of epochs, <em>hard negative mining</em> is applied to make the training harder and more efficient, that is, to search for random patches that maximize the loss and use them to do gradient updates.</p>
<h3 id="frame-sequence">Frame Sequence</h3>
<p>Video frames are naturally positioned in chronological order. Researchers have proposed several self-supervised tasks, motivated by the expectation that good representation should learn the <em>correct sequence</em> of frames.</p>
<p>One idea is to <mark><b>validate frame order</b></mark> (<a href="https://arxiv.org/abs/1603.08561">Misra, et al 2016</a>). The pretext task is to determine whether a sequence of frames from a video is placed in the correct temporal order (“temporal valid”). The model needs to track and reason about small motion of an object across frames to complete such a task.</p>
<p>The training frames are sampled from high-motion windows. Every time 5 frames are sampled <script type="math/tex">(f_a, f_b, f_c, f_d, f_e)</script> and the timestamps are in order <script type="math/tex">% <![CDATA[
a < b < c < d < e %]]></script>. Out of 5 frames, one positive tuple <script type="math/tex">(f_b, f_c, f_d)</script> and two negative tuples, <script type="math/tex">(f_b, f_a, f_d)</script> and <script type="math/tex">(f_b, f_e, f_d)</script> are created. The parameter <script type="math/tex">\tau_\max = \vert b-d \vert</script> controls the difficulty of positive training instances (i.e. higher → harder) and the parameter <script type="math/tex">\tau_\min = \min(\vert a-b \vert, \vert d-e \vert)</script> controls the difficulty of negatives (i.e. lower → harder).</p>
<p>The pretext task of video frame order validation is shown to improve the performance on the downstream task of action recognition when used as a pretraining step.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/frame-order-validation.png" alt="frame order validation" /></p>
<p><em>Fig. 19. Overview of learning representation by validating the order of video frames. (a) the data sample process; (b) the model is a triplet siamese network, where all input frames have shared weights. (Image source: <a href="https://arxiv.org/abs/1603.08561">Misra, et al 2016</a>)</em></p>
<p>The task in <em>O3N</em> (Odd-One-Out Network; <a href="https://arxiv.org/abs/1611.06646">Fernando et al. 2017</a>) is based on video frame sequence validation too. One step further from above, the task is to <mark><b>pick the incorrect sequence</b></mark> from multiple video clips.</p>
<p>Given <script type="math/tex">N+1</script> input video clips, one of them has frames shuffled, thus in the wrong order, and the rest <script type="math/tex">N</script> of them remain in the correct temporal order. O3N learns to predict the location of the odd video clip. In their experiments, there are 6 input clips and each contain 6 frames.</p>
<p>The <mark><b>arrow of time</b></mark> in a video contains very informative messages, on both low-level physics (e.g. gravity pulls objects down to the ground; smoke rises up; water flows downward.) and high-level event reasoning (e.g. fish swim forward; you can break an egg but cannot revert it.). Thus another idea is inspired by this to learn latent representation by predicting the arrow of time (AoT) — whether video playing forwards or backwards (<a href="https://www.robots.ox.ac.uk/~vgg/publications/2018/Wei18/wei18.pdf">Wei et al., 2018</a>).</p>
<p>A classifier should capture both low-level physics and high-level semantics in order to predict the arrow of time. The proposed <em>T-CAM</em> (Temporal Class-Activation-Map) network accepts <script type="math/tex">T</script> groups, each containing a number of frames of optical flow. The conv layer outputs from each group are concatenated and fed into binary logistic regression for predicting the arrow of time.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/learning-arrow-of-time.png" alt="Learning the arrow of time" /></p>
<p><em>Fig. 20. Overview of learning representation by predicting the arrow of time. (a) Conv features of multiple groups of frame sequences are concatenated. (b) The top level contains 3 conv layers and average pooling. (Image source: <a href="https://www.robots.ox.ac.uk/~vgg/publications/2018/Wei18/wei18.pdf">Wei et al, 2018</a>)</em></p>
<p>Interestingly, there exist a couple of artificial cues in the dataset. If not handled properly, they could lead to a trivial classifier without relying on the actual video content:</p>
<ul>
<li>Due to the video compression, the black framing might not be completely black but instead may contain certain information on the chronological order. Hence black framing should be removed in the experiments.</li>
<li>Large camera motion, like vertical translation or zoom-in/out, also provides strong signals for the arrow of time but independent of content. The processing stage should stabilize the camera motion.</li>
</ul>
<p>The AoT pretext task is shown to improve the performance on action classification downstream task when used as a pretraining step. Note that fine-tuning is still needed.</p>
<h3 id="video-colorization">Video Colorization</h3>
<p><a href="https://arxiv.org/abs/1806.09594">Vondrick et al. (2018)</a> proposed <mark><b>video colorization</b></mark> as a self-supervised learning problem, resulting in a rich representation that can be used for video segmentation and unlabelled visual region tracking, <em>without extra fine-tuning</em>.</p>
<p>Unlike the image-based <a href="#colorization">colorization</a>, here the task is to copy colors from a normal reference frame in color to another target frame in grayscale by leveraging the natural temporal coherency of colors across video frames (thus these two frames shouldn’t be too far apart in time). In order to copy colors consistently, the model is designed to learn to keep track of correlated pixels in different frames.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/video-colorization.png" alt="Video colorization" /></p>
<p><em>Fig. 21. Video colorization by copying colors from a reference frame to target frames in grayscale. (Image source: <a href="https://arxiv.org/abs/1806.09594">Vondrick et al. 2018</a>)</em></p>
<p>The idea is quite simple and smart. Let <script type="math/tex">c_i</script> be the true color of the <script type="math/tex">i-th</script> pixel in the reference frame and <script type="math/tex">c_j</script> be the color of <script type="math/tex">j</script>-th pixel in the target frame. The predicted color of <script type="math/tex">j</script>-th color in the target <script type="math/tex">\hat{c}_j</script> is a weighted sum of colors of all the pixels in reference, where the weighting term measures the similarity:</p>
<script type="math/tex; mode=display">\hat{c}_j = \sum_i A_{ij} c_i \text{ where } A_{ij} = \frac{\exp(f_i f_j)}{\sum_{i'} \exp(f_{i'} f_j)}</script>
<p>where <script type="math/tex">f</script> are learned embeddings for corresponding pixels; <script type="math/tex">i’</script> indexes all the pixels in the reference frame. The weighting term implements an attention-based pointing mechanism, similar to <a href="/lil-log/2018/11/30/meta-learning.html#matching-networks">matching network</a> and <a href="/lil-log/2018/06/24/attention-attention.html#pointer-network">pointer network</a>. As the full similarity matrix could be really large, both frames are downsampled. The categorical cross-entropy loss between <script type="math/tex">c_j</script> and <script type="math/tex">\hat{c}_j</script> is used with quantized colors, just like in <a href="https://arxiv.org/abs/1603.08511">Zhang et al. 2016</a>.</p>
<p>Based on how the reference frame are marked, the model can be used to complete several color-based downstream tasks such as tracking segmentation or human pose in time. No fine-tuning is needed. See Fig. 15.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/video-colorization-examples.png" alt="Video colorization for tracking" /></p>
<p><em>Fig. 22. Use video colorization to track object segmentation and human pose in time. (Image source: <a href="https://arxiv.org/abs/1806.09594">Vondrick et al. (2018)</a>)</em></p>
<h2 id="control-based">Control-Based</h2>
<p>When running a RL policy in the real world, such as controlling a physical robot on visual inputs, it is non-trivial to properly track states, obtain reward signals or determine whether a goal is achieved for real. The visual data has a lot of noise that is irrelevant to the true state and thus the equivalence of states cannot be inferred from pixel-level comparison. Self-supervised representation learning has shown great potential in learning useful state embedding that can be used directly as input to a control policy.</p>
<p>All the cases discussed in this section are in robotic learning, mainly for state representation from multiple camera views and goal representation.</p>
<h3 id="multi-view-metric-learning">Multi-View Metric Learning</h3>
<p>The concept of metric learning has been mentioned multiple times in the <a href="#counting-feature-loss">previous</a> <a href="#tracking">sections</a>. A common setting is: Given a triple of samples, (<em>anchor</em> <script type="math/tex">s_a</script>, <em>positive</em> sample <script type="math/tex">s_p</script>, <em>negative</em> sample <script type="math/tex">s_n</script>), the learned representation embedding <script type="math/tex">\phi(s)</script> fulfills that <script type="math/tex">s_a</script> stays close to <script type="math/tex">s_p</script> but far away from <script type="math/tex">s_n</script> in the latent space.</p>
<p><a href="#grasp2vec"></a><mark><b>Grasp2Vec</b></mark> (<a href="https://arxiv.org/abs/1811.06964">Jang & Devin et al., 2018</a>) aims to learn an object-centric vision representation in the robot grasping task from free, unlabelled grasping activities. By object-centric, it means that, irrespective of how the environment or the robot looks like, if two images contain similar items, they should be mapped to similar representation; otherwise the embeddings should be far apart.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/grasp2vec.png" alt="Grasp2vec" /></p>
<p><em>Fig. 23. A conceptual illustration of how grasp2vec learns an object-centric state embedding. (Image source: <a href="https://arxiv.org/abs/1811.06964">Jang & Devin et al., 2018</a>)</em></p>
<p>The grasping system can tell whether it moves an object but cannot tell which object it is. Cameras are set up to take images of the entire scene and the grasped object. During early training, the grasp robot is executed to grasp any object <script type="math/tex">o</script> at random, producing a triple of images, <script type="math/tex">(s_\text{pre}, s_\text{post}, o)</script>:</p>
<ul>
<li><script type="math/tex">o</script> is an image of the grasped object held up to the camera;</li>
<li><script type="math/tex">s_\text{pre}</script> is an image of the scene <em>before</em> grasping, with the object <script type="math/tex">o</script> in the tray;</li>
<li><script type="math/tex">s_\text{post}</script> is an image of the same scene <em>after</em> grasping, without the object <script type="math/tex">o</script> in the tray.</li>
</ul>
<p>To learn object-centric representation, we expect the difference between embeddings of <script type="math/tex">s_\text{pre}</script> and <script type="math/tex">s_\text{post}</script> to capture the removed object <script type="math/tex">o</script>. The idea is quite interesting and similar to relationships that have been observed in <a href="/lil-log/2017/10/15/learning-word-embedding.html">word embedding</a>, <a href="https://developers.google.com/machine-learning/crash-course/embeddings/translating-to-a-lower-dimensional-space">e.g.</a> distance(“king”, “queen”) ≈ distance(“man”, “woman”).</p>
<p>Let <script type="math/tex">\phi_s</script> and <script type="math/tex">\phi_o</script> be the embedding functions for the scene and the object respectively. The model learns the representation by minimizing the distance between <script type="math/tex">\phi_s(s_\text{pre}) - \phi_s(s_\text{post})</script> and <script type="math/tex">\phi_o(o)</script> using <em>n-pair loss</em>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}_\text{grasp2vec} &= \text{NPair}(\phi_s(s_\text{pre}) - \phi_s(s_\text{post}), \phi_o(o)) + \text{NPair}(\phi_o(o), \phi_s(s_\text{pre}) - \phi_s(s_\text{post})) \\
\text{where }\text{NPair}(a, p) &= \sum_{i<B} -\log\frac{\exp(a_i^\top p_j)}{\sum_{j<B, i\neq j}\exp(a_i^\top p_j)} + \lambda (\|a_i\|_2^2 + \|p_i\|_2^2)
\end{aligned} %]]></script>
<p>where <script type="math/tex">B</script> refers to a batch of (anchor, positive) sample pairs.</p>
<p>When framing representation learning as metric learning, <a href="https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective"><strong>n-pair loss</strong></a> is a common choice. Rather than processing explicit a triple of (anchor, positive, negative) samples, the n-pairs loss treats all other positive instances in one mini-batch across pairs as negatives.</p>
<p>The embedding function <script type="math/tex">\phi_o</script> works great for presenting a goal <script type="math/tex">g</script> with an image. The reward function that quantifies how close the actually grasped object <script type="math/tex">o</script> is close to the goal is defined as <script type="math/tex">r = \phi_o(g) \cdot \phi_o(o)</script>. Note that computing rewards only relies on the learned latent space and doesn’t involve ground truth positions, so it can be used for training on real robots.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/grasp2vec-attention-map.png" alt="Grasp2vec attention map" /></p>
<p><em>Fig. 24. Localization results of grasp2vec embedding. The heatmap of localizing a goal object in a pre-grasping scene is defined as <script type="math/tex">\phi_o(o)^\top \phi_{s, \text{spatial}} (s_\text{pre})</script>, where <script type="math/tex">\phi_{s, \text{spatial}}</script> is the output of the last resnet block after ReLU. The fourth column is a failure case and the last three columns take real images as goals. (Image source: <a href="https://arxiv.org/abs/1811.06964">Jang & Devin et al., 2018</a>)</em></p>
<p>Other than the embedding-similarity-based reward function, there are a few other tricks for training the RL policy in the grasp2vec framework:</p>
<ul>
<li><em>posthoc labelingP</em>: Augment the dataset by labeling a randomly grasped object as a correct goal, like HER (Hindsight Experience Replay; <a href="https://papers.nips.cc/paper/7090-hindsight-experience-replay.pdf">Andrychowicz, et al., 2017</a>).</li>
<li><em>Auxiliary goal augmentation</em>: Augment the replay buffer even further by relabeling transitions with unachieved goals; precisely, in each iteration, two goals are sampled <script type="math/tex">(g, g')</script> and both are used to add new transitions into replay buffer.</li>
</ul>
<p><a href="#tcn"></a><strong>TCN</strong> (<mark><b>Time-Contrastive Networks</b></mark>; <a href="https://arxiv.org/abs/1704.06888">Sermanet, et al. 2018</a>) learn from multi-camera view videos with the intuition that different viewpoints at the same timestep of the same scene should share the same embedding (like in <a href="https://arxiv.org/abs/1503.03832">FaceNet</a>) while embedding should vary in time, even of the same camera viewpoint. Therefore embedding captures the semantic meaning of the underlying state rather than visual similarity. The TCN embedding is trained with <a href="#triplet-loss">triplet loss</a>.</p>
<p>The training data is collected by taking videos of the same scene simultaneously but from different angles. All the videos are unlabelled.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/TCN.png" alt="Time-contrastive network" /></p>
<p><em>Fig. 25. An illustration of time-contrastive approach for learning state embedding. The blue frames selected from two camera views at the same timestep are anchor and positive samples, while the red frame at a different timestep is the negative sample.</em></p>
<p>TCN embedding extracts visual features that are invariant to camera configurations. It can be used to construct a reward function for imitation learning based on the euclidean distance between the demo video and the observations in the latent space.</p>
<p>A further improvement over TCN is to learn embedding over multiple frames jointly rather than a single frame, resulting in <strong>mfTCN</strong> (<b><mark>Multi-frame</mark> Time-Contrastive Networks</b>; <a href="https://arxiv.org/abs/1808.00928">Dwibedi et al., 2019</a>). Given a set of videos from several synchronized camera viewpoints, <script type="math/tex">v_1, v_2, \dots, v_k</script>, the frame at time <script type="math/tex">t</script> and the previous <script type="math/tex">n-1</script> frames selected with stride <script type="math/tex">s</script> in each video are aggregated and mapped into one embedding vector, resulting in a lookback window of size $(n−1) \times s + 1$. Each frame first goes through a CNN to extract low-level features and then we use 3D temporal convolutions to aggregate frames in time. The model is trained with <a href="#n-pair-loss">n-pairs loss</a>.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/mfTCN.png" alt="mfTCN" /></p>
<p><em>Fig. 26. The sampling process for training mfTCN. (Image source: <a href="https://arxiv.org/abs/1808.00928">Dwibedi et al., 2019</a>)</em></p>
<p>The training data is sampled as follows:</p>
<ol>
<li>First we construct two pairs of video clips. Each pair contains two clips from different camera views but with synchronized timesteps. These two sets of videos should be far apart in time.</li>
<li>Sample a fixed number of frames from each video clip in the same pair simultaneously with the same stride.</li>
<li>Frames with the same timesteps are trained as positive samples in the n-pair loss, while frames across pairs are negative samples.</li>
</ol>
<p>mfTCN embedding can capture the position and velocity of objects in the scene (e.g. in cartpole) and can also be used as inputs for policy.</p>
<h3 id="autonomous-goal-generation">Autonomous Goal Generation</h3>
<p><strong>RIG</strong> (<b>Reinforcement learning with <mark>Imagined Goals</mark></b>; <a href="https://arxiv.org/abs/1807.04742">Nair et al., 2018</a>) described a way to train a goal-conditioned policy with unsupervised representation learning. A policy learns from self-supervised practice by first imagining “fake” goals and then trying to achieve them.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RIG.png" alt="RIG" /></p>
<p><em>Fig. 27. The workflow of RIG. (Image source: <a href="https://arxiv.org/abs/1807.04742">Nair et al., 2018</a>)</em></p>
<p>The task is to control a robot arm to push a small puck on a table to a desired position. The desired position, or the goal, is present in an image. During training, it learns latent embedding of both state <script type="math/tex">s</script> and goal <script type="math/tex">g</script> through $\beta$-VAE encoder and the control policy operates entirely in the latent space.</p>
<p>Let’s say a <a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#beta-vae"><script type="math/tex">\beta</script>-VAE</a> has an encoder <script type="math/tex">q_\phi</script> mapping input states to latent variable <script type="math/tex">z</script> which is modeled by a Gaussian distribution and a decoder <script type="math/tex">p_\psi</script> mapping <script type="math/tex">z</script> back to the states. The state encoder in RIG is set to be the mean of <script type="math/tex">\beta</script>-VAE encoder.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
z &\sim q_\phi(z \vert s) = \mathcal{N}(z; \mu_\phi(s), \sigma^2_\phi(s)) \\
\mathcal{L}_{\beta\text{-VAE}} &= - \mathbb{E}_{z \sim q_\phi(z \vert s)} [\log p_\psi (s \vert z)] + \beta D_\text{KL}(q_\phi(z \vert s) \| p_\psi(s)) \\
e(s) &\triangleq \mu_\phi(s)
\end{aligned} %]]></script>
<p>The reward is the Euclidean distance between state and goal embedding vectors: <script type="math/tex">r(s, g) = -\|e(s) - e(g)\|</script>. Similar to <a href="#grasp2vec">grasp2vec</a>, RIG applies data augmentation as well by latent goal relabeling: precisely half of the goals are generated from the prior at random and the other half are selected using HER. Also same as grasp2vec, rewards do not depend on any ground truth states but only the learned state encoding, so it can be used for training on real robots.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RIG-algorithm.png" alt="RIG algorithm" /></p>
<p><em>Fig. 28. The algorithm of RIG. (Image source: <a href="https://arxiv.org/abs/1807.04742">Nair et al., 2018</a>)</em></p>
<p>The problem with RIG is a lack of object variations in the imagined goal pictures. If <script type="math/tex">\beta</script>-VAE is only trained with a black puck, it would not be able to create a goal with other objects like blocks of different shapes and colors. A follow-up improvement replaces <script type="math/tex">\beta</script>-VAE with a <strong>CC-VAE</strong> (Context-Conditioned VAE; <a href="https://arxiv.org/abs/1910.11670">Nair, et al., 2019</a>), inspired by <strong>CVAE</strong> (Conditional VAE; <a href="https://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models">Sohn, Lee & Yan, 2015</a>), for goal generation.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CC-RIG.png" alt="Context-conditional RIG" /></p>
<p><em>Fig. 29. The workflow of context-conditioned RIG. (Image source: <a href="https://arxiv.org/abs/1910.11670">Nair, et al., 2019</a>).</em></p>
<p>A CVAE conditions on a context variable <script type="math/tex">c</script>. It trains an encoder <script type="math/tex">q_\phi(z \vert s, c)</script> and a decoder <script type="math/tex">p_\psi (s \vert z, c)</script> and note that both have access to <script type="math/tex">c</script>. The CVAE loss penalizes information passing from the input state <script type="math/tex">s</script> through an information bottleneck but allows for <em>unrestricted</em> information flow from <script type="math/tex">c</script> to both encoder and decoder.</p>
<script type="math/tex; mode=display">\mathcal{L}_\text{CVAE} = - \mathbb{E}_{z \sim q_\phi(z \vert s,c)} [\log p_\psi (s \vert z, c)] + \beta D_\text{KL}(q_\phi(z \vert s, c) \| p_\psi(s))</script>
<p>To create plausible goals, CC-VAE conditions on a starting state <script type="math/tex">s_0</script> so that the generated goal presents a consistent type of object as in <script type="math/tex">s_0</script>. This goal consistency is necessary; e.g. if the current scene contains a red puck but the goal has a blue block, it would confuse the policy.</p>
<p>Other than the state encoder <script type="math/tex">e(s) \triangleq \mu_\phi(s)</script>, CC-VAE trains a second convolutional encoder <script type="math/tex">e_0(.)</script> to translate the starting state <script type="math/tex">s_0</script> into a compact context representation <script type="math/tex">c = e_0(s_0)</script>. Two encoders, <script type="math/tex">e(.)</script> and <script type="math/tex">e_0(.)</script>, are intentionally different without shared weights, as they are expected to encode different factors of image variation. In addition to the loss function of CVAE, CC-VAE adds an extra term to learn to reconstruct <script type="math/tex">c</script> back to <script type="math/tex">s_0</script>, <script type="math/tex">\hat{s}_0 = d_0(c)</script>.</p>
<script type="math/tex; mode=display">\mathcal{L}_\text{CC-VAE} = \mathcal{L}_\text{CVAE} + \log p(s_0\vert c)</script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CC-RIG-goal-samples.png" alt="RIG goal samples" /></p>
<p><em>Fig. 30. Examples of imagined goals generated by CVAE that conditions on the context image (the first row), while VAE fails to capture the object consistency. (Image source: <a href="https://arxiv.org/abs/1910.11670">Nair, et al., 2019</a>).</em></p>
<blockquote>
<p>A couple common observations:</p>
<ul>
<li>Combining multiple pretext tasks improves performance;</li>
<li>Deeper networks improve the quality of representation;</li>
<li>Supervised learning baselines still beat all of them by far.</li>
</ul>
</blockquote>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019selfsup,
title = "Self-Supervised Representation Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "https://lilianweng.github.io/lil-log/2019/11/10/self-supervised-learning.html"
}
</code></pre></div></div>
<h3 id="references">References</h3>
<p>[1] Alexey Dosovitskiy, et al. <a href="https://arxiv.org/abs/1406.6909">“Discriminative unsupervised feature learning with exemplar convolutional neural networks.”</a> IEEE transactions on pattern analysis and machine intelligence 38.9 (2015): 1734-1747.</p>
<p>[2] Spyros Gidaris, Praveer Singh & Nikos Komodakis. <a href="https://arxiv.org/abs/1803.07728">“Unsupervised Representation Learning by Predicting Image Rotations”</a> ICLR 2018.</p>
<p>[3] Carl Doersch, Abhinav Gupta, and Alexei A. Efros. <a href="https://arxiv.org/abs/1505.05192">“Unsupervised visual representation learning by context prediction.”</a> ICCV. 2015.</p>
<p>[4] Mehdi Noroozi & Paolo Favaro. <a href="https://arxiv.org/abs/1603.09246">“Unsupervised learning of visual representations by solving jigsaw puzzles.”</a> ECCV, 2016.</p>
<p>[5] Mehdi Noroozi, Hamed Pirsiavash, and Paolo Favaro. <a href="https://arxiv.org/abs/1708.06734">“Representation learning by learning to count.”</a> ICCV. 2017.</p>
<p>[6] Richard Zhang, Phillip Isola & Alexei A. Efros. <a href="https://arxiv.org/abs/1603.08511">“Colorful image colorization.”</a> ECCV, 2016.</p>
<p>[7] Pascal Vincent, et al. <a href="https://www.cs.toronto.edu/~larocheh/publications/icml-2008-denoising-autoencoders.pdf">“Extracting and composing robust features with denoising autoencoders.”</a> ICML, 2008.</p>
<p>[8] Jeff Donahue, Philipp Krähenbühl, and Trevor Darrell. <a href="https://arxiv.org/abs/1605.09782">“Adversarial feature learning.”</a> ICLR 2017.</p>
<p>[9] Deepak Pathak, et al. <a href="https://arxiv.org/abs/1604.07379">“Context encoders: Feature learning by inpainting.”</a> CVPR. 2016.</p>
<p>[10] Richard Zhang, Phillip Isola, and Alexei A. Efros. <a href="https://arxiv.org/abs/1611.09842">“Split-brain autoencoders: Unsupervised learning by cross-channel prediction.”</a> CVPR. 2017.</p>
<p>[11] Xiaolong Wang & Abhinav Gupta. <a href="https://arxiv.org/abs/1505.00687">“Unsupervised Learning of Visual Representations using Videos.”</a> ICCV. 2015.</p>
<p>[12] Carl Vondrick, et al. <a href="https://arxiv.org/pdf/1806.09594.pdf">“Tracking Emerges by Colorizing Videos”</a> ECCV. 2018.</p>
<p>[13] Ishan Misra, C. Lawrence Zitnick, and Martial Hebert. <a href="https://arxiv.org/abs/1603.08561">“Shuffle and learn: unsupervised learning using temporal order verification.”</a> ECCV. 2016.</p>
<p>[14] Basura Fernando, et al. <a href="https://arxiv.org/abs/1611.06646">“Self-Supervised Video Representation Learning With Odd-One-Out Networks”</a> CVPR. 2017.</p>
<p>[15] Donglai Wei, et al. <a href="https://www.robots.ox.ac.uk/~vgg/publications/2018/Wei18/wei18.pdf">“Learning and Using the Arrow of Time”</a> CVPR. 2018.</p>
<p>[16] Florian Schroff, Dmitry Kalenichenko and James Philbin. <a href="https://arxiv.org/abs/1503.03832">“FaceNet: A Unified Embedding for Face Recognition and Clustering”</a> CVPR. 2015.</p>
<p>[17] Pierre Sermanet, et al. <a href="https://arxiv.org/abs/1704.06888">“Time-Contrastive Networks: Self-Supervised Learning from Video”</a> CVPR. 2018.</p>
<p>[18] Debidatta Dwibedi, et al. <a href="https://arxiv.org/abs/1808.00928">“Learning actionable representations from visual observations.”</a> IROS. 2018.</p>
<p>[19] Eric Jang & Coline Devin, et al. <a href="https://arxiv.org/abs/1811.06964">“Grasp2Vec: Learning Object Representations from Self-Supervised Grasping”</a> CoRL. 2018.</p>
<p>[20] Ashvin Nair, et al. <a href="https://arxiv.org/abs/1807.04742">“Visual reinforcement learning with imagined goals”</a> NeuriPS. 2018.</p>
<p>[21] Ashvin Nair, et al. <a href="https://arxiv.org/abs/1910.11670">“Contextual imagined goals for self-supervised robotic learning”</a> CoRL. 2019.</p>
<p>[22] Aaron van den Oord, Yazhe Li & Oriol Vinyals. <a href="https://arxiv.org/abs/1807.03748">“Representation Learning with Contrastive Predictive Coding”</a> arXiv preprint arXiv:1807.03748, 2018.</p>
<p>[23] Olivier J. Henaff, et al. <a href="https://arxiv.org/abs/1905.09272">“Data-Efficient Image Recognition with Contrastive Predictive Coding”</a> arXiv preprint arXiv:1905.09272, 2019.</p>
<p>[24] Kaiming He, et al. <a href="https://arxiv.org/abs/1911.05722">“Momentum Contrast for Unsupervised Visual Representation Learning.”</a> CVPR 2020.</p>
<p>[25] Zhirong Wu, et al. <a href="https://arxiv.org/abs/1805.01978v1">“Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination.”</a> CVPR 2018.</p>
<p>[26] Ting Chen, et al. <a href="https://arxiv.org/abs/2002.05709">“A Simple Framework for Contrastive Learning of Visual Representations.”</a> arXiv preprint arXiv:2002.05709, 2020.</p>
<p>[27] Aravind Srinivas, Michael Laskin & Pieter Abbeel <a href="https://arxiv.org/abs/2004.04136">“CURL: Contrastive Unsupervised Representations for Reinforcement Learning.”</a> arXiv preprint arXiv:2004.04136, 2020.</p>Lilian WengSelf-supervised learning opens up a huge opportunity for better utilizing unlabelled data, while learning in a supervised learning manner. This post covers many interesting ideas of self-supervised learning tasks on images, videos, and control problems.Evolution Strategies2019-09-05T12:00:00+00:002019-09-05T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/09/05/evolution-strategies<blockquote>
<p>Gradient descent is not the only option when learning optimal model parameters. Evolution Strategies (ES) works out well in the cases where we don’t know the precise analytic form of an objective function or cannot compute the gradients directly. This post dives into several classic ES methods, as well as how ES can be used in deep reinforcement learning.</p>
</blockquote>
<!--more-->
<p>Stochastic gradient descent is a universal choice for optimizing deep learning models. However, it is not the only option. With black-box optimization algorithms, you can evaluate a target function <script type="math/tex">f(x): \mathbb{R}^n \to \mathbb{R}</script>, even when you don’t know the precise analytic form of <script type="math/tex">f(x)</script> and thus cannot compute gradients or the Hessian matrix. Examples of black-box optimization methods include <a href="https://en.wikipedia.org/wiki/Simulated_annealing">Simulated Annealing</a>, <a href="https://en.wikipedia.org/wiki/Hill_climbing">Hill Climbing</a> and <a href="https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method">Nelder-Mead method</a>.</p>
<p><strong>Evolution Strategies (ES)</strong> is one type of black-box optimization algorithms, born in the family of <strong>Evolutionary Algorithms (EA)</strong>. In this post, I would dive into a couple of classic ES methods and introduce a few applications of how ES can play a role in deep reinforcement learning.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#what-are-evolution-strategies" id="markdown-toc-what-are-evolution-strategies">What are Evolution Strategies?</a></li>
<li><a href="#simple-gaussian-evolution-strategies" id="markdown-toc-simple-gaussian-evolution-strategies">Simple Gaussian Evolution Strategies</a></li>
<li><a href="#covariance-matrix-adaptation-evolution-strategies-cma-es" id="markdown-toc-covariance-matrix-adaptation-evolution-strategies-cma-es">Covariance Matrix Adaptation Evolution Strategies (CMA-ES)</a> <ul>
<li><a href="#updating-the-mean" id="markdown-toc-updating-the-mean">Updating the Mean</a></li>
<li><a href="#controlling-the-step-size" id="markdown-toc-controlling-the-step-size">Controlling the Step Size</a></li>
<li><a href="#adapting-the-covariance-matrix" id="markdown-toc-adapting-the-covariance-matrix">Adapting the Covariance Matrix</a></li>
</ul>
</li>
<li><a href="#natural-evolution-strategies" id="markdown-toc-natural-evolution-strategies">Natural Evolution Strategies</a> <ul>
<li><a href="#natural-gradients" id="markdown-toc-natural-gradients">Natural Gradients</a></li>
<li><a href="#estimation-using-fisher-information-matrix" id="markdown-toc-estimation-using-fisher-information-matrix">Estimation using Fisher Information Matrix</a></li>
<li><a href="#nes-algorithm" id="markdown-toc-nes-algorithm">NES Algorithm</a></li>
</ul>
</li>
<li><a href="#applications-es-in-deep-reinforcement-learning" id="markdown-toc-applications-es-in-deep-reinforcement-learning">Applications: ES in Deep Reinforcement Learning</a> <ul>
<li><a href="#openai-es-for-rl" id="markdown-toc-openai-es-for-rl">OpenAI ES for RL</a></li>
<li><a href="#exploration-with-es" id="markdown-toc-exploration-with-es">Exploration with ES</a></li>
<li><a href="#cem-rl" id="markdown-toc-cem-rl">CEM-RL</a></li>
</ul>
</li>
<li><a href="#extension-ea-in-deep-learning" id="markdown-toc-extension-ea-in-deep-learning">Extension: EA in Deep Learning</a> <ul>
<li><a href="#hyperparameter-tuning-pbt" id="markdown-toc-hyperparameter-tuning-pbt">Hyperparameter Tuning: PBT</a></li>
<li><a href="#network-topology-optimization-wann" id="markdown-toc-network-topology-optimization-wann">Network Topology Optimization: WANN</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="what-are-evolution-strategies">What are Evolution Strategies?</h2>
<p>Evolution strategies (ES) belong to the big family of evolutionary algorithms. The optimization targets of ES are vectors of real numbers, <script type="math/tex">x \in \mathbb{R}^n</script>.</p>
<p>Evolutionary algorithms refer to a division of population-based optimization algorithms inspired by <em>natural selection</em>. Natural selection believes that individuals with traits beneficial to their survival can live through generations and pass down the good characteristics to the next generation. Evolution happens by the selection process gradually and the population grows better adapted to the environment.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/EA-illustration.png" alt="EA" /></p>
<p><em>Fig. 1. How natural selection works. (Image source: Khan Academy: <a href="https://www.khanacademy.org/science/biology/her/evolution-and-natural-selection/a/darwin-evolution-natural-selection">Darwin, evolution, & natural selection</a>)</em></p>
<p>Evolutionary algorithms can be summarized in the following <a href="https://ipvs.informatik.uni-stuttgart.de/mlr/marc/teaching/13-Optimization/06-blackBoxOpt.pdf">format</a> as a general optimization solution:</p>
<p>Let’s say we want to optimize a function <script type="math/tex">f(x)</script> and we are not able to compute gradients directly. But we still can evaluate <script type="math/tex">f(x)</script> given any <script type="math/tex">x</script> and the result is deterministic. Our belief in the probability distribution over <script type="math/tex">x</script> as a good solution to <script type="math/tex">f(x)</script> optimization is <script type="math/tex">p_\theta(x)</script>, parameterized by <script type="math/tex">\theta</script>. The goal is to find an optimal configuration of <script type="math/tex">\theta</script>.</p>
<blockquote>
<p>Here given a fixed format of distribution (i.e. Gaussian), the parameter <script type="math/tex">\theta</script> carries the knowledge about the best solutions and is being iteratively updated across generations.</p>
</blockquote>
<p>Starting with an initial value of <script type="math/tex">\theta</script>, we can continuously update <script type="math/tex">\theta</script> by looping three steps as follows:</p>
<ol>
<li>Generate a population of samples <script type="math/tex">D = \{(x_i, f(x_i)\}</script> where <script type="math/tex">x_i \sim p_\theta(x)</script>.</li>
<li>Evaluate the “fitness” of samples in <script type="math/tex">D</script>.</li>
<li>Select the best subset of individuals and use them to update <script type="math/tex">\theta</script>, generally based on fitness or rank.</li>
</ol>
<p>In <strong>Genetic Algorithms (GA)</strong>, another popular subcategory of EA, <script type="math/tex">x</script> is a sequence of binary codes, <script type="math/tex">x \in \{0, 1\}^n</script>. While in ES, <script type="math/tex">x</script> is just a vector of real numbers, <script type="math/tex">x \in \mathbb{R}^n</script>.</p>
<h2 id="simple-gaussian-evolution-strategies">Simple Gaussian Evolution Strategies</h2>
<p><a href="http://blog.otoro.net/2017/10/29/visual-evolution-strategies/">This</a> is the most basic and canonical version of evolution strategies. It models <script type="math/tex">p_\theta(x)</script> as a <script type="math/tex">n</script>-dimensional isotropic Gaussian distribution, in which <script type="math/tex">\theta</script> only tracks the mean <script type="math/tex">\mu</script> and standard deviation <script type="math/tex">\sigma</script>.</p>
<script type="math/tex; mode=display">\theta = (\mu, \sigma),\;p_\theta(x) \sim \mathcal{N}(\mathbf{\mu}, \sigma^2 I) = \mu + \sigma \mathcal{N}(0, I)</script>
<p>The process of Simple-Gaussian-ES, given <script type="math/tex">x \in \mathcal{R}^n</script>:</p>
<ol>
<li>Initialize <script type="math/tex">\theta = \theta^{(0)}</script> and the generation counter <script type="math/tex">t=0</script></li>
<li>Generate the offspring population of size <script type="math/tex">\Lambda</script> by sampling from the Gaussian distribution:<br /><br /><script type="math/tex">D^{(t+1)}=\{ x^{(t+1)}_i \mid x^{(t+1)}_i = \mu^{(t)} + \sigma^{(t)} y^{(t+1)}_i \text{ where } y^{(t+1)}_i \sim \mathcal{N}(x \vert 0, \mathbf{I}),\;i = 1, \dots, \Lambda\}</script><br />.</li>
<li>Select a top subset of <script type="math/tex">\lambda</script> samples with optimal <script type="math/tex">f(x_i)</script> and this subset is called <strong>elite</strong> set. Without loss of generality, we may consider the first <script type="math/tex">k</script> samples in <script type="math/tex">D^{(t+1)}</script> to belong to the elite group — Let’s label them as<br /><br /><script type="math/tex">D^{(t+1)}_\text{elite} = \{x^{(t+1)}_i \mid x^{(t+1)}_i \in D^{(t+1)}, i=1,\dots, \lambda, \lambda\leq \Lambda\}</script><br />.</li>
<li>Then we estimate the new mean and std for the next generation using the elite set:<br /><br />
<script type="math/tex">% <![CDATA[
\begin{aligned}
\mu^{(t+1)} &= \text{avg}(D^{(t+1)}_\text{elite}) = \frac{1}{\lambda}\sum_{i=1}^\lambda x_i^{(t+1)} \\
{\sigma^{(t+1)}}^2 &= \text{var}(D^{(t+1)}_\text{elite}) = \frac{1}{\lambda}\sum_{i=1}^\lambda (x_i^{(t+1)} -\mu^{(t)})^2
\end{aligned} %]]></script><br /></li>
<li>Repeat steps (2)-(4) until the result is good enough ✌️</li>
</ol>
<h2 id="covariance-matrix-adaptation-evolution-strategies-cma-es">Covariance Matrix Adaptation Evolution Strategies (CMA-ES)</h2>
<p>The standard deviation <script type="math/tex">\sigma</script> accounts for the level of exploration: the larger <script type="math/tex">\sigma</script> the bigger search space we can sample our offspring population. In <a href="#simple-gaussian-evolution-strategies">vanilla ES</a>, <script type="math/tex">\sigma^{(t+1)}</script> is highly correlated with <script type="math/tex">\sigma^{(t)}</script>, so the algorithm is not able to rapidly adjust the exploration space when needed (i.e. when the confidence level changes).</p>
<p><a href="https://en.wikipedia.org/wiki/CMA-ES"><strong>CMA-ES</strong></a>, short for <em>“Covariance Matrix Adaptation Evolution Strategy”</em>, fixes the problem by tracking pairwise dependencies between the samples in the distribution with a covariance matrix <script type="math/tex">C</script>. The new distribution parameter becomes:</p>
<script type="math/tex; mode=display">\theta = (\mu, \sigma, C),\; p_\theta(x) \sim \mathcal{N}(\mu, \sigma^2 C) \sim \mu + \sigma \mathcal{N}(0, C)</script>
<p>where <script type="math/tex">\sigma</script> controls for the overall scale of the distribution, often known as <em>step size</em>.</p>
<p>Before we dig into how the parameters are updated in CMA-ES, it is better to review how the covariance matrix works in the multivariate Gaussian distribution first. As a real symmetric matrix, the covariance matrix <script type="math/tex">C</script> has the following nice features (See <a href="http://s3.amazonaws.com/mitsloan-php/wp-faculty/sites/30/2016/12/15032137/Symmetric-Matrices-and-Eigendecomposition.pdf">proof</a> & <a href="http://control.ucsd.edu/mauricio/courses/mae280a/lecture11.pdf">proof</a>):</p>
<ul>
<li>It is always diagonalizable.</li>
<li>Always positive semi-definite.</li>
<li>All of its eigenvalues are real non-negative numbers.</li>
<li>All of its eigenvectors are orthogonal.</li>
<li>There is an orthonormal basis of <script type="math/tex">\mathbb{R}^n</script> consisting of its eigenvectors.</li>
</ul>
<p>Let the matrix <script type="math/tex">C</script> have an <em>orthonormal</em> basis of eigenvectors <script type="math/tex">B = [b_1, \dots, b_n]</script>, with corresponding eigenvalues <script type="math/tex">\lambda_1^2, \dots, \lambda_n^2</script>. Let <script type="math/tex">D=\text{diag}(\lambda_1, \dots, \lambda_n)</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
C = B^\top D^2 B
= \begin{bmatrix}
\mid & \mid & & \mid \\
b_1 & b_2 & \dots & b_n\\
\mid & \mid & & \mid \\
\end{bmatrix}
\begin{bmatrix}
\lambda_1^2 & 0 & \dots & 0 \\
0 & \lambda_2^2 & \dots & 0 \\
\vdots & \dots & \ddots & \vdots \\
0 & \dots & 0 & \lambda_n^2
\end{bmatrix}
\begin{bmatrix}
- & b_1 & - \\
- & b_2 & - \\
& \dots & \\
- & b_n & - \\
\end{bmatrix} %]]></script>
<p>The square root of <script type="math/tex">C</script> is:</p>
<script type="math/tex; mode=display">C^{\frac{1}{2}} = B^\top D B</script>
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr>
<td><script type="math/tex">x_i^{(t)} \in \mathbb{R}^n</script></td>
<td>the <script type="math/tex">i</script>-th samples at the generation (t)</td>
</tr>
<tr>
<td><script type="math/tex">y_i^{(t)} \in \mathbb{R}^n</script></td>
<td><script type="math/tex">x_i^{(t)} = \mu^{(t-1)} + \sigma^{(t-1)} y_i^{(t)}</script></td>
</tr>
<tr>
<td><script type="math/tex">\mu^{(t)}</script></td>
<td>mean of the generation (t)</td>
</tr>
<tr>
<td><script type="math/tex">\sigma^{(t)}</script></td>
<td>step size</td>
</tr>
<tr>
<td><script type="math/tex">C^{(t)}</script></td>
<td>covariance matrix</td>
</tr>
<tr>
<td><script type="math/tex">B^{(t)}</script></td>
<td>a matrix of <script type="math/tex">C</script>’s eigenvectors as row vectors</td>
</tr>
<tr>
<td><script type="math/tex">D^{(t)}</script></td>
<td>a diagonal matrix with <script type="math/tex">C</script>’s eigenvalues on the diagnose.</td>
</tr>
<tr>
<td><script type="math/tex">p_\sigma^{(t)}</script></td>
<td>evaluation path for <script type="math/tex">\sigma</script> at the generation (t)</td>
</tr>
<tr>
<td><script type="math/tex">p_c^{(t)}</script></td>
<td>evaluation path for <script type="math/tex">C</script> at the generation (t)</td>
</tr>
<tr>
<td><script type="math/tex">\alpha_\mu</script></td>
<td>learning rate for <script type="math/tex">\mu</script>’s update</td>
</tr>
<tr>
<td><script type="math/tex">\alpha_\sigma</script></td>
<td>learning rate for <script type="math/tex">p_\sigma</script></td>
</tr>
<tr>
<td><script type="math/tex">d_\sigma</script></td>
<td>damping factor for <script type="math/tex">\sigma</script>’s update</td>
</tr>
<tr>
<td><script type="math/tex">\alpha_{cp}</script></td>
<td>learning rate for <script type="math/tex">p_c</script></td>
</tr>
<tr>
<td><script type="math/tex">\alpha_{c\lambda}</script></td>
<td>learning rate for <script type="math/tex">C</script>’s rank-min(λ, n) update</td>
</tr>
<tr>
<td><script type="math/tex">\alpha_{c1}</script></td>
<td>learning rate for <script type="math/tex">C</script>’s rank-1 update</td>
</tr>
</tbody>
</table>
<h3 id="updating-the-mean">Updating the Mean</h3>
<script type="math/tex; mode=display">\mu^{(t+1)} = \mu^{(t)} + \alpha_\mu \frac{1}{\lambda}\sum_{i=1}^\lambda (x_i^{(t+1)} - \mu^{(t)})</script>
<p>CMA-ES has a learning rate <script type="math/tex">\alpha_\mu \leq 1</script> to control how fast the mean <script type="math/tex">\mu</script> should be updated. Usually it is set to 1 and thus the equation becomes the same as in vanilla ES, <script type="math/tex">\mu^{(t+1)} = \frac{1}{\lambda}\sum_{i=1}^\lambda (x_i^{(t+1)}</script>.</p>
<h3 id="controlling-the-step-size">Controlling the Step Size</h3>
<p>The sampling process can be decoupled from the mean and standard deviation:</p>
<script type="math/tex; mode=display">x^{(t+1)}_i = \mu^{(t)} + \sigma^{(t)} y^{(t+1)}_i \text{, where } y^{(t+1)}_i = \frac{x_i^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \sim \mathcal{N}(0, C)</script>
<p>The parameter <script type="math/tex">\sigma</script> controls the overall scale of the distribution. It is separated from the covariance matrix so that we can change steps faster than the full covariance. A larger step size leads to faster parameter update. In order to evaluate whether the current step size is proper, CMA-ES constructs an <em>evolution path</em> <script type="math/tex">p_\sigma</script> by summing up a consecutive sequence of moving steps, <script type="math/tex">\frac{1}{\lambda}\sum_{i}^\lambda y_i^{(j)}, j=1, \dots, t</script>. By comparing this path length with its expected length under random selection (meaning single steps are uncorrelated), we are able to adjust <script type="math/tex">\sigma</script> accordingly (See Fig. 2).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CMA-ES-step-size-path.png" alt="CMA-ES step size" /></p>
<p><em>Fig. 2. Three scenarios of how single steps are correlated in different ways and their impacts on step size update. (Image source: additional annotations on Fig 5 in <a href="https://arxiv.org/abs/1604.00772">CMA-ES tutorial</a> paper)</em></p>
<p>Each time the evolution path is updated with the average of moving step <script type="math/tex">y_i</script> in the same generation.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
&\frac{1}{\lambda}\sum_{i=1}^\lambda y_i^{(t+1)}
= \frac{1}{\lambda} \frac{\sum_{i=1}^\lambda x_i^{(t+1)} - \lambda \mu^{(t)}}{\sigma^{(t)}}
= \frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \\
&\frac{1}{\lambda}\sum_{i=1}^\lambda y_i^{(t+1)}
\sim \frac{1}{\lambda}\mathcal{N}(0, \lambda C^{(t)})
\sim \frac{1}{\sqrt{\lambda}}{C^{(t)}}^{\frac{1}{2}}\mathcal{N}(0, I) \\
&\text{Thus } \sqrt{\lambda}\;{C^{(t)}}^{-\frac{1}{2}} \frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \sim \mathcal{N}(0, I)
\end{aligned} %]]></script>
<blockquote>
<p>By multiplying with <script type="math/tex">C^{-\frac{1}{2}}</script>, the evolution path is transformed to be independent of its direction. The term <script type="math/tex">{C^{(t)}}^{-\frac{1}{2}} = {B^{(t)}}^\top {D^{(t)}}^{-\frac{1}{2}} {B^{(t)}}</script> transformation works as follows:</p>
<ol>
<li><script type="math/tex">{B^{(t)}}</script> contains row vectors of <script type="math/tex">C</script>’s eigenvectors. It projects the original space onto the perpendicular principal axes.</li>
<li>Then <script type="math/tex">{D^{(t)}}^{-\frac{1}{2}} = \text{diag}(\frac{1}{\lambda_1}, \dots, \frac{1}{\lambda_n})</script> scales the length of principal axes to be equal.</li>
<li><script type="math/tex">{B^{(t)}}^\top</script> transforms the space back to the original coordinate system.</li>
</ol>
</blockquote>
<p>In order to assign higher weights to recent generations, we use polyak averaging to update the evolution path with learning rate <script type="math/tex">\alpha_\sigma</script>. Meanwhile, the weights are balanced so that <script type="math/tex">p_\sigma</script> is <a href="https://en.wikipedia.org/wiki/Conjugate_prior">conjugate</a>, <script type="math/tex">\sim \mathcal{N}(0, I)</script> both before and after one update.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
p_\sigma^{(t+1)}
& = (1 - \alpha_\sigma) p_\sigma^{(t)} + \sqrt{1 - (1 - \alpha_\sigma)^2}\;\sqrt{\lambda}\; {C^{(t)}}^{-\frac{1}{2}} \frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \\
& = (1 - \alpha_\sigma) p_\sigma^{(t)} + \sqrt{c_\sigma (2 - \alpha_\sigma)\lambda}\;{C^{(t)}}^{-\frac{1}{2}} \frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}}
\end{aligned} %]]></script>
<p>The expected length of <script type="math/tex">p_\sigma</script> under random selection is <script type="math/tex">\mathbb{E}\|\mathcal{N}(0,I)\|</script>, that is the expectation of the L2-norm of a <script type="math/tex">\mathcal{N}(0,I)</script> random variable. Following the idea in Fig. 2, we adjust the step size according to the ratio of <script type="math/tex">\|p_\sigma^{(t+1)}\| / \mathbb{E}\|\mathcal{N}(0,I)\|</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\ln\sigma^{(t+1)} &= \ln\sigma^{(t)} + \frac{\alpha_\sigma}{d_\sigma} \Big(\frac{\|p_\sigma^{(t+1)}\|}{\mathbb{E}\|\mathcal{N}(0,I)\|} - 1\Big) \\
\sigma^{(t+1)} &= \sigma^{(t)} \exp\Big(\frac{\alpha_\sigma}{d_\sigma} \Big(\frac{\|p_\sigma^{(t+1)}\|}{\mathbb{E}\|\mathcal{N}(0,I)\|} - 1\Big)\Big)
\end{aligned} %]]></script>
<p>where <script type="math/tex">d_\sigma \approx 1</script> is a damping parameter, scaling how fast <script type="math/tex">\ln\sigma</script> should be changed.</p>
<h3 id="adapting-the-covariance-matrix">Adapting the Covariance Matrix</h3>
<p>For the covariance matrix, it can be estimated from scratch using <script type="math/tex">y_i</script> of elite samples (recall that <script type="math/tex">y_i \sim \mathcal{N}(0, C)</script>):</p>
<script type="math/tex; mode=display">C_\lambda^{(t+1)}
= \frac{1}{\lambda}\sum_{i=1}^\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\top
= \frac{1}{\lambda {\sigma^{(t)}}^2} \sum_{i=1}^\lambda (x_i^{(t+1)} - \mu^{(t)})(x_i^{(t+1)} - \mu^{(t)})^\top</script>
<p>The above estimation is only reliable when the selected population is large enough. However, we do want to run <em>fast</em> iteration with a <em>small</em> population of samples in each generation. That’s why CMA-ES invented a more reliable but also more complicated way to update <script type="math/tex">C</script>. It involves two independent routes,</p>
<ul>
<li><em>Rank-min(λ, n) update</em>: uses the history of <script type="math/tex">\{C_\lambda\}</script>, each estimated from scratch in one generation.</li>
<li><em>Rank-one update</em>: estimates the moving steps <script type="math/tex">y_i</script> and the sign information from the history.</li>
</ul>
<p>The first route considers the estimation of <script type="math/tex">C</script> from the entire history of <script type="math/tex">\{C_\lambda\}</script>. For example, if we have experienced a large number of generations, <script type="math/tex">C^{(t+1)} \approx \text{avg}(C_\lambda^{(i)}; i=1,\dots,t)</script> would be a good estimator. Similar to <script type="math/tex">p_\sigma</script>, we also use polyak averaging with a learning rate to incorporate the history:</p>
<script type="math/tex; mode=display">C^{(t+1)}
= (1 - \alpha_{c\lambda}) C^{(t)} + \alpha_{c\lambda} C_\lambda^{(t+1)}
= (1 - \alpha_{c\lambda}) C^{(t)} + \alpha_{c\lambda} \frac{1}{\lambda} \sum_{i=1}^\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\top</script>
<p>A common choice for the learning rate is <script type="math/tex">\alpha_{c\lambda} \approx \min(1, \lambda/n^2)</script>.</p>
<p>The second route tries to solve the issue that <script type="math/tex">y_i{y_i}^\top = (-y_i)(-y_i)^\top</script> loses the sign information. Similar to how we adjust the step size <script type="math/tex">\sigma</script>, an evolution path <script type="math/tex">p_c</script> is used to track the sign information and it is constructed in a way that <script type="math/tex">p_c</script> is conjugate, <script type="math/tex">\sim \mathcal{N}(0, C)</script> both before and after a new generation.</p>
<p>We may consider <script type="math/tex">p_c</script> as another way to compute <script type="math/tex">\text{avg}_i(y_i)</script> (notice that both <script type="math/tex">\sim \mathcal{N}(0, C)</script>) while the entire history is used and the sign information is maintained. Note that we’ve known <script type="math/tex">\sqrt{k}\frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \sim \mathcal{N}(0, C)</script> in the <a href="#controlling-the-step-size">last section</a>,</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
p_c^{(t+1)}
&= (1-\alpha_{cp}) p_c^{(t)} + \sqrt{1 - (1-\alpha_{cp})^2}\;\sqrt{\lambda}\;\frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}} \\
&= (1-\alpha_{cp}) p_c^{(t)} + \sqrt{\alpha_{cp}(2 - \alpha_{cp})\lambda}\;\frac{\mu^{(t+1)} - \mu^{(t)}}{\sigma^{(t)}}
\end{aligned} %]]></script>
<p>Then the covariance matrix is updated according to <script type="math/tex">p_c</script>:</p>
<script type="math/tex; mode=display">C^{(t+1)} = (1-\alpha_{c1}) C^{(t)} + \alpha_{c1}\;p_c^{(t+1)} {p_c^{(t+1)}}^\top</script>
<p>The <em>rank-one update</em> approach is claimed to generate a significant improvement over the <em>rank-min(λ, n)-update</em> when <script type="math/tex">k</script> is small, because the signs of moving steps and correlations between consecutive steps are all utilized and passed down through generations.</p>
<p>Eventually we combine two approaches together,</p>
<script type="math/tex; mode=display">C^{(t+1)}
= (1 - \alpha_{c\lambda} - \alpha_{c1}) C^{(t)}
+ \alpha_{c1}\;\underbrace{p_c^{(t+1)} {p_c^{(t+1)}}^\top}_\textrm{rank-one update}
+ \alpha_{c\lambda} \underbrace{\frac{1}{\lambda} \sum_{i=1}^\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\top}_\textrm{rank-min(lambda, n) update}</script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CMA-ES-algorithm.png" alt="CMA-ES Algorithm" /></p>
<p>In all my examples above, each elite sample is considered to contribute an equal amount of weights, <script type="math/tex">1/\lambda</script>. The process can be easily extended to the case where selected samples are assigned with different weights, <script type="math/tex">w_1, \dots, w_\lambda</script>, according to their performances. See more detail in <a href="https://arxiv.org/abs/1604.00772">tutorial</a>.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/CMA-ES-illustration.png" alt="CMA-ES Illustration" /></p>
<p><em>Fig. 3. Illustration of how CMA-ES works on a 2D optimization problem (the lighter color the better). Black dots are samples in one generation. The samples are more spread out initially but when the model has higher confidence in finding a good solution in the late stage, the samples become very concentrated over the global optimum. (Image source: <a href="https://en.wikipedia.org/wiki/CMA-ES">Wikipedia CMA-ES</a>)</em></p>
<h2 id="natural-evolution-strategies">Natural Evolution Strategies</h2>
<p>Natural Evolution Strategies (<strong>NES</strong>; <a href="https://arxiv.org/abs/1106.4487">Wierstra, et al, 2008</a>) optimizes in a search distribution of parameters and moves the distribution in the direction of high fitness indicated by the <em>natural gradient</em>.</p>
<h3 id="natural-gradients">Natural Gradients</h3>
<p>Given an objective function <script type="math/tex">\mathcal{J}(\theta)</script> parameterized by <script type="math/tex">\theta</script>, let’s say our goal is to find the optimal <script type="math/tex">\theta</script> to maximize the objective function value. A <em>plain gradient</em> finds the steepest direction within a small Euclidean distance from the current <script type="math/tex">\theta</script>; the distance restriction is applied on the parameter space. In other words, we compute the plain gradient with respect to a small change of the absolute value of <script type="math/tex">\theta</script>. The optimal step is:</p>
<script type="math/tex; mode=display">d^{*} = \operatorname*{argmax}_{\|d\| = \epsilon} \mathcal{J}(\theta + d)\text{, where }\epsilon \to 0</script>
<p>Differently, <em>natural gradient</em> works with a probability <a href="https://arxiv.org/abs/1301.3584v7">distribution</a> <a href="https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/">space</a> parameterized by <script type="math/tex">\theta</script>, <script type="math/tex">p_\theta(x)</script> (referred to as “search distribution” in NES <a href="https://arxiv.org/abs/1106.4487">paper</a>). It looks for the steepest direction within a small step in the distribution space where the distance is measured by KL divergence. With this constraint we ensure that each update is moving along the distributional manifold with constant speed, without being slowed down by its curvature.</p>
<script type="math/tex; mode=display">d^{*}_\text{N} = \operatorname*{argmax}_{\text{KL}[p_\theta \| p_{\theta+d}] = \epsilon} \mathcal{J}(\theta + d)</script>
<h3 id="estimation-using-fisher-information-matrix">Estimation using Fisher Information Matrix</h3>
<p>But, how to compute <script type="math/tex">\text{KL}[p_\theta \| p_{\theta+\Delta\theta}]</script> precisely? By running Taylor expansion of <script type="math/tex">\log p_{\theta + d}</script> at <script type="math/tex">\theta</script>, we get:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
& \text{KL}[p_\theta \| p_{\theta+d}] \\
&= \mathbb{E}_{x \sim p_\theta} [\log p_\theta(x) - \log p_{\theta+d}(x)] & \\
&\approx \mathbb{E}_{x \sim p_\theta} [ \log p_\theta(x) -( \log p_{\theta}(x) + \nabla_\theta \log p_{\theta}(x) d + \frac{1}{2}d^\top \nabla^2_\theta \log p_{\theta}(x) d)] & \scriptstyle{\text{; Taylor expand }\log p_{\theta+d}} \\
&\approx - \mathbb{E}_x [\nabla_\theta \log p_{\theta}(x)] d - \frac{1}{2}d^\top \mathbb{E}_x [\nabla^2_\theta \log p_{\theta}(x)] d &
\end{aligned} %]]></script>
<p>where</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbb{E}_x [\nabla_\theta \log p_{\theta}] d
&= \int_{x\sim p_\theta} p_\theta(x) \nabla_\theta \log p_\theta(x) & \\
&= \int_{x\sim p_\theta} p_\theta(x) \frac{1}{p_\theta(x)} \nabla_\theta p_\theta(x) & \\
&= \nabla_\theta \Big( \int_{x} p_\theta(x) \Big) & \scriptstyle{\textrm{; note that }p_\theta(x)\textrm{ is probability distribution.}} \\
&= \nabla_\theta (1) = 0
\end{aligned} %]]></script>
<p>Finally we have,</p>
<script type="math/tex; mode=display">\text{KL}[p_\theta \| p_{\theta+d}] = - \frac{1}{2}d^\top \mathbf{F}_\theta d
\text{, where }\mathbf{F}_\theta = \mathbb{E}_x [(\nabla_\theta \log p_{\theta}) (\nabla_\theta \log p_{\theta})^\top]</script>
<p>where <script type="math/tex">\mathbf{F}_\theta</script> is called the <strong><a href="http://mathworld.wolfram.com/FisherInformationMatrix.html">Fisher Information Matrix</a></strong> and <a href="https://wiseodd.github.io/techblog/2018/03/11/fisher-information/">it is</a> the covariance matrix of <script type="math/tex">\nabla_\theta \log p_\theta</script> since <script type="math/tex">\mathbb{E}[\nabla_\theta \log p_\theta] = 0</script>.</p>
<p>The solution to the following optimization problem:</p>
<script type="math/tex; mode=display">\max \mathcal{J}(\theta + d) \approx \max \big( \mathcal{J}(\theta) + {\nabla_\theta\mathcal{J}(\theta)}^\top d \big)\;\text{ s.t. }\text{KL}[p_\theta \| p_{\theta+d}] - \epsilon = 0</script>
<p>can be found using a Lagrangian multiplier,</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}(\theta, d, \beta) &= \mathcal{J}(\theta) + \nabla_\theta\mathcal{J}(\theta)^\top d - \beta (\frac{1}{2}d^\top \mathbf{F}_\theta d + \epsilon) = 0 \text{ s.t. } \beta > 0 \\
\nabla_d \mathcal{L}(\theta, d, \beta) &= \nabla_\theta\mathcal{J}(\theta) - \beta\mathbf{F}_\theta d = 0 \\
\text{Thus } d_\text{N}^* &= \nabla_\theta^\text{N} \mathcal{J}(\theta) = \mathbf{F}_\theta^{-1} \nabla_\theta\mathcal{J}(\theta)
\end{aligned} %]]></script>
<p>where <script type="math/tex">d_\text{N}^*</script> only extracts the direction of the optimal moving step on <script type="math/tex">\theta</script>, ignoring the scalar <script type="math/tex">\beta^{-1}</script>.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/CMA-ES-coordinates.png" alt="Plain vs natural coordinates" /></p>
<p><em>Fig. 4. The natural gradient samples (black solid arrows) in the right are the plain gradient samples (black solid arrows) in the left multiplied by the inverse of their covariance. In this way, a gradient direction with high uncertainty (indicated by high covariance with other samples) are penalized with a small weight. The aggregated natural gradient (red dash arrow) is therefore more trustworthy than the natural gradient (green solid arrow). (Image source: additional annotations on Fig 2 in <a href="https://arxiv.org/abs/1106.4487">NES</a> paper)</em></p>
<h3 id="nes-algorithm">NES Algorithm</h3>
<p>The fitness associated with one sample is labeled as <script type="math/tex">f(x)</script> and the search distribution over <script type="math/tex">x</script> is parameterized by <script type="math/tex">\theta</script>. NES is expected to optimize the parameter <script type="math/tex">\theta</script> to achieve maximum expected fitness:</p>
<script type="math/tex; mode=display">\mathcal{J}(\theta) = \mathbb{E}_{x\sim p_\theta(x)} [f(x)] = \int_x f(x) p_\theta(x) dx</script>
<p>Using the same log-likelihood <a href="http://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/">trick</a> in <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce">REINFORCE</a>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\nabla_\theta\mathcal{J}(\theta)
&= \nabla_\theta \int_x f(x) p_\theta(x) dx \\
&= \int_x f(x) \frac{p_\theta(x)}{p_\theta(x)}\nabla_\theta p_\theta(x) dx \\
& = \int_x f(x) p_\theta(x) \nabla_\theta \log p_\theta(x) dx \\
& = \mathbb{E}_{x \sim p_\theta} [f(x) \nabla_\theta \log p_\theta(x)]
\end{aligned} %]]></script>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/NES-algorithm.png" alt="NES" /></p>
<p>Besides natural gradients, NES adopts a couple of important heuristics to make the algorithm performance more robust.</p>
<ul>
<li><a name="fitness-shaping"></a>NES applies <strong>rank-based fitness shaping</strong>, that is to use the <em>rank</em> under monotonically increasing fitness values instead of using <script type="math/tex">f(x)</script> directly. Or it can be a function of the rank (“utility function”), which is considered as a free parameter of NES.</li>
<li>NES adopts <strong>adaptation sampling</strong> to adjust hyperparameters at run time. When changing <script type="math/tex">\theta \to \theta’</script>, samples drawn from <script type="math/tex">p_\theta</script> are compared with samples from <script type="math/tex">p_{\theta’}</script> using [Mann-Whitney U-test(https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test)]; if there shows a positive or negative sign, the target hyperparameter decreases or increases by a multiplication constant. Note the score of a sample <script type="math/tex">x’_i \sim p_{\theta’}(x)</script> has importance sampling weights applied <script type="math/tex">w_i’ = p_\theta(x) / p_{\theta’}(x)</script>.</li>
</ul>
<h2 id="applications-es-in-deep-reinforcement-learning">Applications: ES in Deep Reinforcement Learning</h2>
<h3 id="openai-es-for-rl">OpenAI ES for RL</h3>
<p>The concept of using evolutionary algorithms in reinforcement learning can be traced back <a href="https://arxiv.org/abs/1106.0221">long ago</a>, but only constrained to tabular RL due to computational limitations.</p>
<p>Inspired by <a href="#natural-evolution-strategies">NES</a>, researchers at OpenAI (<a href="https://arxiv.org/abs/1703.03864">Salimans, et al. 2017</a>) proposed to use NES as a gradient-free black-box optimizer to find optimal policy parameters <script type="math/tex">\theta</script> that maximizes the return function <script type="math/tex">F(\theta)</script>. The key is to add Gaussian noise $\epsilon$ on the model parameter $\theta$ and then use the log-likelihood trick to write it as the gradient of the Gaussian pdf. Eventually only the noise term is left as a weighting scalar for measured performance.</p>
<p>Let’s say the current parameter value is <script type="math/tex">\hat{\theta}</script> (the added hat is to distinguish the value from the random variable <script type="math/tex">\theta</script>). The search distribution of <script type="math/tex">\theta</script> is designed to be an isotropic multivariate Gaussian with a mean <script type="math/tex">\hat{\theta}</script> and a fixed covariance matrix <script type="math/tex">\sigma^2 I</script>,</p>
<script type="math/tex; mode=display">\theta \sim \mathcal{N}(\hat{\theta}, \sigma^2 I) \text{ equivalent to } \theta = \hat{\theta} + \sigma\epsilon, \epsilon \sim \mathcal{N}(0, I)</script>
<p>The gradient for <script type="math/tex">\theta</script> update is:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
& \nabla_\theta \mathbb{E}_{\theta\sim\mathcal{N}(\hat{\theta}, \sigma^2 I)} F(\theta) \\
&= \nabla_\theta \mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} F(\hat{\theta} + \sigma\epsilon) \\
&= \nabla_\theta \int_{\epsilon} p(\epsilon) F(\hat{\theta} + \sigma\epsilon) d\epsilon & \scriptstyle{\text{; Gaussian }p(\epsilon)=(2\pi)^{-\frac{n}{2}} \exp(-\frac{1}{2}\epsilon^\top\epsilon)} \\
&= \int_{\epsilon} p(\epsilon) \nabla_\epsilon \log p(\epsilon) \nabla_\theta \epsilon\;F(\hat{\theta} + \sigma\epsilon) d\epsilon & \scriptstyle{\text{; log-likelihood trick}}\\
&= \mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} [ \nabla_\epsilon \big(-\frac{1}{2}\epsilon^\top\epsilon\big) \nabla_\theta \big(\frac{\theta - \hat{\theta}}{\sigma}\big) F(\hat{\theta} + \sigma\epsilon) ] & \\
&= \mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} [ (-\epsilon) (\frac{1}{\sigma}) F(\hat{\theta} + \sigma\epsilon) ] & \\
&= \frac{1}{\sigma}\mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} [ \epsilon F(\hat{\theta} + \sigma\epsilon) ] & \scriptstyle{\text{; negative sign can be absorbed.}}
\end{aligned} %]]></script>
<p>In one generation, we can sample many <script type="math/tex">epsilon_i, i=1,\dots,n</script> and evaluate the fitness <em>in parallel</em>. One beautiful design is that no large model parameter needs to be shared. By only communicating the random seeds between workers, it is enough for the master node to do parameter update. This approach is later extended to adaptively learn a loss function; see my previous post on <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#meta-learning-the-loss-function">Evolved Policy Gradient</a>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/OpenAI-ES-algorithm.png" alt="ES for RL" /></p>
<p><em>Fig. 5. The algorithm for training a RL policy using evolution strategies. (Image source: <a href="https://arxiv.org/abs/1703.03864">ES-for-RL</a> paper)</em></p>
<p>To make the performance more robust, OpenAI ES adopts virtual batch normalization (BN with mini-batch used for calculating statistics fixed), mirror sampling (sampling a pair of <script type="math/tex">(-\epsilon, \epsilon)</script> for evaluation), and <a href="#fitness-shaping">fitness shaping</a>.</p>
<h3 id="exploration-with-es">Exploration with ES</h3>
<p>Exploration (<a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#exploitation-vs-exploration">vs exploitation</a>) is an important topic in RL. The optimization direction in the ES algorithm <a href="TBA">above</a> is only extracted from the cumulative return <script type="math/tex">F(\theta)</script>. Without explicit exploration, the agent might get trapped in a local optimum.</p>
<p>Novelty-Search ES (<strong>NS-ES</strong>; <a href="https://arxiv.org/abs/1712.06560">Conti et al, 2018</a>) encourages exploration by updating the parameter in the direction to maximize the <em>novelty</em> score. The novelty score depends on a domain-specific behavior characterization function <script type="math/tex">b(\pi_\theta)</script>. The choice of <script type="math/tex">b(\pi_\theta)</script> is specific to the task and seems to be a bit arbitrary; for example, in the Humanoid locomotion task in the paper, <script type="math/tex">b(\pi_\theta)</script> is the final <script type="math/tex">(x,y)</script> location of the agent.</p>
<ol>
<li>Every policy’s <script type="math/tex">b(\pi_\theta)</script> is pushed to an archive set <script type="math/tex">\mathcal{A}</script>.</li>
<li>Novelty of a policy <script type="math/tex">\pi_\theta</script> is measured as the k-nearest neighbor score between <script type="math/tex">b(\pi_\theta)</script> and all other entries in <script type="math/tex">\mathcal{A}</script>.
(The use case of the archive set sounds quite similar to <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#episodic-control">episodic memory</a>.)</li>
</ol>
<script type="math/tex; mode=display">N(\theta, \mathcal{A}) = \frac{1}{\lambda} \sum_{i=1}^\lambda \| b(\pi_\theta), b^\text{knn}_i \|_2
\text{, where }b^\text{knn}_i \in \text{kNN}(b(\pi_\theta), \mathcal{A})</script>
<p>The ES optimization step relies on the novelty score instead of fitness:</p>
<script type="math/tex; mode=display">\nabla_\theta \mathbb{E}_{\theta\sim\mathcal{N}(\hat{\theta}, \sigma^2 I)} N(\theta, \mathcal{A})
= \frac{1}{\sigma}\mathbb{E}_{\epsilon\sim\mathcal{N}(0, I)} [ \epsilon N(\hat{\theta} + \sigma\epsilon, \mathcal{A}) ]</script>
<p>NS-ES maintains a group of <script type="math/tex">M</script> independently trained agents (“meta-population”), <script type="math/tex">\mathcal{M} = \{\theta_1, \dots, \theta_M \}</script> and picks one to advance proportional to the novelty score. Eventually we select the best policy. This process is equivalent to ensembling; also see the same idea in <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#svpg">SVPG</a>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
m &\leftarrow \text{pick } i=1,\dots,M\text{ according to probability}\frac{N(\theta_i, \mathcal{A})}{\sum_{j=1}^M N(\theta_j, \mathcal{A})} \\
\theta_m^{(t+1)} &\leftarrow \theta_m^{(t)} + \alpha \frac{1}{\sigma}\sum_{i=1}^N \epsilon_i N(\theta^{(t)}_m + \epsilon_i, \mathcal{A}) \text{ where }\epsilon_i \sim \mathcal{N}(0, I)
\end{aligned} %]]></script>
<p>where <script type="math/tex">N</script> is the number of Gaussian perturbation noise vectors and <script type="math/tex">\alpha</script> is the learning rate.</p>
<p>NS-ES completely discards the reward function and only optimizes for novelty to avoid deceptive local optima. To incorporate the fitness back into the formula, another two variations are proposed.</p>
<p><strong>NSR-ES</strong>:</p>
<script type="math/tex; mode=display">\theta_m^{(t+1)} \leftarrow \theta_m^{(t)} + \alpha \frac{1}{\sigma}\sum_{i=1}^N \epsilon_i \frac{N(\theta^{(t)}_m + \epsilon_i, \mathcal{A}) + F(\theta^{(t)}_m + \epsilon_i)}{2}</script>
<p><strong>NSRAdapt-ES (NSRA-ES)</strong>: the adaptive weighting parameter <script type="math/tex">w = 1.0</script> initially. We start decreasing <script type="math/tex">w</script> if performance stays flat for a number of generations. Then when the performance starts to increase, we stop decreasing <script type="math/tex">w</script> but increase it instead. In this way, fitness is preferred when the performance stops growing but novelty is preferred otherwise.</p>
<script type="math/tex; mode=display">\theta_m^{(t+1)} \leftarrow \theta_m^{(t)} + \alpha \frac{1}{\sigma}\sum_{i=1}^N \epsilon_i \big((1-w) N(\theta^{(t)}_m + \epsilon_i, \mathcal{A}) + w F(\theta^{(t)}_m + \epsilon_i)\big)</script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NS-ES-experiments.png" alt="NS-ES Experiments" /></p>
<p><em>Fig. 6. (Left) The environment is Humanoid locomotion with a three-sided wall which plays a role as a deceptive trap to create local optimum. (Right) Experiments compare ES baseline and other variations that encourage exploration. (Image source: <a href="https://arxiv.org/abs/1712.06560">NS-ES</a> paper)</em></p>
<h3 id="cem-rl">CEM-RL</h3>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CEM-RL.png" alt="CEM-RL" /></p>
<p><em>Fig. 7. Architectures of the (a) CEM-RL and (b) <a href="https://papers.nips.cc/paper/7395-evolution-guided-policy-gradient-in-reinforcement-learning.pdf">ERL</a> algorithms (Image source: <a href="https://arxiv.org/abs/1810.01222">CEM-RL</a> paper)</em></p>
<p>The CEM-RL method (<a href="https://arxiv.org/abs/1810.01222">Pourchot & Sigaud, 2019</a>) combines Cross Entropy Method (CEM) with either <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#ddpg">DDPG</a> or <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#td3">TD3</a>. CEM here works pretty much the same as the simple Gaussian ES described <a href="#simple-gaussian-evolution-strategies">above</a> and therefore the same function can be replaced using CMA-ES. CEM-RL is built on the framework of <em>Evolutionary Reinforcement Learning</em> (<em>ERL</em>; <a href="https://papers.nips.cc/paper/7395-evolution-guided-policy-gradient-in-reinforcement-learning.pdf">Khadka & Tumer, 2018</a>) in which the standard EA algorithm selects and evolves a population of actors and the rollout experience generated in the process is then added into reply buffer for training both RL-actor and RL-critic networks.</p>
<p>Workflow:</p>
<ul>
<li>1) The mean actor of the CEM population is <script type="math/tex">\pi_\mu</script> is initialized with a random actor network.</li>
<li>2) The critic network <script type="math/tex">Q</script> is initialized too, which will be updated by DDPG/TD3.</li>
<li>3) Repeat until happy:
<ul>
<li>a. Sample a population of actors <script type="math/tex">\sim \mathcal{N}(\pi_\mu, \Sigma)</script>.</li>
<li>b. Half of the population is evaluated. Their fitness scores are used as the cumulative reward <script type="math/tex">R</script> and added into replay buffer.</li>
<li>c. The other half are updated together with the critic.</li>
<li>d. The new <script type="math/tex">\pi_mu</script> and <script type="math/tex">\Sigma</script> is computed using top performing elite samples. <a href="#covariance-matrix-adaptation-evolution-strategies-cma-es">CMA-ES</a> can be used for parameter update too.</li>
</ul>
</li>
</ul>
<h2 id="extension-ea-in-deep-learning">Extension: EA in Deep Learning</h2>
<p>(This section is not on evolution strategies, but still an interesting and relevant reading.)</p>
<p>The <em>Evolutionary Algorithms</em> have been applied on many deep learning problems. POET (<a href="https://arxiv.org/abs/1901.01753">Wang et al, 2019</a>) is a framework based on EA and attempts to generate a variety of different tasks while the problems themselves are being solved. POET has been introduced in my <a href="/lil-log/2019/06/23/meta-reinforcement-learning.html#task-generation-by-domain-randomization">last post</a> on meta-RL. Evolutionary Reinforcement Learning (ERL) is another example; See Fig. 7 (b).</p>
<p>Below I would like to introduce two applications in more detail, <em>Population-Based Training (PBT)</em> and <em>Weight-Agnostic Neural Networks (WANN)</em>.</p>
<h3 id="hyperparameter-tuning-pbt">Hyperparameter Tuning: PBT</h3>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/PBT.png" alt="PBT" /></p>
<p><em>Fig. 8. Paradigms of comparing different ways of hyperparameter tuning. (Image source: <a href="https://arxiv.org/abs/1711.09846">PBT</a> paper)</em></p>
<p>Population-Based Training (<a href="https://arxiv.org/abs/1711.09846">Jaderberg, et al, 2017</a>), short for <strong>PBT</strong> applies EA on the problem of hyperparameter tuning. It jointly trains a population of models and corresponding hyperparameters for optimal performance.</p>
<p>PBT starts with a set of random candidates, each containing a pair of model weights initialization and hyperparameters, <script type="math/tex">\{(\theta_i, h_i)\mid i=1, \dots, N\}</script>. Every sample is trained in parallel and asynchronously evaluates its own performance periodically. Whenever a member deems ready (i.e. after taking enough gradient update steps, or when the performance is good enough), it has a chance to be updated by comparing with the whole population:</p>
<ul>
<li><strong><code class="language-plaintext highlighter-rouge">exploit()</code></strong>: When this model is under-performing, the weights could be replaced with a better performing model.</li>
<li><strong><code class="language-plaintext highlighter-rouge">explore()</code></strong>: If the model weights are overwritten, <code class="language-plaintext highlighter-rouge">explore</code> step perturbs the hyperparameters with random noise.</li>
</ul>
<p>In this process, only promising model and hyperparameter pairs can survive and keep on evolving, achieving better utilization of computational resources.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/PBT-algorithm.png" alt="PBT Algorithm" /></p>
<p><em>Fig. 9. The algorithm of population-based training. (Image source: <a href="https://arxiv.org/abs/1711.09846">PBT</a> paper)</em></p>
<h3 id="network-topology-optimization-wann">Network Topology Optimization: WANN</h3>
<p><em>Weight Agnostic Neural</em> Networks (short for <strong>WANN</strong>; <a href="https://arxiv.org/abs/1906.04358">Gaier & Ha 2019</a>) experiments with searching for the smallest network topologies that can achieve the optimal performance without training the network weights. By not considering the best configuration of network weights, WANN puts much more emphasis on the architecture itself, making the focus different from <a href="http://openaccess.thecvf.com/content_cvpr_2018/papers/Zoph_Learning_Transferable_Architectures_CVPR_2018_paper.pdf">NAS</a>. WANN is heavily inspired by a classic genetic algorithm to evolve network topologies, called <em>NEAT</em> (“Neuroevolution of Augmenting Topologies”; <a href="http://nn.cs.utexas.edu/downloads/papers/stanley.gecco02_1.pdf">Stanley & Miikkulainen 2002</a>).</p>
<p>The workflow of WANN looks pretty much the same as standard GA:</p>
<ol>
<li>Initialize: Create a population of minimal networks.</li>
<li>Evaluation: Test with a range of <em>shared</em> weight values.</li>
<li>Rank and Selection: Rank by performance and complexity.</li>
<li>Mutation: Create new population by varying best networks.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/WANN-mutations.png" alt="Mutation operations in WANN" /></p>
<p><em>Fig. 10. mutation operations for searching for new network topologies in WANN (Image source: <a href="https://arxiv.org/abs/1906.04358">WANN</a> paper)</em></p>
<p>At the “evaluation” stage, all the network weights are set to be the same. In this way, WANN is actually searching for network that can be described with a minimal description length. In the “selection” stage, both the network connection and the model performance are considered.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/WANN-results.png" alt="WANN results" /></p>
<p><em>Fig. 11. Performance of WANN found network topologies on different RL tasks are compared with baseline FF networks commonly used in the literature. “Tuned Shared Weight” only requires adjusting one weight value. (Image source: <a href="https://arxiv.org/abs/1906.04358">WANN</a> paper)</em></p>
<p>As shown in Fig. 11, WANN results are evaluated with both random weights and shared weights (single weight). It is interesting that even when enforcing weight-sharing on all weights and tuning this single parameter, WANN can discover topologies that achieve non-trivial good performance.</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019ES,
title = "Evolution Strategies",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "https://lilianweng.github.io/lil-log/2019/09/05/evolution-strategies.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Nikolaus Hansen. <a href="https://arxiv.org/abs/1604.00772">“The CMA Evolution Strategy: A Tutorial”</a> arXiv preprint arXiv:1604.00772 (2016).</p>
<p>[2] Marc Toussaint. <a href="https://ipvs.informatik.uni-stuttgart.de/mlr/marc/teaching/13-Optimization/06-blackBoxOpt.pdf">Slides: “Introduction to Optimization”</a></p>
<p>[3] David Ha. <a href="http://blog.otoro.net/2017/10/29/visual-evolution-strategies/">“A Visual Guide to Evolution Strategies”</a> blog.otoro.net. Oct 2017.</p>
<p>[4] Daan Wierstra, et al. <a href="https://arxiv.org/abs/1106.4487">“Natural evolution strategies.”</a> IEEE World Congress on Computational Intelligence, 2008.</p>
<p>[5] Agustinus Kristiadi. <a href="https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/">“Natural Gradient Descent”</a> Mar 2018.</p>
<p>[6] Razvan Pascanu & Yoshua Bengio. <a href="https://arxiv.org/abs/1301.3584v7">“Revisiting Natural Gradient for Deep Networks.”</a> arXiv preprint arXiv:1301.3584 (2013).</p>
<p>[7] Tim Salimans, et al. <a href="https://arxiv.org/abs/1703.03864">“Evolution strategies as a scalable alternative to reinforcement learning.”</a> arXiv preprint arXiv:1703.03864 (2017).</p>
<p>[8] Edoardo Conti, et al. <a href="https://arxiv.org/abs/1712.06560">“Improving exploration in evolution strategies for deep reinforcement learning via a population of novelty-seeking agents.”</a> NIPS. 2018.</p>
<p>[9] Aloïs Pourchot & Olivier Sigaud. <a href="https://arxiv.org/abs/1810.01222">“CEM-RL: Combining evolutionary and gradient-based methods for policy search.”</a> ICLR 2019.</p>
<p>[10] Shauharda Khadka & Kagan Tumer. <a href="https://papers.nips.cc/paper/7395-evolution-guided-policy-gradient-in-reinforcement-learning.pdf">“Evolution-guided policy gradient in reinforcement learning.”</a> NIPS 2018.</p>
<p>[11] Max Jaderberg, et al. <a href="https://arxiv.org/abs/1711.09846">“Population based training of neural networks.”</a> arXiv preprint arXiv:1711.09846 (2017).</p>
<p>[12] Adam Gaier & David Ha. <a href="https://arxiv.org/abs/1906.04358">“Weight Agnostic Neural Networks.”</a> arXiv preprint arXiv:1906.04358 (2019).</p>Lilian WengGradient descent is not the only option when learning optimal model parameters. Evolution Strategies (ES) works out well in the cases where we don’t know the precise analytic form of an objective function or cannot compute the gradients directly. This post dives into several classic ES methods, as well as how ES can be used in deep reinforcement learning.Meta Reinforcement Learning2019-06-23T12:00:00+00:002019-06-23T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/06/23/meta-reinforcement-learning<blockquote>
<p>Meta-RL is meta-learning on reinforcement learning tasks. After trained over a distribution of tasks, the agent is able to solve a new task by developing a new RL algorithm with its internal activity dynamics. This post starts with the origin of meta-RL and then dives into three key components of meta-RL.</p>
</blockquote>
<!--more-->
<p>In my earlier post on <a href="/lil-log/2018/11/30/meta-learning.html">meta-learning</a>, the problem is mainly defined in the context of few-shot classification. Here I would like to explore more into cases when we try to “meta-learn” <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">Reinforcement Learning (RL)</a> tasks by developing an agent that can solve unseen tasks fast and efficiently.</p>
<p>To recap, a good meta-learning model is expected to generalize to new tasks or new environments that have never been encountered during training. The adaptation process, essentially a <em>mini learning session</em>, happens at test with limited exposure to the new configurations. Even without any explicit fine-tuning (no gradient backpropagation on trainable variables), the meta-learning model autonomously adjusts internal hidden states to learn.</p>
<p>Training RL algorithms can be notoriously difficult sometimes. If the meta-learning agent could become so smart that the distribution of solvable unseen tasks grows extremely broad, we are on track towards <a href="http://incompleteideas.net/IncIdeas/BitterLesson.html">general purpose methods</a> — essentially building a “brain” which would solve all kinds of RL problems without much human interference or manual feature engineering. Sounds amazing, right? 💖</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#on-the-origin-of-meta-rl" id="markdown-toc-on-the-origin-of-meta-rl">On the Origin of Meta-RL</a> <ul>
<li><a href="#back-in-2001" id="markdown-toc-back-in-2001">Back in 2001</a></li>
<li><a href="#proposal-in-2016" id="markdown-toc-proposal-in-2016">Proposal in 2016</a></li>
</ul>
</li>
<li><a href="#define-meta-rl" id="markdown-toc-define-meta-rl">Define Meta-RL</a> <ul>
<li><a href="#formulation" id="markdown-toc-formulation">Formulation</a></li>
<li><a href="#main-differences-from-rl" id="markdown-toc-main-differences-from-rl">Main Differences from RL</a></li>
<li><a href="#key-components" id="markdown-toc-key-components">Key Components</a></li>
</ul>
</li>
<li><a href="#meta-learning-algorithms-for-meta-rl" id="markdown-toc-meta-learning-algorithms-for-meta-rl">Meta-Learning Algorithms for Meta-RL</a> <ul>
<li><a href="#optimizing-model-weights-for-meta-learning" id="markdown-toc-optimizing-model-weights-for-meta-learning">Optimizing Model Weights for Meta-learning</a></li>
<li><a href="#meta-learning-hyperparameters" id="markdown-toc-meta-learning-hyperparameters">Meta-learning Hyperparameters</a></li>
<li><a href="#meta-learning-the-loss-function" id="markdown-toc-meta-learning-the-loss-function">Meta-learning the Loss Function</a></li>
<li><a href="#meta-learning-the-exploration-strategies" id="markdown-toc-meta-learning-the-exploration-strategies">Meta-learning the Exploration Strategies</a></li>
<li><a href="#episodic-control" id="markdown-toc-episodic-control">Episodic Control</a></li>
</ul>
</li>
<li><a href="#training-task-acquisition" id="markdown-toc-training-task-acquisition">Training Task Acquisition</a> <ul>
<li><a href="#task-generation-by-domain-randomization" id="markdown-toc-task-generation-by-domain-randomization">Task Generation by Domain Randomization</a></li>
<li><a href="#evolutionary-algorithm-on-environment-generation" id="markdown-toc-evolutionary-algorithm-on-environment-generation">Evolutionary Algorithm on Environment Generation</a></li>
<li><a href="#learning-with-random-rewards" id="markdown-toc-learning-with-random-rewards">Learning with Random Rewards</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="on-the-origin-of-meta-rl">On the Origin of Meta-RL</h2>
<h3 id="back-in-2001">Back in 2001</h3>
<p>I encountered a paper written in 2001 by <a href="http://snowedin.net/tmp/Hochreiter2001.pdf">Hochreiter et al.</a> when reading <a href="https://arxiv.org/pdf/1611.05763.pdf">Wang et al., 2016</a>. Although the idea was proposed for supervised learning, there are so many resemblances to the current approach to meta-RL.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/Hochreiter-meta-learning.png" alt="Hochreiter 2001" /></p>
<p><em>Fig. 1. The meta-learning system consists of the supervisory and the subordinate systems. The subordinate system is a recurrent neural network that takes as input both the observation at the current time step, <script type="math/tex">x_t</script> and the label at the last time step, <script type="math/tex">y_{t-1}</script>. (Image source: <a href="http://snowedin.net/tmp/Hochreiter2001.pdf">Hochreiter et al., 2001</a>)</em></p>
<p>Hochreiter’s meta-learning model is a recurrent network with LSTM cell. LSTM is a good choice because it can internalize a history of inputs and tune its own weights effectively through <a href="https://en.wikipedia.org/wiki/Backpropagation_through_time">BPTT</a>. The training data contains <script type="math/tex">K</script> sequences and each sequence is consist of <script type="math/tex">N</script> samples generated by a target function <script type="math/tex">f_k(.), k=1, \dots, K</script>,</p>
<script type="math/tex; mode=display">\{\text{input: }(\mathbf{x}^k_i, \mathbf{y}^k_{i-1}) \to \text{label: }\mathbf{y}^k_i\}_{i=1}^N
\text{ where }\mathbf{y}^k_i = f_k(\mathbf{x}^k_i)</script>
<p>Noted that <em>the last label</em> <script type="math/tex">\mathbf{y}^k_{i-1}</script> is also provided as an auxiliary input so that the function can learn the presented mapping.</p>
<p>In the experiment of decoding two-dimensional quadratic functions, <script type="math/tex">a x_1^2 + b x_2^2 + c x_1 x_2 + d x_1 + e x_2 + f</script>, with coefficients <script type="math/tex">a</script>-<script type="math/tex">f</script> are randomly sampled from [-1, 1], this meta-learning system was able to approximate the function after seeing only ~35 examples.</p>
<h3 id="proposal-in-2016">Proposal in 2016</h3>
<p>In the modern days of DL, <a href="https://arxiv.org/abs/1611.05763">Wang et al.</a> (2016) and <a href="https://arxiv.org/abs/1611.02779">Duan et al.</a> (2017) simultaneously proposed the very similar idea of <strong>Meta-RL</strong> (it is called <strong>RL^2</strong> in the second paper). A meta-RL model is trained over a distribution of MDPs, and at test time, it is able to learn to solve a new task quickly. The goal of meta-RL is ambitious, taking one step further towards general algorithms.</p>
<h2 id="define-meta-rl">Define Meta-RL</h2>
<p><em>Meta Reinforcement Learning</em>, in short, is to do <a href="/lil-log/2018/11/30/meta-learning.html">meta-learning</a> in the field of <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">reinforcement learning</a>. Usually the train and test tasks are different but drawn from the same family of problems; i.e., experiments in the papers included multi-armed bandit with different reward probabilities, mazes with different layouts, same robots but with different physical parameters in simulator, and many others.</p>
<h3 id="formulation">Formulation</h3>
<p>Let’s say we have a distribution of tasks, each formularized as an <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#markov-decision-processes">MDP</a> (Markov Decision Process), <script type="math/tex">M_i \in \mathcal{M}</script>. An MDP is determined by a 4-tuple, <script type="math/tex">M_i= \langle \mathcal{S}, \mathcal{A}, P_i, R_i \rangle</script>:</p>
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr>
<td><script type="math/tex">\mathcal{S}</script></td>
<td>A set of states.</td>
</tr>
<tr>
<td><script type="math/tex">\mathcal{A}</script></td>
<td>A set of actions.</td>
</tr>
<tr>
<td><script type="math/tex">P_i: \mathcal{S} \times \mathcal{A} \times \mathcal{S} \to \mathbb{R}_{+}</script></td>
<td>Transition probability function.</td>
</tr>
<tr>
<td><script type="math/tex">R_i: \mathcal{S} \times \mathcal{A} \to \mathbb{R}</script></td>
<td>Reward function.</td>
</tr>
</tbody>
</table>
<p>(RL^2 paper adds an extra parameter, horizon <script type="math/tex">T</script>, into the MDP tuple to emphasize that each MDP should have a finite horizon.)</p>
<p>Note that common state <script type="math/tex">\mathcal{S}</script> and action space <script type="math/tex">\mathcal{A}</script> are used above, so that a (stochastic) policy: <script type="math/tex">\pi_\theta: \mathcal{S} \times \mathcal{A} \to \mathbb{R}_{+}</script> would get inputs compatible across different tasks. The test tasks are sampled from the same distribution <script type="math/tex">\mathcal{M}</script> or slightly modified version.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/meta-RL-illustration.png" alt="Illustration of meta-RL" /></p>
<p><em>Fig. 2. Illustration of meta-RL, containing two optimization loops. The outer loop samples a new environment in every iteration and adjusts parameters that determine the agent’s behavior. In the inner loop, the agent interacts with the environment and optimizes for the maximal reward. (Image source: <a href="https://www.cell.com/action/showPdf?pii=S1364-6613%2819%2930061-0">Botvinick, et al. 2019</a></em></p>
<h3 id="main-differences-from-rl">Main Differences from RL</h3>
<p>The overall configure of meta-RL is very similar to an ordinary RL algorithm, except that <strong>the last reward</strong> <script type="math/tex">r_{t-1}</script> and <strong>the last action</strong> <script type="math/tex">a_{t-1}</script> are also incorporated into the policy observation in addition to the current state <script type="math/tex">s_t</script>.</p>
<ul>
<li>In RL: <script type="math/tex">\pi_\theta(s_t) \to</script> a distribution over <script type="math/tex">\mathcal{A}</script></li>
<li>In meta-RL: <script type="math/tex">\pi_\theta(a_{t-1}, r_{t-1}, s_t) \to</script> a distribution over <script type="math/tex">\mathcal{A}</script></li>
</ul>
<p>The intention of this design is to feed a history into the model so that the policy can internalize the dynamics between states, rewards, and actions in the current MDP and adjust its strategy accordingly. This is well aligned with the setup in <a href="#back-in-2001">Hochreiter’s system</a>. Both meta-RL and RL^2 implemented an LSTM policy and the LSTM’s hidden states serve as a <em>memory</em> for tracking characteristics of the trajectories. Because the policy is recurrent, there is no need to feed the last state as inputs explicitly.</p>
<p>The training procedure works as follows:</p>
<ol>
<li>Sample a new MDP, <script type="math/tex">M_i \sim \mathcal{M}</script>;</li>
<li><strong>Reset the hidden state</strong> of the model;</li>
<li>Collect multiple trajectories and update the model weights;</li>
<li>Repeat from step 1.</li>
</ol>
<p style="width: 68%;" class="center"><img src="/lil-log/assets/images/L2RL.png" alt="L2RL" /></p>
<p><em>Fig. 3. In the meta-RL paper, different actor-critic architectures all use a recurrent model. Last reward and last action are additional inputs. The observation is fed into the LSTM either as a one-hot vector or as an embedding vector after passed through an encoder model. (Image source: <a href="https://arxiv.org/abs/1611.05763">Wang et al., 2016</a>)</em></p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RL_2.png" alt="RL^2" /></p>
<p><em>Fig. 4. As described in the RL^2 paper, illustration of the procedure of the model interacting with a series of MDPs in training time . (Image source: <a href="https://arxiv.org/abs/1611.02779">Duan et al., 2017</a>)</em></p>
<h3 id="key-components">Key Components</h3>
<p>There are three key components in Meta-RL:</p>
<blockquote>
<p>⭐ <strong>A Model with Memory</strong>
<br />
A recurrent neural network maintains a hidden state. Thus, it could acquire and memorize the knowledge about the current task by updating the hidden state during rollouts. Without memory, meta-RL would not work.</p>
</blockquote>
<blockquote>
<p>⭐ <strong>Meta-learning Algorithm</strong>
<br />
A meta-learning algorithm refers to how we can update the model weights to optimize for the purpose of solving an unseen task fast at test time. In both Meta-RL and RL^2 papers, the meta-learning algorithm is the ordinary gradient descent update of LSTM with hidden state reset between a switch of MDPs.</p>
</blockquote>
<blockquote>
<p>⭐ <strong>A Distribution of MDPs</strong>
<br />
While the agent is exposed to a variety of environments and tasks during training, it has to learn how to adapt to different MDPs.</p>
</blockquote>
<p>According to <a href="https://www.cell.com/action/showPdf?pii=S1364-6613%2819%2930061-0">Botvinick et al.</a> (2019), one source of slowness in RL training is <em>weak <a href="https://en.wikipedia.org/wiki/Inductive_bias">inductive bias</a></em> ( = “a set of assumptions that the learner uses to predict outputs given inputs that it has not encountered”). As a general ML rule, a learning algorithm with weak inductive bias will be able to master a wider range of variance, but usually, will be less sample-efficient. Therefore, to narrow down the hypotheses with stronger inductive biases help improve the learning speed.</p>
<p>In meta-RL, we impose certain types of inductive biases from the <em>task distribution</em> and store them in <em>memory</em>. Which inductive bias to adopt at test time depends on the <em>algorithm</em>. Together, these three key components depict a compelling view of meta-RL: Adjusting the weights of a recurrent network is slow but it allows the model to work out a new task fast with its own RL algorithm implemented in its internal activity dynamics.</p>
<p>Meta-RL interestingly and not very surprisingly matches the ideas in the <a href="https://arxiv.org/abs/1905.10985">AI-GAs</a> (“AI-Generating Algorithms”) paper by Jeff Clune (2019). He proposed that one efficient way towards building general AI is to make learning as automatic as possible. The AI-GAs approach involves three pillars: (1) meta-learning architectures, (2) meta-learning algorithms, and (3) automatically generated environments for effective learning.</p>
<hr />
<p>The topic of designing good recurrent network architectures is a bit too broad to be discussed here, so I will skip it. Next, let’s look further into another two components: meta-learning algorithms in the context of meta-RL and how to acquire a variety of training MDPs.</p>
<h2 id="meta-learning-algorithms-for-meta-rl">Meta-Learning Algorithms for Meta-RL</h2>
<p>My previous <a href="/lil-log/2018/11/30/meta-learning.html">post</a> on meta-learning has covered several classic meta-learning algorithms. Here I’m gonna include more related to RL.</p>
<h3 id="optimizing-model-weights-for-meta-learning">Optimizing Model Weights for Meta-learning</h3>
<p>Both MAML (<a href="https://arxiv.org/abs/1703.03400">Finn, et al. 2017</a>) and Reptile (<a href="https://arxiv.org/abs/1803.02999">Nichol et al., 2018</a>) are methods on updating model parameters in order to achieve good generalization performance on new tasks. See an earlier post <a href="/lil-log/2018/11/30/meta-learning.html#optimization-based">section</a> on MAML and Reptile.</p>
<h3 id="meta-learning-hyperparameters">Meta-learning Hyperparameters</h3>
<p>The <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#value-function">return</a> function in an RL problem, <script type="math/tex">G_t^{(n)}</script> or <script type="math/tex">G_t^\lambda</script>, involves a few hyperparameters that are often set heuristically, like the discount factor <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#value-function"><script type="math/tex">\gamma</script></a> and the bootstrapping parameter <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#combining-td-and-mc-learning"><script type="math/tex">\lambda</script></a>.
Meta-gradient RL (<a href="http://papers.nips.cc/paper/7507-meta-gradient-reinforcement-learning.pdf">Xu et al., 2018</a>) considers them as <em>meta-parameters</em>, <script type="math/tex">\eta=\{\gamma, \lambda \}</script>, that can be tuned and learned <em>online</em> while an agent is interacting with the environment. Therefore, the return becomes a function of <script type="math/tex">\eta</script> and dynamically adapts itself to a specific task over time.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
G_\eta^{(n)}(\tau_t) &= R_{t+1} + \gamma R_{t+2} + \dots + \gamma^{n-1}R_{t+n} + \gamma^n v_\theta(s_{t+n}) & \scriptstyle{\text{; n-step return}} \\
G_\eta^{\lambda}(\tau_t) &= (1-\lambda) \sum_{n=1}^\infty \lambda^{n-1} G_\eta^{(n)} & \scriptstyle{\text{; λ-return, mixture of n-step returns}}
\end{aligned} %]]></script>
<p>During training, we would like to update the policy parameters with gradients as a function of all the information in hand, <script type="math/tex">\theta' = \theta + f(\tau, \theta, \eta)</script>, where <script type="math/tex">\theta</script> are the current model weights, <script type="math/tex">\tau</script> is a sequence of trajectories, and <script type="math/tex">\eta</script> are the meta-parameters.</p>
<p>Meanwhile, let’s say we have a meta-objective function <script type="math/tex">J(\tau, \theta, \eta)</script> as a performance measure. The training process follows the principle of online cross-validation, using a sequence of consecutive experiences:</p>
<ol>
<li>Starting with parameter <script type="math/tex">\theta</script>, the policy <script type="math/tex">\pi_\theta</script> is updated on the first batch of samples <script type="math/tex">\tau</script>, resulting in <script type="math/tex">\theta'</script>.</li>
<li>Then we continue running the policy <script type="math/tex">\pi_{\theta'}</script> to collect a new set of experiences <script type="math/tex">\tau'</script>, just following <script type="math/tex">\tau</script> consecutively in time. The performance is measured as <script type="math/tex">J(\tau', \theta', \bar{\eta})</script> with a fixed meta-parameter <script type="math/tex">\bar{\eta}</script>.</li>
<li>The gradient of meta-objective <script type="math/tex">J(\tau', \theta', \bar{\eta})</script> w.r.t. <script type="math/tex">\eta</script> is used to update <script type="math/tex">\eta</script>:</li>
</ol>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\Delta \eta
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \eta} \\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \frac{d\theta'}{d\eta} & \scriptstyle{\text{ ; single variable chain rule.}} \\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \frac{\partial (\theta + f(\tau, \theta, \eta))}{\partial\eta} \\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \Big(\frac{d\theta}{d\eta} + \frac{\partial f(\tau, \theta, \eta)}{\partial\theta}\frac{d\theta}{d\eta} + \frac{\partial f(\tau, \theta, \eta)}{\partial\eta}\frac{d\eta}{d\eta} \Big) & \scriptstyle{\text{; multivariable chain rule.}}\\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \Big( \color{red}{\big(\mathbf{I} + \frac{\partial f(\tau, \theta, \eta)}{\partial\theta}\big)}\frac{d\theta}{d\eta} + \frac{\partial f(\tau, \theta, \eta)}{\partial\eta}\Big) & \scriptstyle{\text{; secondary gradient term in red.}}
\end{aligned} %]]></script>
<p>where <script type="math/tex">\beta</script> is the learning rate for <script type="math/tex">\eta</script>.</p>
<p>The meta-gradient RL algorithm simplifies the computation by setting the secondary gradient term to zero, <script type="math/tex">\mathbf{I} + \partial g(\tau, \theta, \eta)/\partial\theta = 0</script> — this choice prefers the immediate effect of the meta-parameters <script type="math/tex">\eta</script> on the parameters <script type="math/tex">\theta</script>. Eventually we get:</p>
<script type="math/tex; mode=display">\Delta \eta = -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \frac{\partial f(\tau, \theta, \eta)}{\partial\eta}</script>
<p>Experiments in the paper adopted the meta-objective function same as <script type="math/tex">TD(\lambda)</script> algorithm, minimizing the error between the approximated value function <script type="math/tex">v_\theta(s)</script> and the <script type="math/tex">\lambda</script>-return:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
J(\tau, \theta, \eta) &= (G^\lambda_\eta(\tau) - v_\theta(s))^2 \\
J(\tau', \theta', \bar{\eta}) &= (G^\lambda_{\bar{\eta}}(\tau') - v_{\theta'}(s'))^2
\end{aligned} %]]></script>
<h3 id="meta-learning-the-loss-function">Meta-learning the Loss Function</h3>
<p>In policy gradient algorithms, the expected total reward is maximized by updating the policy parameters <script type="math/tex">\theta</script> in the direction of estimated gradient (<a href="https://arxiv.org/abs/1506.02438">Schulman et al., 2016</a>),</p>
<script type="math/tex; mode=display">g = \mathbb{E}[\sum_{t=0}^\infty \Psi_t \nabla_\theta \log \pi_\theta (a_t \mid s_t)]</script>
<p>where the candidates for <script type="math/tex">\Psi_t</script> include the trajectory return <script type="math/tex">G_t</script>, the Q value <script type="math/tex">Q(s_t, a_t)</script>, or the advantage value <script type="math/tex">A(s_t, a_t)</script>. The corresponding surrogate loss function for the policy gradient can be reverse-engineered:</p>
<script type="math/tex; mode=display">L_\text{pg} = \mathbb{E}[\sum_{t=0}^\infty \Psi_t \log \pi_\theta (a_t \mid s_t)]</script>
<p>This loss function is a measure over a history of trajectories, <script type="math/tex">(s_0, a_0, r_0, \dots, s_t, a_t, r_t, \dots)</script>. <strong>Evolved Policy Gradient</strong> (<strong>EPG</strong>; <a href="https://papers.nips.cc/paper/7785-evolved-policy-gradients.pdf">Houthooft, et al, 2018</a>) takes a step further by defining the policy gradient loss function as a temporal convolution (1-D convolution) over the agent’s past experience, <script type="math/tex">L_\phi</script>. The parameters <script type="math/tex">\phi</script> of the loss function network are evolved in a way that an agent can achieve higher returns.</p>
<p>Similar to many meta-learning algorithms, EPG has two optimization loops:</p>
<ul>
<li>In the internal loop, an agent learns to improve its policy <script type="math/tex">\pi_\theta</script>.</li>
<li>In the outer loop, the model updates the parameters <script type="math/tex">\phi</script> of the loss function <script type="math/tex">L_\phi</script>. Because there is no explicit way to write down a differentiable equation between the return and the loss, EPG turned to <a href="https://en.wikipedia.org/wiki/Evolution_strategy"><em>Evolutionary Strategies</em></a> (ES).</li>
</ul>
<p>A general idea is to train a population of <script type="math/tex">N</script> agents, each of them is trained with the loss function <script type="math/tex">L_{\phi + \sigma \epsilon_i}</script> parameterized with <script type="math/tex">\phi</script> added with a small Gaussian noise <script type="math/tex">\epsilon_i \sim \mathcal{N}(0, \mathbf{I})</script> of standard deviation <script type="math/tex">\sigma</script>. During the inner loop’s training, EPG tracks a history of experience and updates the policy parameters according to the loss function <script type="math/tex">L_{\phi + \sigma\epsilon_i}</script> for each agent:</p>
<script type="math/tex; mode=display">\theta_i \leftarrow \theta - \alpha_\text{in} \nabla_\theta L_{\phi + \sigma \epsilon_i} (\pi_\theta, \tau_{t-K, \dots, t})</script>
<p>where <script type="math/tex">\alpha_\text{in}</script> is the learning rate of the inner loop and <script type="math/tex">\tau_{t-K, \dots, t}</script> is a sequence of <script type="math/tex">M</script> transitions up to the current time step <script type="math/tex">t</script>.</p>
<p>Once the inner loop policy is mature enough, the policy is evaluated by the mean return <script type="math/tex">\bar{G}_{\phi+\sigma\epsilon_i}</script> over multiple randomly sampled trajectories. Eventually, we are able to estimate the gradient of <script type="math/tex">\phi</script> according to <a href="http://blog.otoro.net/2017/10/29/visual-evolution-strategies/">NES</a> numerically (<a href="https://arxiv.org/abs/1703.03864">Salimans et al, 2017</a>). While repeating this process, both the policy parameters <script type="math/tex">\theta</script> and the loss function weights <script type="math/tex">\phi</script> are being updated simultaneously to achieve higher returns.</p>
<script type="math/tex; mode=display">\phi \leftarrow \phi + \alpha_\text{out} \frac{1}{\sigma N} \sum_{i=1}^N \epsilon_i G_{\phi+\sigma\epsilon_i}</script>
<p>where <script type="math/tex">\alpha_\text{out}</script> is the learning rate of the outer loop.</p>
<p>In practice, the loss <script type="math/tex">L_\phi</script> is bootstrapped with an ordinary policy gradient (such as REINFORCE or PPO) surrogate loss <script type="math/tex">L_\text{pg}</script>, <script type="math/tex">\hat{L} = (1-\alpha) L_\phi + \alpha L_\text{pg}</script>. The weight <script type="math/tex">\alpha</script> is annealing from 1 to 0 gradually during training. At test time, the loss function parameter <script type="math/tex">\phi</script> stays fixed and the loss value is computed over a history of experience to update the policy parameters <script type="math/tex">\theta</script>.</p>
<h3 id="meta-learning-the-exploration-strategies">Meta-learning the Exploration Strategies</h3>
<p>The <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#exploitation-vs-exploration">exploitation vs exploration</a> dilemma is a critical problem in RL. Common ways to do exploration include <script type="math/tex">\epsilon</script>-greedy, random noise on actions, or stochastic policy with built-in randomness on the action space.</p>
<p><strong>MAESN</strong> (<a href="http://papers.nips.cc/paper/7776-meta-reinforcement-learning-of-structured-exploration-strategies.pdf">Gupta et al, 2018</a>) is an algorithm to learn structured action noise from prior experience for better and more effective exploration. Simply adding random noise on actions cannot capture task-dependent or time-correlated exploration strategies. MAESN changes the policy to condition on a per-task random variable <script type="math/tex">z_i \sim \mathcal{N}(\mu_i, \sigma_i)</script>, for <script type="math/tex">i</script>-th task <script type="math/tex">M_i</script>, so we would have a policy <script type="math/tex">a \sim \pi_\theta(a\mid s, z_i)</script>.
The latent variable <script type="math/tex">z_i</script> is sampled once and fixed during one episode. Intuitively, the latent variable determines one type of behavior (or skills) that should be explored more at the beginning of a rollout and the agent would adjust its actions accordingly. Both the policy parameters and latent space are optimized to maximize the total task rewards. In the meantime, the policy learns to make use of the latent variables for exploration.</p>
<p>In addition, the loss function includes a KL divergence between the learned latent variable and a unit Gaussian prior, <script type="math/tex">D_\text{KL}(\mathcal{N}(\mu_i, \sigma_i)\|\mathcal{N}(0, \mathbf{I}))</script>. On one hand, it restricts the learned latent space not too far from a common prior. On the other hand, it creates the variational evidence lower bound (<a href="http://users.umiacs.umd.edu/~xyang35/files/understanding-variational-lower.pdf">ELBO</a>) for the reward function. Interestingly the paper found that <script type="math/tex">(\mu_i, \sigma_i)</script> for each task are usually close to the prior at convergence.</p>
<p style="width: 82%;" class="center"><img src="/lil-log/assets/images/MAESN.png" alt="MAESN" /></p>
<p><em>Fig. 5. The policy is conditioned on a latent variable variable <script type="math/tex">z_i \sim \mathcal{N}(\mu, \sigma)</script> that is sampled once every episode. Each task has different hyperparameters for the latent variable distribution, <script type="math/tex">(\mu_i, \sigma_i)</script> and they are optimized in the outer loop. (Image source: <a href="http://papers.nips.cc/paper/7776-meta-reinforcement-learning-of-structured-exploration-strategies.pdf">Gupta et al, 2018</a>)</em></p>
<h3 id="episodic-control">Episodic Control</h3>
<p>A major criticism of RL is on its sample inefficiency. A large number of samples and small learning steps are required for incremental parameter adjustment in RL in order to maximize generalization and avoid catastrophic forgetting of earlier learning (<a href="https://www.cell.com/trends/cognitive-sciences/fulltext/S1364-6613\(19\)30061-0">Botvinick et al., 2019</a>).</p>
<p><strong>Episodic control</strong> (<a href="http://papers.nips.cc/paper/3311-hippocampal-contributions-to-control-the-third-way.pdf">Lengyel & Dayan, 2008</a>) is proposed as a solution to avoid forgetting and improve generalization while training at a faster speed. It is partially inspired by hypotheses on instance-based <a href="https://en.wikipedia.org/wiki/Hippocampus">hippocampal</a> learning.</p>
<p>An <em>episodic memory</em> keeps explicit records of past events and uses these records directly as point of reference for making new decisions (i.e. just like <a href="/lil-log/2018/11/30/meta-learning.html#metric-based">metric-based</a> meta-learning). In <strong>MFEC</strong> (Model-Free Episodic Control; <a href="https://arxiv.org/abs/1606.04460">Blundell et al., 2016</a>), the memory is modeled as a big table, storing the state-action pair <script type="math/tex">(s, a)</script> as key and the corresponding Q-value <script type="math/tex">Q_\text{EC}(s, a)</script> as value. When receiving a new observation <script type="math/tex">s</script>, the Q value is estimated in an non-parametric way as the average Q-value of top <script type="math/tex">k</script> most similar samples:</p>
<script type="math/tex; mode=display">% <![CDATA[
\hat{Q}_\text{EC}(s, a) =
\begin{cases}
Q_\text{EC}(s, a) & \text{if } (s,a) \in Q_\text{EC}, \\
\frac{1}{k} \sum_{i=1}^k Q(s^{(i)}, a) & \text{otherwise}
\end{cases} %]]></script>
<p>where <script type="math/tex">s^{(i)}, i=1, \dots, k</script> are top <script type="math/tex">k</script> states with smallest distances to the state <script type="math/tex">s</script>. Then the action that yields the highest estimated Q value is selected. Then the memory table is updated according to the return received at <script type="math/tex">s_t</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
Q_\text{EC}(s, a) \leftarrow
\begin{cases}
\max\{Q_\text{EC}(s_t, a_t), G_t\} & \text{if } (s,a) \in Q_\text{EC}, \\
G_t & \text{otherwise}
\end{cases} %]]></script>
<p>As a tabular RL method, MFEC suffers from large memory consumption and a lack of ways to generalize among similar states. The first one can be fixed with an LRU cache. Inspired by <a href="/lil-log/2018/11/30/meta-learning.html#metric-based">metric-based</a> meta-learning, especially Matching Networks (<a href="http://papers.nips.cc/paper/6385-matching-networks-for-one-shot-learning.pdf">Vinyals et al., 2016</a>), the generalization problem is improved in a follow-up algorithm, <strong>NEC</strong> (Neural Episodic Control; <a href="https://arxiv.org/abs/1703.01988">Pritzel et al., 2016</a>).</p>
<p>The episodic memory in NEC is a Differentiable Neural Dictionary (<strong>DND</strong>), where the key is a convolutional embedding vector of input image pixels and the value stores estimated Q value. Given an inquiry key, the output is a weighted sum of values of top similar keys, where the weight is a normalized kernel measure between the query key and the selected key in the dictionary. This sounds like a hard <a href="/2018/06/24/attention-attention.html">attention</a> machanism.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/neural-episodic-control.png" alt="Neural episodic control" /></p>
<p><em>Fig. 6 Illustrations of episodic memory module in NEC and two operations on a differentiable neural dictionary. (Image source: <a href="https://arxiv.org/abs/1703.01988">Pritzel et al., 2016</a>)</em></p>
<p>Further, <strong>Episodic LSTM</strong> (<a href="https://arxiv.org/abs/1805.09692">Ritter et al., 2018</a>) enhances the basic LSTM architecture with a DND episodic memory, which stores task context embeddings as keys and the LSTM cell states as values. The stored hidden states are retrieved and added directly to the current cell state through the same gating mechanism within LSTM:</p>
<p style="width: 77%;" class="center"><img src="/lil-log/assets/images/episodic-LSTM.png" alt="Episodic LSTM" /></p>
<p><em>Fig. 7. Illustration of the episodic LSTM architecture. The additional structure of episodic memory is in bold. (Image source: <a href="https://arxiv.org/abs/1805.09692">Ritter et al., 2018</a>)</em></p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{c}_t &= \mathbf{i}_t \circ \mathbf{c}_\text{in} + \mathbf{f}_t \circ \mathbf{c}_{t-1} + \color{green}{\mathbf{r}_t \circ \mathbf{c}_\text{ep}} &\\
\mathbf{i}_t &= \sigma(\mathbf{W}_{i} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) & \scriptstyle{\text{; input gate}} \\
\mathbf{f}_t &= \sigma(\mathbf{W}_{f} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) & \scriptstyle{\text{; forget gate}} \\
\color{green}{\mathbf{r}_t} & \color{green}{=} \color{green}{\sigma(\mathbf{W}_{r} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_r)} & \scriptstyle{\text{; reinstatement gate}}
\end{aligned} %]]></script>
<p>where <script type="math/tex">\mathbf{c}_t</script> and <script type="math/tex">\mathbf{h}_t</script> are hidden and cell state at time <script type="math/tex">t</script>; <script type="math/tex">\mathbf{i}_t</script>, <script type="math/tex">\mathbf{f}_t</script> and <script type="math/tex">\mathbf{r}_t</script> are input, forget and reinstatement gates, respectively; <script type="math/tex">\mathbf{c}_\text{ep}</script> is the retrieved cell state from episodic memory. The newly added episodic memory components are marked in green.</p>
<p>This architecture provides a shortcut to the prior experience through context-based retrieval. Meanwhile, explicitly saving the task-dependent experience in an external memory avoids forgetting. In the paper, all the experiments have manually designed context vectors. How to construct an effective and efficient format of task context embeddings for more free-formed tasks would be an interesting topic.</p>
<p>Overall the capacity of episodic control is limited by the complexity of the environment. It is very rare for an agent to repeatedly visit exactly the same states in a real-world task, so properly encoding the states is critical. The learned embedding space compresses the observation data into a lower dimension space and, in the meantime, two states being close in this space are expected to demand similar strategies.</p>
<h2 id="training-task-acquisition">Training Task Acquisition</h2>
<p>Among three key components, how to design a proper distribution of tasks is the less studied and probably the most specific one to meta-RL itself. As described <a href="#formulation">above</a>, each task is a MDP: <script type="math/tex">M_i = \langle \mathcal{S}, \mathcal{A}, P_i, R_i \rangle \in \mathcal{M}</script>. We can build a distribution of MDPs by modifying:</p>
<ul>
<li>The <em>reward configuration</em>: Among different tasks, same behavior might get rewarded differently according to <script type="math/tex">R_i</script>.</li>
<li>Or, the <em>environment</em>: The transition function <script type="math/tex">P_i</script> can be reshaped by initializing the environment with varying shifts between states.</li>
</ul>
<h3 id="task-generation-by-domain-randomization">Task Generation by Domain Randomization</h3>
<p>Randomizing parameters in a simulator is an easy way to obtain tasks with modified transition functions. If interested in learning further, check my last <a href="/lil-log/2019/05/05/domain-randomization.html">post</a> on <strong>domain randomization</strong>.</p>
<h3 id="evolutionary-algorithm-on-environment-generation">Evolutionary Algorithm on Environment Generation</h3>
<p><a href="https://en.wikipedia.org/wiki/Evolutionary_algorithm">Evolutionary algorithm</a> is a gradient-free heuristic-based optimization method, inspired by natural selection. A population of solutions follows a loop of evaluation, selection, reproduction, and mutation. Eventually, good solutions survive and thus get selected.</p>
<p><strong>POET</strong> (<a href="https://arxiv.org/abs/1901.01753">Wang et al, 2019</a>), a framework based on the evolutionary algorithm, attempts to generate tasks while the problems themselves are being solved. The implementation of POET is only specifically designed for a simple 2D <a href="https://gym.openai.com/envs/BipedalWalkerHardcore-v2/">bipedal walker</a> environment but points out an interesting direction. It is noteworthy that the evolutionary algorithm has had some compelling applications in Deep Learning like <a href="#meta-learning-the-loss-function">EPG</a> and PBT (Population-Based Training; <a href="https://arxiv.org/abs/1711.09846"> Jaderberg et al, 2017</a>).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/POET.png" alt="POET" /></p>
<p><em>Fig. 8. An example bipedal walking environment (top) and an overview of POET (bottom). (Image source: <a href="https://eng.uber.com/poet-open-ended-deep-learning/">POET blog post</a>)</em></p>
<p>The 2D bipedal walking environment is evolving: from a simple flat surface to a much more difficult trail with potential gaps, stumps, and rough terrains. POET pairs the generation of environmental challenges and the optimization of agents together so as to (a) select agents that can resolve current challenges and (b) evolve environments to be solvable. The algorithm maintains a list of <em>environment-agent pairs</em> and repeats the following:</p>
<ol>
<li><em>Mutation</em>: Generate new environments from currently active environments. Note that here types of mutation operations are created just for bipedal walker and a new environment would demand a new set of configurations.</li>
<li><em>Optimization</em>: Train paired agents within their respective environments.</li>
<li><em>Selection</em>: Periodically attempt to transfer current agents from one environment to another. Copy and update the best performing agent for every environment. The intuition is that skills learned in one environment might be helpful for a different environment.</li>
</ol>
<p>The procedure above is quite similar to <a href="https://arxiv.org/abs/1711.09846">PBT</a>, but PBT mutates and evolves hyperparameters instead. To some extent, POET is doing <a href="/lil-log/2019/05/05/domain-randomization.html">domain randomization</a>, as all the gaps, stumps and terrain roughness are controlled by some randomization probability parameters. Different from DR, the agents are not exposed to a fully randomized difficult environment all at once, but instead they are learning gradually with a curriculum configured by the evolutionary algorithm.</p>
<h3 id="learning-with-random-rewards">Learning with Random Rewards</h3>
<p>An MDP without a reward function <script type="math/tex">R</script> is known as a <em>Controlled Markov process</em> (CMP). Given a predefined CMP, <script type="math/tex">\langle \mathcal{S}, \mathcal{A}, P\rangle</script>, we can acquire a variety of tasks by generating a collection of reward functions <script type="math/tex">\mathcal{R}</script> that encourage the training of an effective meta-learning policy.</p>
<p><a href="https://arxiv.org/abs/1806.04640">Gupta et al. (2018)</a> proposed two unsupervised approaches for growing the task distribution in the context of CMP. Assuming there is an underlying latent variable <script type="math/tex">z \sim p(z)</script> associated with every task, it parameterizes/determines a reward function: <script type="math/tex">r_z(s) = \log D(z|s)</script>, where a “discriminator” function <script type="math/tex">D(.)</script> is used to extract the latent variable from the state. The paper described two ways to construct a discriminator function:</p>
<ul>
<li>Sample random weights <script type="math/tex">\phi_\text{rand}</script> of the discriminator, <script type="math/tex">D_{\phi_\text{rand}}(z \mid s)</script>.</li>
<li>Learn a discriminator function to encourage diversity-driven exploration. This method is introduced in more details in another sister paper “DIAYN” (<a href="https://arxiv.org/abs/1802.06070">Eysenbach et al., 2018</a>).</li>
</ul>
<p>DIAYN, short for “Diversity is all you need”, is a framework to encourage a policy to learn useful skills without a reward function. It explicitly models the latent variable <script type="math/tex">z</script> as a <em>skill</em> embedding and makes the policy conditioned on <script type="math/tex">z</script> in addition to state <script type="math/tex">s</script>, <script type="math/tex">\pi_\theta(a \mid s, z)</script>. (Ok, this part is same as <a href="#meta-learning-the-exploration-strategies">MAESN</a> unsurprisingly, as the papers are from the same group.) The design of DIAYN is motivated by a few hypotheses:</p>
<ul>
<li>Skills should be diverse and lead to visitations of different states. → maximize the mutual information between states and skills, <script type="math/tex">I(S; Z)</script></li>
<li>Skills should be distinguishable by states, not actions. → minimize the mutual information between actions and skills, conditioned on states <script type="math/tex">I(A; Z \mid S)</script></li>
</ul>
<p>The objective function to maximize is as follows, where the policy entropy is also added to encourage diversity:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{F}(\theta)
&= I(S; Z) + H[A \mid S] - I(A; Z \mid S) & \\
&= (H(Z) - H(Z \mid S)) + H[A \mid S] - (H[A\mid S] - H[A\mid S, Z]) & \\
&= H[A\mid S, Z] \color{green}{- H(Z \mid S) + H(Z)} & \\
&= H[A\mid S, Z] + \mathbb{E}_{z\sim p(z), s\sim\rho(s)}[\log p(z \mid s)] - \mathbb{E}_{z\sim p(z)}[\log p(z)] & \scriptstyle{\text{; can infer skills from states & p(z) is diverse.}} \\
&\ge H[A\mid S, Z] + \mathbb{E}_{z\sim p(z), s\sim\rho(s)}[\color{red}{\log D_\phi(z \mid s) - \log p(z)}] & \scriptstyle{\text{; according to Jensen's inequality; "pseudo-reward" in red.}}
\end{aligned} %]]></script>
<p>where <script type="math/tex">I(.)</script> is mutual information and <script type="math/tex">H[.]</script> is entropy measure. We cannot integrate all states to compute <script type="math/tex">p(z \mid s)</script>, so approximate it with <script type="math/tex">D_\phi(z \mid s)</script> — that is the diversity-driven discriminator function.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DIAYN.png" alt="DIAYN" /></p>
<p><em>Fig. 9. DIAYN Algorithm. (Image source: <a href="https://arxiv.org/abs/1802.06070">Eysenbach et al., 2019</a>)</em></p>
<p>Once the discriminator function is learned, sampling a new MDP for training is strainght-forward: First, sample a latent variable, <script type="math/tex">z \sim p(z)</script> and construct a reward function <script type="math/tex">r_z(s) = \log(D(z \vert s))</script>. Pairing the reward function with a predefined CMP creates a new MDP.</p>
<!--
---
So far, experiments of meta-RL are still limited to a collection of very similar tasks, originated from the same family; such as multi-armed bandit with different reward probabilities, mazes with different layouts, or same robots but with different physical parameters in simulator. I'm looking forward to more research demonstrating the power of meta-RL over a more diverse set of tasks.
-->
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019metaRL,
title = "Meta Reinforcement Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/06/23/meta-reinforcement-learning.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Richard S. Sutton. <a href="http://incompleteideas.net/IncIdeas/BitterLesson.html">“The Bitter Lesson.”</a> March 13, 2019.</p>
<p>[2] Sepp Hochreiter, A. Steven Younger, and Peter R. Conwell. <a href="http://snowedin.net/tmp/Hochreiter2001.pdf">“Learning to learn using gradient descent.”</a> Intl. Conf. on Artificial Neural Networks. 2001.</p>
<p>[3] Jane X Wang, et al. <a href="https://arxiv.org/abs/1611.05763">“Learning to reinforcement learn.”</a> arXiv preprint arXiv:1611.05763 (2016).</p>
<p>[4] Yan Duan, et al. <a href="https://arxiv.org/abs/1611.02779">“RL $^ 2$: Fast Reinforcement Learning via Slow Reinforcement Learning.”</a> ICLR 2017.</p>
<p>[5] Matthew Botvinick, et al. <a href="https://www.cell.com/trends/cognitive-sciences/fulltext/S1364-6613\(19\)30061-0">“Reinforcement Learning, Fast and Slow”</a> Cell Review, Volume 23, Issue 5, P408-422, May 01, 2019.</p>
<p>[6] Jeff Clune. <a href="https://arxiv.org/abs/1905.10985">“AI-GAs: AI-generating algorithms, an alternate paradigm for producing general artificial intelligence”</a> arXiv preprint arXiv:1905.10985 (2019).</p>
<p>[7] Zhongwen Xu, et al. <a href="http://papers.nips.cc/paper/7507-meta-gradient-reinforcement-learning.pdf">“Meta-Gradient Reinforcement Learning”</a> NIPS 2018.</p>
<p>[8] Rein Houthooft, et al. <a href="https://papers.nips.cc/paper/7785-evolved-policy-gradients.pdf">“Evolved Policy Gradients.”</a> NIPS 2018.</p>
<p>[9] Tim Salimans, et al. <a href="https://arxiv.org/abs/1703.03864">“Evolution strategies as a scalable alternative to reinforcement learning.”</a> arXiv preprint arXiv:1703.03864 (2017).</p>
<p>[10] Abhishek Gupta, et al. <a href="http://papers.nips.cc/paper/7776-meta-reinforcement-learning-of-structured-exploration-strategies.pdf">“Meta-Reinforcement Learning of Structured Exploration Strategies.”</a> NIPS 2018.</p>
<p>[11] Alexander Pritzel, et al. <a href="https://arxiv.org/abs/1703.01988">“Neural episodic control.”</a> Proc. Intl. Conf. on Machine Learning, Volume 70, 2017.</p>
<p>[12] Charles Blundell, et al. <a href="https://arxiv.org/abs/1606.04460">“Model-free episodic control.”</a> arXiv preprint arXiv:1606.04460 (2016).</p>
<p>[13] Samuel Ritter, et al. <a href="https://arxiv.org/abs/1805.09692">“Been there, done that: Meta-learning with episodic recall.”</a> ICML, 2018.</p>
<p>[14] Rui Wang et al. <a href="https://arxiv.org/abs/1901.01753">“Paired Open-Ended Trailblazer (POET): Endlessly Generating Increasingly Complex and Diverse Learning Environments and Their Solutions”</a> arXiv preprint arXiv:1901.01753 (2019).</p>
<p>[15] Uber Engineering Blog: <a href="https://eng.uber.com/poet-open-ended-deep-learning/">“POET: Endlessly Generating Increasingly Complex and Diverse Learning Environments and their Solutions through the Paired Open-Ended Trailblazer.”</a> Jan 8, 2019.</p>
<p>[16] Abhishek Gupta, et al.<a href="https://arxiv.org/abs/1806.04640">“Unsupervised meta-learning for Reinforcement Learning”</a> arXiv preprint arXiv:1806.04640 (2018).</p>
<p>[17] Eysenbach, Benjamin, et al. <a href="https://arxiv.org/abs/1802.06070">“Diversity is all you need: Learning skills without a reward function.”</a> ICLR 2019.</p>
<p>[18] Max Jaderberg, et al. <a href="https://arxiv.org/abs/1711.09846">“Population Based Training of Neural Networks.”</a> arXiv preprint arXiv:1711.09846 (2017).</p>Lilian WengMeta-RL is meta-learning on reinforcement learning tasks. After trained over a distribution of tasks, the agent is able to solve a new task by developing a new RL algorithm with its internal activity dynamics. This post starts with the origin of meta-RL and then dives into three key components of meta-RL.Domain Randomization for Sim2Real Transfer2019-05-05T00:00:00+00:002019-05-05T00:00:00+00:00https://lilianweng.github.io/lil-log/2019/05/05/domain-randomization<blockquote>
<p>If a model or policy is mainly trained in a simulator but expected to work on a real robot, it would surely face the sim2real gap. <em>Domain Randomization</em> (DR) is a simple but powerful idea of closing this gap by randomizing properties of the training environment.</p>
</blockquote>
<!--more-->
<p>In Robotics, one of the hardest problems is how to make your model transfer to the real world. Due to the sample inefficiency of deep RL algorithms and the cost of data collection on real robots, we often need to train models in a simulator which theoretically provides an infinite amount of data. However, the reality gap between the simulator and the physical world often leads to failure when working with physical robots. The gap is triggered by an inconsistency between physical parameters (i.e. friction, kp, damping, mass, density) and, more fatally, the incorrect physical modeling (i.e. collision between soft surfaces).</p>
<p>To close the sim2real gap, we need to improve the simulator and make it closer to reality. A couple of approaches:</p>
<ul>
<li><strong>System identification</strong>
<ul>
<li><em>System identification</em> is to build a mathematical model for a physical system; in the context of RL, the mathematical model is the simulator. To make the simulator more realistic, careful calibration is necessary.</li>
<li>Unfortunately, calibration is expensive. Furthermore, many physical parameters of the same machine might vary significantly due to temperature, humidity, positioning or its wear-and-tear in time.</li>
</ul>
</li>
<li><strong>Domain adaptation</strong>
<ul>
<li><em>Domain adaptation (DA)</em> refers to a set of transfer learning techniques developed to update the data distribution in sim to match the real one through a mapping or regularization enforced by the task model.</li>
<li>Many DA models, especially for image classification or end-to-end image-based RL task, are built on adversarial loss or <a href="/lil-log/2017/08/20/from-GAN-to-WGAN.html">GAN</a>.</li>
</ul>
</li>
<li><strong>Domain randomization</strong>
<ul>
<li>With <em>domain randomization (DR)</em>, we are able to create a variety of simulated environments with randomized properties and train a model that works across all of them.</li>
<li>Likely this model can adapt to the real-world environment, as the real system is expected to be one sample in that rich distribution of training variations.</li>
</ul>
</li>
</ul>
<p>Both DA and DR are unsupervised. Compared to DA which requires a decent amount of real data samples to capture the distribution, DR may need <em>only a little or no</em> real data. DR is the focus of this post.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/sim2real-transfer.png" alt="Approaches for sim2real transfer" /></p>
<p><em>Fig. 1. Conceptual illustrations of three approaches for sim2real transfer.</em></p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#what-is-domain-randomization" id="markdown-toc-what-is-domain-randomization">What is Domain Randomization?</a></li>
<li><a href="#uniform-domain-randomization" id="markdown-toc-uniform-domain-randomization">Uniform Domain Randomization</a></li>
<li><a href="#why-does-domain-randomization-work" id="markdown-toc-why-does-domain-randomization-work">Why does Domain Randomization Work?</a> <ul>
<li><a href="#dr-as-optimization" id="markdown-toc-dr-as-optimization">DR as Optimization</a></li>
<li><a href="#dr-as-meta-learning" id="markdown-toc-dr-as-meta-learning">DR as Meta-Learning</a></li>
</ul>
</li>
<li><a href="#guided-domain-randomization" id="markdown-toc-guided-domain-randomization">Guided Domain Randomization</a> <ul>
<li><a href="#optimization-for-task-performance" id="markdown-toc-optimization-for-task-performance">Optimization for Task Performance</a></li>
<li><a href="#match-real-data-distribution" id="markdown-toc-match-real-data-distribution">Match Real Data Distribution</a></li>
<li><a href="#guided-by-data-in-simulator" id="markdown-toc-guided-by-data-in-simulator">Guided by Data in Simulator</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="what-is-domain-randomization">What is Domain Randomization?</h2>
<p>To make the definition more general, let us call the environment that we have full access to (i.e. simulator) <strong>source domain</strong> and the environment that we would like to transfer the model to <strong>target domain</strong> (i.e. physical world). Training happens in the source domain. We can control a set of <script type="math/tex">N</script> randomization parameters in the source domain <script type="math/tex">e_\xi</script> with a configuration <script type="math/tex">\xi</script>, sampled from a randomization space, <script type="math/tex">\xi \in \Xi \subset \mathbb{R}^N</script>.</p>
<p>During policy training, episodes are collected from source domain with randomization applied. Thus the policy is exposed to a variety of environments and learns to generalize. The policy parameter <script type="math/tex">\theta</script> is trained to maximize the expected reward <script type="math/tex">R(.)</script> average across a distribution of configurations:</p>
<script type="math/tex; mode=display">\theta^* = \arg\max_\theta \mathbb{E}_{\xi \sim \Xi} [\mathbb{E}_{\pi_\theta, \tau \sim e_\xi} [R(\tau)]]</script>
<p>where <script type="math/tex">\tau_\xi</script> is a trajectory collected in source domain randomized with <script type="math/tex">\xi</script>. In a way, <em>“discrepancies between the source and target domains are modeled as variability in the source domain.”</em> (quote from <a href="https://arxiv.org/abs/1710.06537">Peng et al. 2018</a>).</p>
<h2 id="uniform-domain-randomization">Uniform Domain Randomization</h2>
<p>In the original form of DR (<a href="https://arxiv.org/abs/1703.06907">Tobin et al, 2017</a>; <a href="https://arxiv.org/pdf/1611.04201.pdf">Sadeghi et al. 2016</a>), each randomization parameter <script type="math/tex">\xi_i</script> is bounded by an interval, <script type="math/tex">\xi_i \in [\xi_i^\text{low}, \xi_i^\text{high}], i=1,\dots,N</script> and each parameter is uniformly sampled within the range.</p>
<p>The randomization parameters can control appearances of the scene, including but not limited to the followings (see Fig. 2). A model trained on simulated and randomized images is able to transfer to real non-randomized images.</p>
<ul>
<li>Position, shape, and color of objects,</li>
<li>Material texture,</li>
<li>Lighting condition,</li>
<li>Random noise added to images,</li>
<li>Position, orientation, and field of view of the camera in the simulator.</li>
</ul>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/DR.png" alt="Domain Randomization" /></p>
<p><em>Fig. 2. Images captured in the training environment are randomized. (Image source: <a href="https://arxiv.org/abs/1703.06907">Tobin et al, 2017</a>)</em></p>
<p>Physical dynamics in the simulator can also be randomized (<a href="https://arxiv.org/abs/1710.06537">Peng et al. 2018</a>). Studies have showed that a <em>recurrent</em> policy can adapt to different physical dynamics including the partially observable reality. A set of physical dynamics features include but are not limited to:</p>
<ul>
<li>Mass and dimensions of objects,</li>
<li>Mass and dimensions of robot bodies,</li>
<li>Damping, kp, friction of the joints,</li>
<li>Gains for the PID controller (P term),</li>
<li>Joint limit,</li>
<li>Action delay,</li>
<li>Observation noise.</li>
</ul>
<p>With visual and dynamics DR, at OpenAI Robotics, we were able to learn a policy that works on real dexterous robot hand (<a href="https://arxiv.org/abs/1808.00177">OpenAI, 2018</a>). Our manipulation task is to teach the robot hand to rotate an object continously to achieve 50 successive random target orientations. The sim2real gap in this task is very large, due to (a) a high number of simultaneous contacts between the robot and the object and (b) imperfect simulation of object collision and other motions. At first, the policy could barely survive for more than 5 seconds without dropping the object. But with the help of DR, the policy evolved to work surprisingly well in reality eventually.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/DKe8FumoD4E" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen=""></iframe>
<h2 id="why-does-domain-randomization-work">Why does Domain Randomization Work?</h2>
<p>Now you may ask, why does domain randomization work so well? The idea sounds really simple. Here are two non-exclusive explanations I found most convincing.</p>
<h3 id="dr-as-optimization">DR as Optimization</h3>
<p>One idea (<a href="https://arxiv.org/abs/1903.11774">Vuong, et al, 2019</a>) is to view learning randomization parameters in DR as a <em>bilevel optimization</em>. Assuming we have access to the real environment <script type="math/tex">e_\text{real}</script> and the randomization config is sampled from a distribution parameterized by <script type="math/tex">\phi</script>, <script type="math/tex">\xi \sim P_\phi(\xi)</script>, we would like to learn a distribution on which a policy <script type="math/tex">\pi_\theta</script> is trained on can achieve maximal performance in <script type="math/tex">e_\text{real}</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
&\phi^* = \arg\min_{\phi} \mathcal{L}(\pi_{\theta^*(\phi)}; e_\text{real}) \\
\text{where } &\theta^*(\phi) = \arg\min_\theta \mathbb{E}_{\xi \sim P_\phi(\xi)}[\mathcal{L}(\pi_\theta; e_\xi)]
\end{aligned} %]]></script>
<p>where <script type="math/tex">\mathcal{L}(\pi; e)</script> is the loss function of policy <script type="math/tex">\pi</script> evaluated in the environment <script type="math/tex">e</script>.</p>
<p>Although randomization ranges are hand-picked in uniform DR, it often involves domain knowledge and a couple rounds of trial-and-error adjustment based on the transfer performance. Essentially this is a manual optimization process on tuning <script type="math/tex">\phi</script> for the optimal <script type="math/tex">\mathcal{L}(\pi_{\theta^*(\phi)}; e_\text{real})</script>.</p>
<p>Guided domain randomization in the next section is largely inspired by this view, aiming to do bilevel optimization and learn the best parameter distribution automatically.</p>
<h3 id="dr-as-meta-learning">DR as Meta-Learning</h3>
<p>In our learning dexterity project (<a href="https://arxiv.org/abs/1808.00177">OpenAI, 2018</a>), we trained an LSTM policy to generalize across different environmental dynamics. We observed that once a robot achieved the first rotation, the time it needed for the following successes was much shorter. Also, a FF policy without memory was found not able to transfer to a physical robot. Both are evidence of the policy dynamically learning and adapting to a new environment.</p>
<p>In some ways, domain randomization composes a collection of different tasks. Memory in the recurrent network empowers the policy to achieve <a href="/lil-log/2018/11/30/meta-learning.html"><em>meta-learning</em></a> across tasks and further work on a real-world setting.</p>
<h2 id="guided-domain-randomization">Guided Domain Randomization</h2>
<p>The vanilla DR assumes no access to the real data, and thus the randomization config is sampled as broadly and uniformly as possible in sim, hoping that the real environment could be covered under this broad distribution. It is reasonable to think of a more sophisticated strategy — replacing uniform sampling with guidance from <em>task performance</em>, <em>real data</em>, or <em>simulator</em>.</p>
<p>One motivation for guided DR is to save computation resources by avoiding training models in unrealistic environments. Another is to avoid infeasible solutions that might arise from overly wide randomization distributions and thus might hinder successful policy learning.</p>
<h3 id="optimization-for-task-performance">Optimization for Task Performance</h3>
<p>Say we train a family of policies with different randomization parameters <script type="math/tex">\xi \sim P_\phi(\xi)</script>, where <script type="math/tex">P_\xi</script> is the distribution for <script type="math/tex">\xi</script> parameterized by <script type="math/tex">\phi</script>. Later we decide to try every one of them on the downstream task in the target domain (i.e. control a robot in reality or evaluate on a validation set) to collect feedback. This feedback tells us how good a configuration <script type="math/tex">\xi</script> is and provides signals for optimizing <script type="math/tex">\phi</script>.</p>
<p>Inspired by <a href="https://ai.google/research/pubs/pub45826">NAS</a>, <strong>AutoAugment</strong> (<a href="https://arxiv.org/abs/1805.09501">Cubuk, et al. 2018</a>) frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem. Note that AutoAugment is not proposed for sim2real transfer, but falls in the bucket of DR guided by task performance. Individual augmentation configuration is tested on the evaluation set and the performance improvement is used as a reward to train a PPO policy. This policy outputs different augmentation strategies for different datasets; for example, for CIFAR-10 AutoAugment mostly picks color-based transformations, while ImageNet prefers geometric based.</p>
<p><a href="https://arxiv.org/abs/1810.02513">Ruiz (2019)</a> considered the <em>task feedback</em> as <em>reward</em> in RL problem and proposed a RL-based method, named “learning to simulate”, for adjusting <script type="math/tex">\xi</script>. A policy is trained to predict <script type="math/tex">\xi</script> using performance metrics on the validation data of the main task as rewards, which is modeled as a multivariate Gaussian. Overall the idea is similar to AutoAugment, applying NAS on data generation. According to their experiments, even if the main task model is not converged, it still can provide a reasonable signal to the data generation policy.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/learning-to-simulate.png" alt="Learning to simulate" /></p>
<p><em>Fig. 3. An overview of the “learning to simulate” approach. (Image source: <a href="https://arxiv.org/abs/1810.02513">Ruiz (2019)</a>)</em></p>
<p>Evolutionary algorithm is another way to go, where the <em>feedback</em> is treated as <em>fitness</em> for guiding evolution (<a href="https://openreview.net/forum?id=H1g6osRcFQ">Yu et al, 2019</a>). In this study, they used <a href="https://en.wikipedia.org/wiki/CMA-ES">CMA-ES</a> (covariance matrix adaptation evolution strategy) while fitness is the performance of a <script type="math/tex">\xi</script>-conditional policy in target environment. In the appendix, they compared CMA-ES with other ways of modeling the dynamics of <script type="math/tex">\xi</script>, including Bayesian optimization or a neural network. The main claim was those methods are not as stable or sample efficient as CMA-ES. Interestly, when modeling <script type="math/tex">P(\xi)</script> as a neural network, LSTM is found to notably outperform FF.</p>
<p>Some believe that sim2real gap is a combination of appearance gap and content gap; i.e. most GAN-inspired DA models focus on appearance gap. <strong>Meta-Sim</strong> (<a href="https://arxiv.org/abs/1904.11621">Kar, et al. 2019</a>) aims to close the content gap by generating task-specific synthetic datasets. Meta-Sim uses self-driving car training as an example and thus the scene could be very complicated. In this case, the synthetic scenes are parameterized by a hierarchy of objects with properties (i.e., location, color) as well as relationships between objects. The hierarchy is specified by a probabilistic scene grammar akin to structure domain randomization (<strong>SDR</strong>; <a href="https://arxiv.org/abs/1810.10093">Prakash et al., 2018</a>) and it is assumed to be known beforehand. A model <script type="math/tex">G</script> is trained to augment the distribution of scene properties <script type="math/tex">s</script> by following:</p>
<ol>
<li>Learn the prior first: pre-train <script type="math/tex">G</script> to learn the identity function <script type="math/tex">G(s) = s</script>.</li>
<li>Minimize MMD loss between the real and sim data distributions. This involves backpropagation through non-differentiable renderer. The paper computes it numerically by perturbing the attributes of <script type="math/tex">G(s)</script>.</li>
<li>Minimize REINFORCE task loss when trained on synthetic data but evaluated on real data. Again, very similar to AutoAugment.</li>
</ol>
<p>Unfortunately, this family of methods are not suitable for sim2real case. Either an RL policy or an EA model requires a large number of real samples. And it is really expensive to include real-time feedback collection on a physical robot into the training loop. Whether you want to trade less computation resource for real data collection would depend on your task.</p>
<h3 id="match-real-data-distribution">Match Real Data Distribution</h3>
<p>Using real data to guide domain randomization feels a lot like doing system identification or DA. The core idea behind DA is to improve the synthetic data to match the real data distribution. In the case of real-data-guided DR, we would like to learn the randomization parameters <script type="math/tex">\xi</script> that bring the state distribution in simulator close to the state distribution in the real world.</p>
<p>The <strong>SimOpt</strong> model (<a href="https://arxiv.org/abs/1810.05687">Chebotar et al, 2019</a>) is trained under an initial randomization distribution <script type="math/tex">P_\phi(\xi)</script> first, getting a policy <script type="math/tex">\pi_{\theta, P_\phi}</script>. Then this policy is deployed on both simulator and physical robot to collect trajectories <script type="math/tex">\tau_\xi</script> and <script type="math/tex">\tau_\text{real}</script> respectively. The optimization objective is to minimize the discrepancy between sim and real trajectories:</p>
<script type="math/tex; mode=display">\phi^* = \arg\min_{\phi}\mathbb{E}_{\xi \sim P_\phi(\xi)} [\mathbb{E}_{\pi_{\theta, P_\phi}} [D(\tau_\text{sim}, \tau_\text{real})]]</script>
<p>where <script type="math/tex">D(.)</script> is a trajectory-based discrepancy measure. Like the “Learning to simulate” paper, SimOpt also has to solve the tricky problem of how to propagate gradient through non-differentiable simulator. It used a method called <a href="https://www.aaai.org/ocs/index.php/AAAI/AAAI10/paper/viewFile/1851/2264">relative entropy policy search</a>, see paper for more details.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/simopt.png" alt="SimOpt" /></p>
<p><em>Fig. 4. An overview of the SimOpt framework. (Image source: <a href="https://arxiv.org/abs/1810.05687">Chebotar et al, 2019</a>)</em></p>
<p><strong>RCAN</strong> (<a href="https://arxiv.org/abs/1812.07252">James et al., 2019</a>), short for “Randomized-to-Canonical Adaptation Networks”, is a nice combination of DA and DR for end-to-end RL tasks. An image-conditional GAN (<a href="https://arxiv.org/abs/1611.07004">cGAN</a>) is trained in sim to translate a domain-randomized image into a non-randomized version (aka “canonical version”). Later the same model is used to translate real images into corresponding simulated version so that the agent would consume consistent observation as what it has encountered in training. Still, the underlying assumption is that the distribution of domain-randomized sim images is broad enough to cover real-world samples.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/RCAN.png" alt="RCAN" /></p>
<p><em>Fig. 5. RCAN is an image-conditional generator that can convert a domain-randomized or real image into its corresponding non-randomized simulator version. (Image source: <a href="https://arxiv.org/abs/1812.07252">James et al., 2019</a>)</em></p>
<p>The RL model is trained end-to-end in a simulator to do vision-based robot arm grasping. Randomization is applied at each timestep, including the position of tray divider, objects to grasp, random textures, as well as the position, direction, and color of the lighting. The canonical version is the default simulator look. RCAN is trying to learn a generator</p>
<p><script type="math/tex">G</script>: randomized image <script type="math/tex">\to</script> {canonical image, segmentation, depth}</p>
<p>where segmentation masks and depth images are used as auxiliary tasks. RCAN had a better zero-shot transfer compared to uniform DR, although both were shown to be worse than the model trained on only real images. Conceptually, RCAN operates in a reverse direction of <a href="https://arxiv.org/abs/1709.07857">GraspGAN</a> which translates synthetic images into real ones by domain adaptation.</p>
<h3 id="guided-by-data-in-simulator">Guided by Data in Simulator</h3>
<p>Network-driven domain randomization (<a href="https://arxiv.org/abs/1904.02750">Zakharov et al., 2019</a>), also known as <strong>DeceptionNet</strong>, is motivated by learning which randomizations are actually useful to bridge the domain gap for image classification tasks.</p>
<p>Randomization is applied through a set of deception modules with encoder-decoder architecture. The deception modules are specifically designed to transform images; such as change backgrounds, add distortion, change lightings, etc. The other recognition network handles the main task by running classification on transformed images.</p>
<p>The training involves two steps:</p>
<ol>
<li>With the recognition network fixed, <em>maximize the difference</em> between the prediction and the labels by applying reversed gradients during backpropagation. So that the deception module can learn the most confusing tricks.</li>
<li>With the deception modules fixed, train the recognition network with input images altered.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/deception-net.png" alt="DeceptionNet" /></p>
<p><em>Fig. 6. How DeceptionNet works. (Image source: <a href="https://arxiv.org/abs/1904.02750">Zakharov et al., 2019</a>)</em></p>
<p>The feedback for training deception modules is provided by the downstream classifier. But rather than trying to maximize the task performance like <a href="#optimization-for-task-performance">the section</a> above, the randomization modules aim to create harder cases. One big disadvantage is you need to manually design different deception modules for different datasets or tasks, making it not easily scalable. Given the fact that it is zero-shot, the results are still worse than SOTA DA methods on MNIST and LineMOD.</p>
<p>Similarly, Active domain randomization (<strong>ADR</strong>; <a href="https://arxiv.org/abs/1904.04762">Mehta et al., 2019</a>) also relies on sim data to create harder training samples. ADR searches for the <em>most informative</em> environment variations within the given randomization ranges, where the <em>informativeness</em> is measured as the discrepancies of policy rollouts in randomized and reference (original, non-randomized) environment instances. Sounds a bit like <a href="#match-real-data-distribution">SimOpt</a>? Well, noted that SimOpt measures the discrepancy between sim and real rollouts, while ADR measures between randomized and non-randomized sim, avoiding the expensive real data collection part.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/ADR.png" alt="ADR" /></p>
<p><em>Fig. 7. How active domain randomization (ADR) works. (Image source: <a href="https://arxiv.org/abs/1904.04762">Mehta et al., 2019</a>)</em></p>
<p>Precisely the training happens as follows:</p>
<ol>
<li>Given a policy, run it on both reference and randomized envs and collect two sets of trajectories respectively.</li>
<li>Train a discriminator model to tell whether a rollout trajectory is randomized apart from reference run. The predicted <script type="math/tex">\log p</script> (probability of being randomized) is used as reward. The more different randomized and reference rollouts, the easier the prediction, the higher the reward.
<ul>
<li>The intuition is that if an environment is easy, the same policy agent can produce similar trajectories as in the reference one. Then the model should reward and explore hard environments by encouraging different behaviors.</li>
</ul>
</li>
<li>The reward by discriminator is fed into <em>Stein Variational Policy Gradient</em> (<a href="https://arxiv.org/abs/1704.02399">SVPG</a>) particles, outputting a diverse set of randomization configurations.</li>
</ol>
<p>The idea of ADR is very appealing with two small concerns. The similarity between trajectories might not be a good way to measure the env difficulty when running a stochastic policy. The sim2real results look unfortunately not as exciting, but the paper pointed out the win being ADR explores a smaller range of randomization parameters.</p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019DR,
title = "Domain Randomization for Sim2Real Transfer",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/05/04/domain-randomization.html"
}
</code></pre></div></div>
<p>Overall, after reading this post, I hope you like domain randomization as much as I do :).</p>
<h2 id="references">References</h2>
<p>[1] Josh Tobin, et al. <a href="https://arxiv.org/pdf/1703.06907.pdf">“Domain randomization for transferring deep neural networks from simulation to the real world.”</a> IROS, 2017.</p>
<p>[2] Fereshteh Sadeghi and Sergey Levine. <a href="https://arxiv.org/abs/1611.04201">“CAD2RL: Real single-image flight without a single real image.”</a> arXiv:1611.04201 (2016).</p>
<p>[3] Xue Bin Peng, et al. <a href="https://arxiv.org/abs/1710.06537">“Sim-to-real transfer of robotic control with dynamics randomization.”</a> ICRA, 2018.</p>
<p>[4] Nataniel Ruiz, et al. <a href="https://openreview.net/forum?id=HJgkx2Aqt7">“Learning to Simulate.”</a> ICLR 2019</p>
<p>[5] OpenAI. <a href="https://arxiv.org/abs/1808.00177">“Learning Dexterous In-Hand Manipulation.”</a> arXiv:1808.00177 (2018).</p>
<p>[6] OpenAI Blog. <a href="https://openai.com/blog/learning-dexterity/">“Learning dexterity”</a> July 30, 2018.</p>
<p>[7] Quan Vuong, et al. <a href="https://arxiv.org/abs/1903.11774">“How to pick the domain randomization parameters for sim-to-real transfer of reinforcement learning policies?.”</a> arXiv:1903.11774 (2019).</p>
<p>[8] Ekin D. Cubuk, et al. <a href="https://arxiv.org/abs/1805.09501">“AutoAugment: Learning augmentation policies from data.”</a> arXiv:1805.09501 (2018).</p>
<p>[9] Wenhao Yu et al. <a href="https://openreview.net/forum?id=H1g6osRcFQ">“Policy Transfer with Strategy Optimization.”</a> ICLR 2019</p>
<p>[10] Yevgen Chebotar et al. <a href="https://arxiv.org/abs/1810.05687">“Closing the Sim-to-Real Loop: Adapting Simulation Randomization with Real World Experience.”</a> Arxiv: 1810.05687 (2019).</p>
<p>[11] Stephen James et al. <a href="https://arxiv.org/abs/1812.07252">“Sim-to-real via sim-to-sim: Data-efficient robotic grasping via randomized-to-canonical adaptation networks”</a> CVPR 2019.</p>
<p>[12] Bhairav Mehta et al. <a href="https://arxiv.org/abs/1904.04762">“Active Domain Randomization”</a> arXiv:1904.04762</p>
<p>[13] Sergey Zakharov,et al. <a href="https://arxiv.org/abs/1904.02750">“DeceptionNet: Network-Driven Domain Randomization.”</a> arXiv:1904.02750 (2019).</p>
<p>[14] Amlan Kar, et al. <a href="https://arxiv.org/abs/1904.11621">“Meta-Sim: Learning to Generate Synthetic Datasets.”</a> arXiv:1904.11621 (2019).</p>
<p>[15] Aayush Prakash, et al. <a href="https://arxiv.org/abs/1810.10093">“Structured Domain Randomization: Bridging the Reality Gap by Context-Aware Synthetic Data.”</a> arXiv:1810.10093 (2018).</p>Lilian WengIf a model or policy is mainly trained in a simulator but expected to work on a real robot, it would surely face the sim2real gap. Domain Randomization (DR) is a simple but powerful idea of closing this gap by randomizing properties of the training environment.Are Deep Neural Networks Dramatically Overfitted?2019-03-14T12:00:00+00:002019-03-14T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/03/14/are-deep-neural-networks-dramatically-overfitted<blockquote>
<p>If you are, like me, confused by why deep neural networks can generalize to out-of-sample data points without drastic overfitting, keep on reading.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2019-05-27: add the <a href="#the-lottery-ticket-hypothesis">section</a> on Lottery Ticket Hypothesis.]</span></p>
<p>If you are like me, entering into the field of deep learning with experience in traditional machine learning, you may often ponder over this question: Since a typical deep neural network has so many parameters and training error can easily be perfect, it should surely suffer from substantial overfitting. How could it be ever generalized to out-of-sample data points?</p>
<p>The effort in understanding why deep neural networks can generalize somehow reminds me of this interesting paper on System Biology — <a href="https://www.cell.com/cancer-cell/pdf/S1535-6108(02)00133-2.pdf">“Can a biologist fix a radio?”</a> (Lazebnik, 2002). If a biologist intends to fix a radio machine like how she works on a biological system, life could be hard. Because the full mechanism of the radio system is not revealed, poking small local functionalities might give some hints but it can hardly present all the interactions within the system, let alone the entire working flow. No matter whether you think it is relevant to DL, it is a very fun read.</p>
<p>I would like to discuss a couple of papers on generalizability and complexity measurement of deep learning models in the post. Hopefully, it could shed light on your thinking path towards the understanding of why DNN can generalize.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#classic-theorems-on-compression-and-model-selection" id="markdown-toc-classic-theorems-on-compression-and-model-selection">Classic Theorems on Compression and Model Selection</a> <ul>
<li><a href="#occams-razor" id="markdown-toc-occams-razor">Occam’s Razor</a></li>
<li><a href="#minimum-description-length-principle" id="markdown-toc-minimum-description-length-principle">Minimum Description Length principle</a></li>
<li><a href="#kolmogorov-complexity" id="markdown-toc-kolmogorov-complexity">Kolmogorov Complexity</a></li>
<li><a href="#solomonoffs-inference-theory" id="markdown-toc-solomonoffs-inference-theory">Solomonoff’s Inference Theory</a></li>
</ul>
</li>
<li><a href="#expressive-power-of-dl-models" id="markdown-toc-expressive-power-of-dl-models">Expressive Power of DL Models</a> <ul>
<li><a href="#universal-approximation-theorem" id="markdown-toc-universal-approximation-theorem">Universal Approximation Theorem</a></li>
<li><a href="#proof-finite-sample-expressivity-of-two-layer-nn" id="markdown-toc-proof-finite-sample-expressivity-of-two-layer-nn">Proof: Finite Sample Expressivity of Two-layer NN</a></li>
<li><a href="#deep-nn-can-learn-random-noise" id="markdown-toc-deep-nn-can-learn-random-noise">Deep NN can Learn Random Noise</a></li>
</ul>
</li>
<li><a href="#are-deep-learning-models-dramatically-overfitted" id="markdown-toc-are-deep-learning-models-dramatically-overfitted">Are Deep Learning Models Dramatically Overfitted?</a> <ul>
<li><a href="#modern-risk-curve-for-deep-learning" id="markdown-toc-modern-risk-curve-for-deep-learning">Modern Risk Curve for Deep Learning</a></li>
<li><a href="#regularization-is-not-the-key-to-generalization" id="markdown-toc-regularization-is-not-the-key-to-generalization">Regularization is not the Key to Generalization</a></li>
<li><a href="#intrinsic-dimension" id="markdown-toc-intrinsic-dimension">Intrinsic Dimension</a></li>
<li><a href="#heterogeneous-layer-robustness" id="markdown-toc-heterogeneous-layer-robustness">Heterogeneous Layer Robustness</a></li>
<li><a href="#the-lottery-ticket-hypothesis" id="markdown-toc-the-lottery-ticket-hypothesis">The Lottery Ticket Hypothesis</a></li>
</ul>
</li>
<li><a href="#experiments" id="markdown-toc-experiments">Experiments</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="classic-theorems-on-compression-and-model-selection">Classic Theorems on Compression and Model Selection</h2>
<p>Let’s say we have a classification problem and a dataset, we can develop many models to solve it, from fitting a simple linear regression to memorizing the full dataset in disk space. Which one is better? If we only care about the accuracy over training data (especially given that testing data is likely unknown), the memorization approach seems to be the best — well, it doesn’t sound right.</p>
<p>There are many classic theorems to guide us when deciding what types of properties a good model should possess in such scenarios.</p>
<h3 id="occams-razor">Occam’s Razor</h3>
<p><a href="http://pespmc1.vub.ac.be/OCCAMRAZ.html">Occam’s Razor</a> is an informal principle for problem-solving, proposed by <a href="https://en.wikipedia.org/wiki/William_of_Ockham">William of Ockham</a> in the 14th century:</p>
<blockquote>
<p>“Simpler solutions are more likely to be correct than complex ones.”</p>
</blockquote>
<p>The statement is extremely powerful when we are facing multiple candidates of underlying theories to explain the world and have to pick one. Too many unnecessary assumptions might seem to be plausible for one problem, but harder to be generalized to other complications or to eventually lead to the basic principles of the universe.</p>
<p>Think of this, it took people hundreds of years to figure out that the sky is blue in the daytime but reddish at sunset are because of the same reason (<a href="https://en.wikipedia.org/wiki/Rayleigh_scattering">Rayleigh scattering</a>), although two phenomena look very different. People must have proposed many other explanations for them separately but the unified and simple version won eventually.</p>
<h3 id="minimum-description-length-principle">Minimum Description Length principle</h3>
<p>The principle of Occam’s Razor can be similarly applied to machine learning models. A formalized version of such concept is called the <em>Minimum Description Length (MDL)</em> principle, used for comparing competing models / explanations given data observed.</p>
<blockquote>
<p>“Comprehension is compression.”</p>
</blockquote>
<p>The fundamental idea in MDL is to <em>view learning as data compression</em>. By compressing the data, we need to discover regularity or patterns in the data with the high potentiality to generalize to unseen samples. <a href="/lil-log/2017/09/28/anatomize-deep-learning-with-information-theory.html">Information bottleneck</a> theory believes that a deep neural network is trained first to represent the data by minimizing the generalization error and then learn to compress this representation by trimming noise.</p>
<p>Meanwhile, MDL considers the model description as part of the compression delivery, so the model cannot be arbitrarily large.</p>
<p>A <em>two-part version</em> of MDL principle states that: Let <script type="math/tex">\mathcal{H}^{(1)}, \mathcal{H}^{(2)}, \dots</script> be a list of models that can explain the dataset <script type="math/tex">\mathcal{D}</script>. The best hypothesis among them should be the one that minimizes the sum:</p>
<script type="math/tex; mode=display">\mathcal{H}^\text{best} = \arg\min_\mathcal{H} [L(\mathcal{H}) + L(\mathcal{D}\vert\mathcal{H})]</script>
<ul>
<li><script type="math/tex">L(\mathcal{H})</script> is the length of the description of model <script type="math/tex">\mathcal{H}</script> in bits.</li>
<li><script type="math/tex">L(\mathcal{D}\vert\mathcal{H})</script> is the length of the description of the data <script type="math/tex">\mathcal{D}</script> in bits when encoded with <script type="math/tex">\mathcal{H}</script>.</li>
</ul>
<p>In simple words, the <em>best</em> model is the <em>smallest</em> model containing the encoded data and the model itself. Following this criterion, the memorization approach I proposed at the beginning of the section sounds horrible no matter how good accuracy it can achieve on the training data.</p>
<p>People might argue Occam’s Razor is wrong, as given the real world can be arbitrarily complicated, why do we have to find simple models? One interesting view by MDL is to consider models as <strong>“languages”</strong> instead of fundamental generative theorems. We would like to find good compression strategies to describe regularity in a small set of samples, and they <strong>do not have to be the “real” generative model</strong> for explaining the phenomenon. Models can be wrong but still useful (i.e., think of any Bayesian prior).</p>
<h3 id="kolmogorov-complexity">Kolmogorov Complexity</h3>
<p>Kolmogorov Complexity relies on the concept of modern computers to define the algorithmic (descriptive) complexity of an object: It is <em>the length of the shortest binary computer program that describes the object</em>. Following MDL, a computer is essentially the most general form of data decompressor.</p>
<p>The formal definition of Kolmogorov Complexity states that: Given a universal computer <script type="math/tex">\mathcal{U}</script> and a program <script type="math/tex">p</script>, let’s denote <script type="math/tex">\mathcal{U}(p)</script> as the output of the computer processing the program and <script type="math/tex">L(p)</script> as the descriptive length of the program. Then Kolmogorov Complexity <script type="math/tex">K_\mathcal{U}</script> of a string <script type="math/tex">s</script> with respect to a universal computer <script type="math/tex">\mathcal{U}</script> is:</p>
<script type="math/tex; mode=display">K_\mathcal{U}(s) = \min_{p: \mathcal{U}(p)=s} L(p)</script>
<p>Note that a universal computer is one that can mimic the actions of any other computers. All modern computers are universal as they can all be reduced to Turing machines. The definition is universal no matter which computers we are using, because another universal computer can always be programmed to clone the behavior of <script type="math/tex">\mathcal{U}</script>, while encoding this clone program is just a constant.</p>
<p>There are a lot of connections between Kolmogorov Complexity and Shannon Information Theory, as both are tied to universal coding. It is an amazing fact that the expected Kolmogorov Complexity of a random variable is approximately equal to its Shannon entropy (see Sec 2.3 of <a href="https://homepages.cwi.nl/~paulv/papers/info.pdf">the report</a>). More on this topic is out of the scope here, but there are many interesting readings online. Help yourself :)</p>
<h3 id="solomonoffs-inference-theory">Solomonoff’s Inference Theory</h3>
<p>Another mathematical formalization of Occam’s Razor is Solomonoff’s theory of universal inductive inference (<a href="https://www.sciencedirect.com/science/article/pii/S0019995864902232">Solomonoff</a>, <a href="https://www.sciencedirect.com/science/article/pii/S0019995864901317">1964</a>). The principle is to favor models that correspond to the “shortest program” to produce the training data, based on its Kolmogorov complexity</p>
<h2 id="expressive-power-of-dl-models">Expressive Power of DL Models</h2>
<p>Deep neural networks have an extremely large number of parameters compared to the traditional statistical models. If we use MDL to measure the complexity of a deep neural network and consider the number of parameters as the model description length, it would look awful. The model description <script type="math/tex">L(\mathcal{H})</script> can easily grow out of control.</p>
<p>However, having numerous parameters is <em>necessary</em> for a neural network to obtain high expressivity power. Because of its great capability to capture any flexible data representation, deep neural networks have achieved great success in many applications.</p>
<h3 id="universal-approximation-theorem">Universal Approximation Theorem</h3>
<p>The <em>Universal Approximation Theorem</em> states that a feedforward network with: 1) a linear output layer, 2) at least one hidden layer containing a finite number of neurons and 3) some activation function can approximate <strong>any</strong> continuous functions on a compact subset of <script type="math/tex">\mathbb{R}^n</script> to arbitrary accuracy. The theorem was first proved for sigmoid activation function (<a href="https://pdfs.semanticscholar.org/05ce/b32839c26c8d2cb38d5529cf7720a68c3fab.pdf">Cybenko, 1989</a>). Later it was shown that the universal approximation property is not specific to the choice of activation (<a href="http://zmjones.com/static/statistical-learning/hornik-nn-1991.pdf">Hornik, 1991</a>) but the multilayer feedforward architecture.</p>
<p>Although a feedforward network with a single layer is sufficient to represent any function, the width has to be exponentially large. The universal approximation theorem does not guarantee whether the model can be learned or generalized properly. Often, adding more layers helps to reduce the number of hidden neurons needed in a shallow network.</p>
<p>To take advantage of the universal approximation theorem, we can always find a neural network to represent the target function with error under any desired threshold, but we need to pay the price — the network might grow super large.</p>
<h3 id="proof-finite-sample-expressivity-of-two-layer-nn">Proof: Finite Sample Expressivity of Two-layer NN</h3>
<p>The Universal Approximation Theorem we have discussed so far does not consider a finite sample set. <a href="https://arxiv.org/abs/1611.03530">Zhang, et al. (2017)</a> provided a neat proof on the finite-sample expressivity of two-layer neural networks.</p>
<p>A neural network <script type="math/tex">C</script> can represent any function given a sample size <script type="math/tex">n</script> in <script type="math/tex">d</script> dimensions if: For every finite sample set <script type="math/tex">S \subseteq \mathbb{R}^d</script> with <script type="math/tex">\vert S \vert = n</script> and every function defined on this sample set: <script type="math/tex">f: S \mapsto \mathbb{R}</script>, we can find a set of weight configuration for <script type="math/tex">C</script> so that <script type="math/tex">C(\boldsymbol{x}) = f(\boldsymbol{x}), \forall \boldsymbol{x} \in S</script>.</p>
<p>The paper proposed a theorem:</p>
<blockquote>
<p>There exists a two-layer neural network with ReLU activations and <script type="math/tex">2n + d</script> weights that can represent any function on a sample of size <script type="math/tex">n</script> in <script type="math/tex">d</script> dimensions.</p>
</blockquote>
<p><em>Proof.</em> First we would like to construct a two-layer neural network <script type="math/tex">C: \mathbb{R}^d \mapsto \mathbb{R}</script>. The input is a <script type="math/tex">d</script>-dimensional vector, <script type="math/tex">\boldsymbol{x} \in \mathbb{R}^d</script>. The hidden layer has <script type="math/tex">h</script> hidden units, associated with a weight matrix <script type="math/tex">\mathbf{W} \in \mathbb{R}^{d\times h}</script>, a bias vector <script type="math/tex">-\mathbf{b} \in \mathbb{R}^h</script> and ReLU activation function. The second layer outputs a scalar value with weight vector <script type="math/tex">\boldsymbol{v} \in \mathbb{R}^h</script> and zero biases.</p>
<p>The output of network <script type="math/tex">C</script> for a input vector <script type="math/tex">\boldsymbol{x}</script> can be represented as follows:</p>
<script type="math/tex; mode=display">C(\boldsymbol{x})
= \boldsymbol{v} \max\{ \boldsymbol{x}\mathbf{W} - \boldsymbol{b}, 0\}^\top
= \sum_{i=1}^h v_i \max\{\boldsymbol{x}\boldsymbol{W}_{(:,i)} - b_i, 0\}</script>
<p>where <script type="math/tex">\boldsymbol{W}_{(:,i)}</script> is the <script type="math/tex">i</script>-th column in the <script type="math/tex">d \times h</script> matrix.</p>
<p>Given a sample set <script type="math/tex">S = \{\boldsymbol{x}_1, \dots, \boldsymbol{x}_n\}</script> and target values <script type="math/tex">\boldsymbol{y} = \{y_1, \dots, y_n \}</script>, we would like to find proper weights <script type="math/tex">\mathbf{W} \in \mathbb{R}^{d\times h}</script>, <script type="math/tex">\boldsymbol{b}, \boldsymbol{v} \in \mathbb{R}^h</script> so that <script type="math/tex">C(\boldsymbol{x}_i) = y_i, \forall i=1,\dots,n</script>.</p>
<p>Let’s combine all sample points into one batch as one input matrix <script type="math/tex">\mathbf{X} \in \mathbb{R}^{n \times d}</script>. If set <script type="math/tex">h=n</script>, <script type="math/tex">\mathbf{X}\mathbf{W} - \boldsymbol{b}</script> would be a square matrix of size <script type="math/tex">n \times n</script>.</p>
<script type="math/tex; mode=display">\mathbf{M}_\text{ReLU}
= \max\{\mathbf{X}\mathbf{W} - \boldsymbol{b}, 0 \}
= \begin{bmatrix}
\boldsymbol{x}_1\mathbf{W} - \boldsymbol{b} \\
\dots \\
\boldsymbol{x}_n\mathbf{W} - \boldsymbol{b} \\
\end{bmatrix}
= [\boldsymbol{x}_i\boldsymbol{W}_{(:,j)} - b_j]_{i \times j}</script>
<p>We can simplify <script type="math/tex">\mathbf{W}</script> to have the same column vectors across all the columns:</p>
<script type="math/tex; mode=display">\mathbf{W}_{(:,j)} = \boldsymbol{w} \in \mathbb{R}^{d}, \forall j = 1, \dots, n</script>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/nn-expressivity-proof.png" alt="intrinsic dimension experiment 1" /></p>
<p>Let <script type="math/tex">a_i = \boldsymbol{x}_i \boldsymbol{w}</script>, we would like to find a suitable <script type="math/tex">\boldsymbol{w}</script> and <script type="math/tex">\boldsymbol{b}</script> such that <script type="math/tex">% <![CDATA[
b_1 < a_1 < b_2 < a_2 < \dots < b_n < a_n %]]></script>. This is always achievable because we try to solve <script type="math/tex">n+d</script> unknown variables with <script type="math/tex">n</script> constraints and <script type="math/tex">\boldsymbol{x}_i</script> are independent (i.e. pick a random <script type="math/tex">\boldsymbol{w}</script>, sort <script type="math/tex">\boldsymbol{x}_i \boldsymbol{w}</script> and then set <script type="math/tex">b_j</script>’s as values in between). Then <script type="math/tex">\mathbf{M}_\text{ReLU}</script> becomes a lower triangular matrix:</p>
<script type="math/tex; mode=display">% <![CDATA[
\mathbf{M}_\text{ReLU} = [a_i - b_j]_{i \times j}
= \begin{bmatrix}
a_1 - b_1 & 0 & 0 & \dots & 0 \\
\vdots & \ddots & & & \vdots \\
a_i - b_1 & \dots & a_i - b_i & \dots & 0\\
\vdots & & & \ddots & \vdots \\
a_n - b_1 & a_n - b_2 & \dots & \dots & a_n - b_n \\
\end{bmatrix} %]]></script>
<p>It is a nonsingular square matrix as <script type="math/tex">\det(\mathbf{M}_\text{ReLU}) \neq 0</script>, so we can always find suitable <script type="math/tex">\boldsymbol{v}</script> to solve <script type="math/tex">\boldsymbol{v}\mathbf{M}_\text{ReLU}=\boldsymbol{y}</script> (In other words, the column space of <script type="math/tex">\mathbf{M}_\text{ReLU}</script> is all of <script type="math/tex">\mathbb{R}^n</script> and we can find a linear combination of column vectors to obtain any <script type="math/tex">\boldsymbol{y}</script>).</p>
<h3 id="deep-nn-can-learn-random-noise">Deep NN can Learn Random Noise</h3>
<p>As we know two-layer neural networks are universal approximators, it is less surprising to see that they are able to learn unstructured random noise perfectly, as shown in <a href="https://arxiv.org/abs/1611.03530">Zhang, et al. (2017)</a>. If labels of image classification dataset are randomly shuffled, the high expressivity power of deep neural networks can still empower them to achieve near-zero training loss. These results do not change with regularization terms added.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/fit-random-labels.png" alt="Fitting random labels" /></p>
<p><em>Fig. 1. Fit models on CIFAR10 with random labels or random pixels: (a) learning curves; (b-c) label corruption ratio is the percentage of randomly shuffled labels. (Image source: <a href="https://arxiv.org/abs/1611.03530">Zhang’s paper</a>)</em></p>
<h2 id="are-deep-learning-models-dramatically-overfitted">Are Deep Learning Models Dramatically Overfitted?</h2>
<p>Deep learning models are heavily over-parameterized and can often get to perfect results on training data. In the traditional view, like bias-variance trade-offs, this could be a disaster that nothing may generalize to the unseen test data. However, as is often the case, such “overfitted” (training error = 0) deep learning models still present a decent performance on out-of-sample test data. Hmm … interesting and why?</p>
<h3 id="modern-risk-curve-for-deep-learning">Modern Risk Curve for Deep Learning</h3>
<p>The traditional machine learning uses the following U-shape risk curve to measure the bias-variance trade-offs and quantify how generalizable a model is. If I get asked how to tell whether a model is overfitted, this would be the first thing popping into my mind.</p>
<p>As the model turns larger (more parameters added), the training error decreases to close to zero, but the test error (generalization error) starts to increase once the model complexity grows to pass the threshold between “underfitting” and “overfitting”. In a way, this is well aligned with Occam’s Razor.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/bias-variance-risk-curve.png" alt="Bias-variance risk curve" /></p>
<p><em>Fig. 2. U-shaped bias-variance risk curve. (Image source: (left) <a href="https://arxiv.org/abs/1812.11118">paper</a> (right) <a href="http://scott.fortmann-roe.com/docs/BiasVariance.html">fig. 6 of this post</a>)</em></p>
<p>Unfortunately this does not apply to deep learning models. <a href="https://arxiv.org/abs/1812.11118">Belkin et al. (2018)</a> reconciled the traditional bias-variance trade-offs and proposed a new double-U-shaped risk curve for deep neural networks. Once the number of network parameters is high enough, the risk curve enters another regime.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/new-bias-variance-risk-curve.png" alt="new risk curve" /></p>
<p><em>Fig. 3. A new double-U-shaped bias-variance risk curve for deep neural networks. (Image source: <a href="https://arxiv.org/abs/1812.11118">original paper</a>)</em></p>
<p>The paper claimed that it is likely due to two reasons:</p>
<ul>
<li>The number of parameters is not a good measure of <em>inductive bias</em>, defined as the set of assumptions of a learning algorithm used to predict for unknown samples. See more discussion on DL model complexity in <a href="#intrinsic-dimension">later</a> <a href="#heterogeneous-layer-robustness">sections</a>.</li>
<li>Equipped with a larger model, we might be able to discover larger function classes and further find interpolating functions that have smaller norm and are thus “simpler”.</li>
</ul>
<p>The double-U-shaped risk curve was observed empirically, as shown in the paper. However I was struggling quite a bit to reproduce the results. There are some signs of life, but in order to generate a pretty smooth curve similar to the theorem, <a href="#experiments">many details</a> in the experiment have to be taken care of.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/new-risk-curve-mnist.png" alt="New risk curve on MNIST" /></p>
<p><em>Fig. 4. Training and evaluation errors of a one hidden layer fc network of different numbers of hidden units, trained on 4000 data points sampled from MNIST. (Image source: <a href="https://arxiv.org/abs/1812.11118">original paper</a>)</em></p>
<h3 id="regularization-is-not-the-key-to-generalization">Regularization is not the Key to Generalization</h3>
<p>Regularization is a common way to control overfitting and improve model generalization performance. Interestingly some research (<a href="https://arxiv.org/abs/1611.03530">Zhang, et al. 2017</a>) has shown that explicit regularization (i.e. data augmentation, weight decay and dropout) is neither necessary or sufficient for reducing generalization error.</p>
<p>Taking the Inception model trained on CIFAR10 as an example (see Fig. 5), regularization techniques help with out-of-sample generalization but not much. No single regularization seems to be critical independent of other terms. Thus, it is unlikely that regularizers are the <em>fundamental reason</em> for generalization.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/regularization-generalization-test.png" alt="regularization test" /></p>
<p><em>Fig. 5. The accuracy of Inception model trained on CIFAR10 with different combinations of taking on or off data augmentation and weight decay. (Image source: Table 1 in the <a href="https://arxiv.org/abs/1611.03530">original paper</a>)</em></p>
<h3 id="intrinsic-dimension">Intrinsic Dimension</h3>
<p>The number of parameters is not correlated with model overfitting in the field of deep learning, suggesting that parameter counting cannot indicate the true complexity of deep neural networks.</p>
<p>Apart from parameter counting, researchers have proposed many ways to quantify the complexity of these models, such as the number of degrees of freedom of models (<a href="https://arxiv.org/abs/1603.09260">Gao & Jojic, 2016</a>), or prequential code (<a href="https://arxiv.org/abs/1802.07044">Blier & Ollivier, 2018</a>).</p>
<p>I would like to discuss a recent method on this matter, named <strong>intrinsic dimension</strong> (<a href="https://arxiv.org/abs/1804.08838">Li et al, 2018</a>). Intrinsic dimension is intuitive, easy to measure, while still revealing many interesting properties of models of different sizes.</p>
<p>Considering a neural network with a great number of parameters, forming a high-dimensional parameter space, the learning happens on this high-dimensional <em>objective landscape</em>.
The shape of the parameter space manifold is critical. For example, a smoother manifold is beneficial for optimization by providing more predictive gradients and allowing for larger learning rates—this was claimed to be the reason why batch normalization has succeeded in stabilizing training (<a href="https://arxiv.org/abs/1805.11604">Santurkar, et al, 2019</a>).</p>
<p>Even though the parameter space is huge, fortunately we don’t have to worry too much about the optimization process getting stuck in local optima, as it has been <a href="https://arxiv.org/abs/1406.2572">shown</a> that local optimal points in the objective landscape almost always lay in saddle-points rather than valleys. In other words, there is always a subset of dimensions containing paths to leave local optima and keep on exploring.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/optimization-landscape-shape.png" alt="parameter landscape shape" /></p>
<p><em>Fig. 6. Illustrations of various types of critical points on the parameter optimization landscape. (Image source: <a href="https://www.offconvex.org/2016/03/22/saddlepoints/">here</a>)</em></p>
<p>One intuition behind the measurement of intrinsic dimension is that, since the parameter space has such high dimensionality, it is probably not necessary to exploit all the dimensions to learn efficiently. If we only travel through a slice of objective landscape and still can learn a good solution, the complexity of the resulting model is likely lower than what it appears to be by parameter-counting. This is essentially what intrinsic dimension tries to assess.</p>
<p>Say a model has <script type="math/tex">D</script> dimensions and its parameters are denoted as <script type="math/tex">\theta^{(D)}</script>. For learning, a smaller <script type="math/tex">d</script>-dimensional subspace is randomly sampled, <script type="math/tex">\theta^{(d)}</script>, where <script type="math/tex">% <![CDATA[
d < D %]]></script>. During one optimization update, rather than taking a gradient step according to all <script type="math/tex">D</script> dimensions, only the smaller subspace <script type="math/tex">\theta^{(d)}</script> is used and remapped to update model parameters.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/intrinsic-dimension-illustration.png" alt="illustration" /></p>
<p><em>Fig. 7. Illustration of parameter vectors for direct optimization when <script type="math/tex">D=3</script>. (Image source: <a href="https://arxiv.org/abs/1804.08838">original paper</a>)</em></p>
<p>The gradient update formula looks like the follows:</p>
<script type="math/tex; mode=display">\theta^{(D)} = \theta_0^{(D)} + \mathbf{P} \theta^{(d)}</script>
<p>where <script type="math/tex">\theta_0^{(D)}</script> are the initialization values and <script type="math/tex">\mathbf{P}</script> is a <script type="math/tex">D \times d</script> projection matrix that is randomly sampled before training. Both <script type="math/tex">\theta_0^{(D)}</script> and <script type="math/tex">\mathbf{P}</script> are not trainable and fixed during training. <script type="math/tex">\theta^{(d)}</script> is initialized as all zeros.</p>
<p>By searching through the value of <script type="math/tex">d = 1, 2, \dots, D</script>, the corresponding <script type="math/tex">d</script> when the solution emerges is defined as the <em>intrinsic dimension</em>.</p>
<p>It turns out many problems have much smaller intrinsic dimensions than the number of parameters. For example, on CIFAR10 image classification, a fully-connected network with 650k+ parameters has only 9k intrinsic dimension and a convolutional network containing 62k parameters has an even lower intrinsic dimension of 2.9k.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/intrinsic-dimension.png" alt="intrinsic dimension results" /></p>
<p><em>Fig. 8. The measured intrinsic dimensions <script type="math/tex">d</script> for various models achieving 90% of the best performance. (Image source: <a href="https://arxiv.org/abs/1804.08838">original paper</a>)</em></p>
<p>The measurement of intrinsic dimensions suggests that deep learning models are significantly simpler than what they might appear to be.</p>
<h3 id="heterogeneous-layer-robustness">Heterogeneous Layer Robustness</h3>
<p><a href="https://arxiv.org/abs/1902.01996">Zhang et al. (2019)</a> investigated the role of parameters in different layers. The fundamental question raised by the paper is: <em>“are all layers created equal?”</em> The short answer is: No. The model is more sensitive to changes in some layers but not others.</p>
<p>The paper proposed two types of operations that can be applied to parameters of the <script type="math/tex">\ell</script>-th layer, <script type="math/tex">\ell = 1, \dots, L</script>, at time <script type="math/tex">t</script>, <script type="math/tex">\theta^{(\ell)}_t</script> to test their impacts on model robustness:</p>
<ul>
<li>
<p><strong>Re-initialization</strong>: Reset the parameters to the initial values, <script type="math/tex">\theta^{(\ell)}_t \leftarrow \theta^{(\ell)}_0</script>. The performance of a network in which layer <script type="math/tex">\ell</script> was re-initialized is referred to as the <em>re-initialization robustness</em> of layer <script type="math/tex">\ell</script>.</p>
</li>
<li>
<p><strong>Re-randomization</strong>: Re-sampling the layer’s parameters at random, <script type="math/tex">\theta^{(\ell)}_t \leftarrow \tilde{\theta}^{(\ell)} \sim \mathcal{P}^{(\ell)}</script>. The corresponding network performance is called the <em>re-randomization robustness</em> of layer <script type="math/tex">\ell</script>.</p>
</li>
</ul>
<p>Layers can be categorized into two categories with the help of these two operations:</p>
<ul>
<li><strong>Robust Layers</strong>: The network has no or only negligible performance degradation after re-initializing or re-randomizing the layer.</li>
<li><strong>Critical Layers</strong>: Otherwise.</li>
</ul>
<p>Similar patterns are observed on fully-connected and convolutional networks. Re-randomizing any of the layers <em>completely destroys</em> the model performance, as the prediction drops to random guessing immediately. More interestingly and surprisingly, when applying re-initialization, only the first or the first few layers (those closest to the input layer) are critical, while re-initializing higher levels causes <em>only negligible decrease</em> in performance.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/layer-robustness-results.png" alt="Re-initialization robustness" /></p>
<p><em>Fig. 9. (a) A fc network trained on MNIST. Each row corresponds to one layer in the network. The first column is re-randomization robustness of each layer and the rest of the columns indicate re-initialization robustness at different training time. (b) VGG11 model (conv net) trained on CIFAR 10. Similar representation as in (a) but rows and columns are transposed. (Image source: <a href="https://arxiv.org/abs/1902.01996">original paper</a>)</em></p>
<p>ResNet is able to use shortcuts between non-adjacent layers to re-distribute the sensitive layers across the networks rather than just at the bottom. With the help of residual block architecture, the network can <em>evenly be robust to re-randomization</em>. Only the first layer of each residual block is still sensitive to both re-initialization and re-randomization. If we consider each residual block as a local sub-network, the robustness pattern resembles the fc and conv nets above.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/layer-robustness-resnet.png" alt="ResNet robustness" /></p>
<p><em>Fig. 10. Re-randomization (first row) and re-initialization (the reset rows) robustness of layers in ResNet-50 model trained on CIFAR10. (Image source: <a href="https://arxiv.org/abs/1902.01996">original paper</a>)</em></p>
<p>Based on the fact that many top layers in deep neural networks are not critical to the model performance after re-initialization, the paper loosely concluded that:</p>
<blockquote>
<p>“Over-capacitated deep networks trained with stochastic gradient have low-complexity due to self-restricting the number of critical layers.”</p>
</blockquote>
<p>We can consider re-initialization as a way to reduce the effective number of parameters, and thus the observation is aligned with what intrinsic dimension has demonstrated.</p>
<h3 id="the-lottery-ticket-hypothesis">The Lottery Ticket Hypothesis</h3>
<p>The lottery ticket hypothesis (<a href="https://arxiv.org/abs/1803.03635">Frankle & Carbin, 2019</a>) is another intriguing and inspiring discovery, supporting that only a subset of network parameters have impact on the model performance and thus the network is not overfitted. The lottery ticket hypothesis states that a randomly initialized, dense, feed-forward network contains a pool of subnetworks and among them only a subset are <em>“winning tickets”</em> which can achieve the optimal performance when <em>trained in isolation</em>.</p>
<p>The idea is motivated by network pruning techniques — removing unnecessary weights (i.e. tiny weights that are almost negligible) without harming the model performance. Although the final network size can be reduced dramatically, it is hard to train such a pruned network architecture successfully from scratch. It feels like in order to successfully train a neural network, we need a large number of parameters, but we don’t need that many parameters to keep the accuracy high once the model is trained. Why is that?</p>
<p>The lottery ticket hypothesis did the following experiments:</p>
<ol>
<li>Randomly initialize a dense feed-forward network with initialization values <script type="math/tex">\theta_0</script>;</li>
<li>Train the network for multiple iterations to achieve a good performance with parameter config <script type="math/tex">\theta</script>;</li>
<li>Run pruning on <script type="math/tex">\theta</script> and creating a mask <script type="math/tex">m</script>.</li>
<li>The “winning ticket” initialization config is <script type="math/tex">m \odot \theta_0</script>.</li>
</ol>
<p>Only training the small “winning ticket” subset of parameters with the initial values as found in step 1, the model is able to achieve the same level of accuracy as in step 2. It turns out a large parameter space is not needed in the final solution representation, but needed for training as it provides a big pool of initialization configs of many much smaller subnetworks.</p>
<p>The lottery ticket hypothesis opens a new perspective about interpreting and dissecting deep neural network results. Many interesting following-up works are on the way.</p>
<h2 id="experiments">Experiments</h2>
<p>After seeing all the interesting findings above, it should be pretty fun to reproduce them. Some results are easily to reproduce than others. Details are described below. My code is available on github <a href="https://github.com/lilianweng/generalization-experiment">lilianweng/generalization-experiment</a>.</p>
<p><strong>New Risk Curve for DL Models</strong></p>
<p>This is the trickiest one to reproduce. The authors did give me a lot of good advice and I appreciate it a lot. Here are a couple of noticeable settings in their experiments:</p>
<ul>
<li>There are no regularization terms like weight decay, dropout.</li>
<li>In Fig 3, the training set contains 4k samples. It is only sampled once and fixed for all the models. The evaluation uses the full MNIST test set.</li>
<li>Each network is trained for a long time to achieve near-zero training risk. The learning rate is adjusted differently for models of different sizes.</li>
<li>To make the model less sensitive to the initialization in the under-parameterization region, their experiments adopted a <em>“weight reuse”</em> scheme: the parameters obtained from training a smaller neural network are used as initialization for training larger networks.</li>
</ul>
<p>I did not train or tune each model long enough to get perfect training performance, but evaluation error indeed shows a special twist around the interpolation threshold, different from training error. For example, for MNIST, the threshold is the number of training samples times the number of classes (10), that is 40000.</p>
<p>The x-axis is the number of model parameters: (28 * 28 + 1) * num. units + num. units * 10, in logarithm.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/risk_curve_loss-mse_sample-4000_epoch-500.png" alt="risk curve experiment 1" /></p>
<p><br /></p>
<p><strong>Layers are not Created Equal</strong></p>
<p>This one is fairly easy to reproduce. See my implementation <a href="https://github.com/lilianweng/generalization-experiment/blob/master/layer_equality.py">here</a>.</p>
<p>In the first experiment, I used a three-layer fc networks with 256 units in each layer. Layer 0 is the input layer while layer 3 is the output. The network is trained on MNIST for 100 epochs.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/layer_equality_256x3.png" alt="Layer equality experiment 1" /></p>
<p>In the second experiment, I used a four-layer fc networks with 128 units in each layer. Other settings are the same as experiment 1.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/layer_equality_128x4.png" alt="Layer equality experiment 2" /></p>
<p><br /></p>
<p><strong>Intrinsic Dimension Measurement</strong></p>
<p>To correctly map the <script type="math/tex">d</script>-dimensional subspace to the full parameter space, the projection matrix <script type="math/tex">\mathbf{P}</script> should have orthogonal columns. Because the production <script type="math/tex">\mathbf{P}\theta^{(d)}</script> is the sum of columns of <script type="math/tex">\mathbf{P}</script> scaled by corresponding scalar values in the <script type="math/tex">d</script>-dim vector, <script type="math/tex">\sum_{i=1}^d \theta^{(d)}_i \mathbf{P}^\top_{(:,i)}</script>, it is better to fully utilize the subspace with orthogonal columns in <script type="math/tex">\mathbf{P}</script>.</p>
<p>My implementation follows a naive approach by sampling a large matrix with independent entries from a standard normal distribution. The columns are expected to be independent in a high dimension space and thus to be orthogonal. This works when the dimension is not too large. When exploring with a large <script type="math/tex">d</script>, there are methods for creating sparse projection matrices, which is what the intrinsic dimension paper suggested.</p>
<p>Here are experiment runs on two networks: (left) a two-layer fc network with 64 units in each layer and (right) a one-layer fc network with 128 hidden units, trained on 10% of MNIST. For every <script type="math/tex">d</script>, the model is trained for 100 epochs. See the <a href="https://github.com/lilianweng/generalization-experiment/blob/master/intrinsic_dimensions.py">code</a> <a href="https://github.com/lilianweng/generalization-experiment/blob/master/intrinsic_dimensions_measurement.py">here</a>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/intrinsic-dimension-net-64-64-and-128.png" alt="intrinsic dimension experiment 1" /></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019overfit,
title = "Are Deep Neural Networks Dramatically Overfitted?",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/03/14/are-deep-neural-networks-dramatically-overfitted.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Wikipedia page on <a href="https://en.wikipedia.org/wiki/Occam%27s_razor">Occam’s Razor</a>.</p>
<p>[2] <a href="http://pespmc1.vub.ac.be/OCCAMRAZ.html">Occam’s Razor</a> on Principia Cybernetica Web.</p>
<p>[3] Peter Grunwald. <a href="https://arxiv.org/abs/math/0406077">“A Tutorial Introduction to the Minimum Description Length Principle”</a>. 2004.</p>
<p>[4] Ian Goodfellow, et al. <a href="https://www.deeplearningbook.org/">Deep Learning</a>. 2016. <a href="https://www.deeplearningbook.org/contents/mlp.html">Sec 6.4.1</a>.</p>
<p>[5] Zhang, Chiyuan, et al. <a href="https://arxiv.org/abs/1611.03530">“Understanding deep learning requires rethinking generalization.”</a> ICLR 2017.</p>
<p>[6] Shibani Santurkar, et al. <a href="https://arxiv.org/abs/1805.11604">“How does batch normalization help optimization?.”</a> NIPS 2018.</p>
<p>[7] Mikhail Belkin, et al. <a href="https://arxiv.org/abs/1812.11118">“Reconciling modern machine learning and the bias-variance trade-off.”</a> arXiv:1812.11118, 2018.</p>
<p>[8] Chiyuan Zhang, et al. <a href="https://arxiv.org/abs/1902.01996">“Are All Layers Created Equal?”</a> arXiv:1902.01996, 2019.</p>
<p>[9] Chunyuan Li, et al. <a href="https://arxiv.org/abs/1804.08838">“Measuring the intrinsic dimension of objective landscapes.”</a> ICLR 2018.</p>
<p>[10] Jonathan Frankle and Michael Carbin. <a href="https://arxiv.org/abs/1803.03635">“The lottery ticket hypothesis: Finding sparse, trainable neural networks.”</a> ICLR 2019.</p>Lilian WengIf you are, like me, confused by why deep neural networks can generalize to out-of-sample data points without drastic overfitting, keep on reading.Generalized Language Models2019-01-31T12:00:00+00:002019-01-31T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/01/31/generalized-language-models<blockquote>
<p>As a follow up of word embedding post, we will discuss the models on learning contextualized word vectors, as well as the new trend in large unsupervised pre-trained language models which have achieved amazing SOTA results on a variety of language tasks.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2019-02-14: add <a href="#ulmfit">ULMFiT</a> and <a href="#openai-gpt-2">OpenAI GPT-2</a>.]</span><br />
<span style="color: #286ee0;">[Updated on 2020-02-29: add <a href="#albert">ALBERT</a>.</span></p>
<p style="width: 60%;" class="center"><br />
<img src="/lil-log/assets/images/elmo-and-bert.png" alt="Elmo & Bert" /></p>
<p><em>Fig. 0. I guess they are Elmo & Bert? (Image source: <a href="https://www.youtube.com/watch?v=l5einDQ-Ttc">here</a>)</em>
<br /></p>
<p>We have seen amazing progress in NLP in 2018. Large-scale pre-trained language modes like <a href="https://blog.openai.com/language-unsupervised/">OpenAI GPT</a> and <a href="https://arxiv.org/abs/1810.04805">BERT</a> have achieved great performance on a variety of language tasks using generic model architectures. The idea is similar to how ImageNet classification pre-training helps many vision tasks (*). Even better than vision classification pre-training, this simple and powerful approach in NLP does not require labeled data for pre-training, allowing us to experiment with increased training scale, up to our very limit.</p>
<p><em>(*) Although recently He et al. (2018) <a href="https://arxiv.org/abs/1811.08883">found</a> that pre-training might not be necessary for image segmentation task.</em></p>
<p>In my previous NLP <a href="/lil-log/2017/10/15/learning-word-embedding.html">post on word embedding</a>, the introduced embeddings are not context-specific — they are learned based on word concurrency but not sequential context. So in two sentences, “<em>I am eating an apple</em>” and “<em>I have an Apple phone</em>”, two “apple” words refer to very different things but they would still share the same word embedding vector.</p>
<p>Despite this, early adoption of word embeddings in problem-solving is to use them as additional features for an existing task-specific model and in a way the improvement is bounded.</p>
<p>In this post, we will discuss how various approaches were proposed to make embeddings dependent on context, and to make them easier and cheaper to be applied to downstream tasks in general form.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#cove" id="markdown-toc-cove">CoVe</a> <ul>
<li><a href="#nmt-recap" id="markdown-toc-nmt-recap">NMT Recap</a></li>
<li><a href="#use-cove-in-downstream-tasks" id="markdown-toc-use-cove-in-downstream-tasks">Use CoVe in Downstream Tasks</a></li>
</ul>
</li>
<li><a href="#elmo" id="markdown-toc-elmo">ELMo</a> <ul>
<li><a href="#bidirectional-language-model" id="markdown-toc-bidirectional-language-model">Bidirectional Language Model</a></li>
<li><a href="#elmo-representations" id="markdown-toc-elmo-representations">ELMo Representations</a></li>
<li><a href="#use-elmo-in-downstream-tasks" id="markdown-toc-use-elmo-in-downstream-tasks">Use ELMo in Downstream Tasks</a></li>
</ul>
</li>
<li><a href="#cross-view-training" id="markdown-toc-cross-view-training">Cross-View Training</a> <ul>
<li><a href="#model-architecture" id="markdown-toc-model-architecture">Model Architecture</a></li>
<li><a href="#multi-task-learning" id="markdown-toc-multi-task-learning">Multi-Task Learning</a></li>
<li><a href="#use-cvt-in-downstream-tasks" id="markdown-toc-use-cvt-in-downstream-tasks">Use CVT in Downstream Tasks</a></li>
</ul>
</li>
<li><a href="#ulmfit" id="markdown-toc-ulmfit">ULMFiT</a></li>
<li><a href="#openai-gpt" id="markdown-toc-openai-gpt">OpenAI GPT</a> <ul>
<li><a href="#transformer-decoder-as-language-model" id="markdown-toc-transformer-decoder-as-language-model">Transformer Decoder as Language Model</a></li>
<li><a href="#byte-pair-encoding" id="markdown-toc-byte-pair-encoding">Byte Pair Encoding</a></li>
<li><a href="#supervised-fine-tuning" id="markdown-toc-supervised-fine-tuning">Supervised Fine-Tuning</a></li>
</ul>
</li>
<li><a href="#bert" id="markdown-toc-bert">BERT</a> <ul>
<li><a href="#pre-training-tasks" id="markdown-toc-pre-training-tasks">Pre-training Tasks</a></li>
<li><a href="#input-embedding" id="markdown-toc-input-embedding">Input Embedding</a></li>
<li><a href="#use-bert-in-downstream-tasks" id="markdown-toc-use-bert-in-downstream-tasks">Use BERT in Downstream Tasks</a></li>
</ul>
</li>
<li><a href="#albert" id="markdown-toc-albert">ALBERT</a> <ul>
<li><a href="#factorized-embedding-parameterization" id="markdown-toc-factorized-embedding-parameterization">Factorized Embedding Parameterization</a></li>
<li><a href="#cross-layer-parameter-sharing" id="markdown-toc-cross-layer-parameter-sharing">Cross-layer Parameter Sharing</a></li>
<li><a href="#sentence-order-prediction-sop" id="markdown-toc-sentence-order-prediction-sop">Sentence-Order Prediction (SOP)</a></li>
</ul>
</li>
<li><a href="#openai-gpt-2" id="markdown-toc-openai-gpt-2">OpenAI GPT-2</a> <ul>
<li><a href="#zero-shot-transfer" id="markdown-toc-zero-shot-transfer">Zero-Shot Transfer</a></li>
<li><a href="#bpe-on-byte-sequences" id="markdown-toc-bpe-on-byte-sequences">BPE on Byte Sequences</a></li>
<li><a href="#model-modifications" id="markdown-toc-model-modifications">Model Modifications</a></li>
</ul>
</li>
<li><a href="#summary" id="markdown-toc-summary">Summary</a></li>
<li><a href="#metric-perplexity" id="markdown-toc-metric-perplexity">Metric: Perplexity</a></li>
<li><a href="#common-tasks-and-datasets" id="markdown-toc-common-tasks-and-datasets">Common Tasks and Datasets</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="cove">CoVe</h2>
<p><strong>CoVe</strong> (<a href="https://arxiv.org/abs/1708.00107">McCann et al. 2017</a>), short for <strong>Contextual Word Vectors</strong>, is a type of word embeddings learned by an encoder in an <a href="/lil-log/2018/06/24/attention-attention.html#born-for-translation">attentional seq-to-seq</a> machine translation model.
Different from traditional word embeddings introduced <a href="/lil-log/2017/10/15/learning-word-embedding.html">here</a>, CoVe word representations are functions of the entire input sentence.</p>
<h3 id="nmt-recap">NMT Recap</h3>
<p>Here the Neural Machine Translation (<a href="https://github.com/THUNLP-MT/MT-Reading-List">NMT</a>) model is composed of a standard, two-layer, bidirectional LSTM encoder and an attentional two-layer unidirectional LSTM decoder. It is pre-trained on the English-German translation task. The encoder learns and optimizes the embedding vectors of English words in order to translate them to German. With the intuition that the encoder should capture high-level semantic and syntactic meanings before transforming words into another language, the encoder output is used to provide contextualized word embeddings for various downstream language tasks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/nmt-recap.png" alt="NMT Recap" /></p>
<p><em>Fig. 1. The NMT base model used in CoVe.</em></p>
<ul>
<li>A sequence of <script type="math/tex">n</script> words in source language (English): <script type="math/tex">x = [x_1, \dots, x_n]</script>.</li>
<li>A sequence of <script type="math/tex">m</script> words in target language (German): <script type="math/tex">y = [y_1, \dots, y_m]</script>.</li>
<li>The <a href="/lil-log/2017/10/15/learning-word-embedding.html#glove-global-vectors">GloVe</a> vectors of source words: <script type="math/tex">\text{GloVe}(x)</script>.</li>
<li>Randomly initialized embedding vectors of target words: <script type="math/tex">z = [z_1, \dots, z_m]</script>.</li>
<li>The biLSTM encoder outputs a sequence of hidden states: <script type="math/tex">h = [h_1, \dots, h_n] = \text{biLSTM}(\text{GloVe}(x))</script> and <script type="math/tex">h_t = [\overrightarrow{h}_t; \overleftarrow{h}_t]</script> where the forward LSTM computes <script type="math/tex">\overrightarrow{h}_t = \text{LSTM}(x_t, \overrightarrow{h}_{t-1})</script> and the backward computation gives us <script type="math/tex">\overleftarrow{h}_t = \text{LSTM}(x_t, \overleftarrow{h}_{t-1})</script>.</li>
<li>The attentional decoder outputs a distribution over words: <script type="math/tex">p(y_t \mid H, y_1, \dots, y_{t-1})</script> where <script type="math/tex">H</script> is a stack of hidden states <script type="math/tex">\{h\}</script> along the time dimension:</li>
</ul>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\text{decoder hidden state: } s_t &= \text{LSTM}([z_{t-1}; \tilde{h}_{t-1}], s_{t-1}) \\
\text{attention weights: } \alpha_t &= \text{softmax}(H(W_1 s_t + b_1)) \\
\text{context-adjusted hidden state: } \tilde{h}_t &= \tanh(W_2[H^\top\alpha_t;s_t] + b_2) \\
\text{decoder output: } p(y_t\mid H, y_1, \dots, y_{t-1}) &= \text{softmax}(W_\text{out} \tilde{h}_t + b_\text{out})
\end{aligned} %]]></script>
<h3 id="use-cove-in-downstream-tasks">Use CoVe in Downstream Tasks</h3>
<p>The hidden states of NMT encoder are defined as <strong>context vectors</strong> for other language tasks:</p>
<script type="math/tex; mode=display">\text{CoVe}(x) = \text{biLSTM}(\text{GloVe}(x))</script>
<p>The paper proposed to use the concatenation of GloVe and CoVe for question-answering and classification tasks. GloVe learns from the ratios of global word co-occurrences, so it has no sentence context, while CoVe is generated by processing text sequences is able to capture the contextual information.</p>
<script type="math/tex; mode=display">v = [\text{GloVe}(x); \text{CoVe}(x)]</script>
<p>Given a downstream task, we first generate the concatenation of GloVe + CoVe vectors of input words and then feed them into the task-specific models as additional features.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CoVe.png" alt="CoVe model" /></p>
<p><em>Fig. 2. The CoVe embeddings are generated by an encoder trained for machine translation task. The encoder can be plugged into any downstream task-specific model. (Image source: <a href="https://arxiv.org/abs/1708.00107">original paper</a>)</em></p>
<p><strong>Summary</strong>: The limitation of CoVe is obvious: (1) pre-training is bounded by available datasets on the supervised translation task; (2) the contribution of CoVe to the final performance is constrained by the task-specific model architecture.</p>
<p>In the following sections, we will see that ELMo overcomes issue (1) by unsupervised pre-training and OpenAI GPT & BERT further overcome both problems by unsupervised pre-training + using generative model architecture for different downstream tasks.</p>
<h2 id="elmo">ELMo</h2>
<p><strong>ELMo</strong>, short for <strong>Embeddings from Language Model</strong> (<a href="https://arxiv.org/abs/1802.05365">Peters, et al, 2018</a>) learns contextualized word representation by pre-training a language model in an <em>unsupervised</em> way.</p>
<h3 id="bidirectional-language-model">Bidirectional Language Model</h3>
<p>The bidirectional Language Model (<strong>biLM</strong>) is the foundation for ELMo. While the input is a sequence of <script type="math/tex">n</script> tokens, <script type="math/tex">(x_1, \dots, x_n)</script>, the language model learns to predict the probability of next token given the history.</p>
<p>In the forward pass, the history contains words before the target token,</p>
<script type="math/tex; mode=display">p(x_1, \dots, x_n) = \prod_{i=1}^n p(x_i \mid x_1, \dots, x_{i-1})</script>
<p>In the backward pass, the history contains words after the target token,</p>
<script type="math/tex; mode=display">p(x_1, \dots, x_n) = \prod_{i=1}^n p(x_i \mid x_{i+1}, \dots, x_n)</script>
<p>The predictions in both directions are modeled by multi-layer LSTMs with hidden states <script type="math/tex">\overrightarrow{\mathbf{h}}_{i,\ell}</script> and <script type="math/tex">\overleftarrow{\mathbf{h}}_{i,\ell}</script> for input token <script type="math/tex">x_i</script> at the layer level <script type="math/tex">\ell=1,\dots,L</script>.
The final layer’s hidden state <script type="math/tex">\mathbf{h}_{i,L} = [\overrightarrow{\mathbf{h}}_{i,L}; \overleftarrow{\mathbf{h}}_{i,L}]</script> is used to output the probabilities over tokens after softmax normalization. They share the embedding layer and the softmax layer, parameterized by <script type="math/tex">\Theta_e</script> and <script type="math/tex">\Theta_s</script> respectively.</p>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/ELMo-biLSTM.png" alt="ELMo biLSTM" /></p>
<p><em>Fig. 3. The biLSTM base model of ELMo. (Image source: recreated based on the figure in <a href="http://colah.github.io/posts/2015-09-NN-Types-FP/">“Neural Networks, Types, and Functional Programming”</a> by Christopher Olah.)</em></p>
<p>The model is trained to minimize the negative log likelihood (= maximize the log likelihood for true words) in both directions:</p>
<script type="math/tex; mode=display">\begin{aligned}
\mathcal{L} = - \sum_{i=1}^n \Big(
\log p(x_i \mid x_1, \dots, x_{i-1}; \Theta_e, \overrightarrow{\Theta}_\text{LSTM}, \Theta_s) + \\
\log p(x_i \mid x_{i+1}, \dots, x_n; \Theta_e, \overleftarrow{\Theta}_\text{LSTM}, \Theta_s) \Big)
\end{aligned}</script>
<h3 id="elmo-representations">ELMo Representations</h3>
<p>On top of a <script type="math/tex">L</script>-layer biLM, ELMo stacks all the hidden states across layers together by learning a task-specific linear combination. The hidden state representation for the token <script type="math/tex">x_i</script> contains <script type="math/tex">2L+1</script> vectors:</p>
<p><script type="math/tex">R_i = \{ \mathbf{h}_{i,\ell} \mid \ell = 0, \dots, L \}</script>
where <script type="math/tex">\mathbf{h}_{0, \ell}</script> is the embedding layer output and <script type="math/tex">\mathbf{h}_{i, \ell} = [\overrightarrow{\mathbf{h}}_{i,\ell}; \overleftarrow{\mathbf{h}}_{i,\ell}]</script>.</p>
<p>The weights, <script type="math/tex">\mathbf{s}^\text{task}</script>, in the linear combination are learned for each end task and normalized by softmax. The scaling factor <script type="math/tex">\gamma^\text{task}</script> is used to correct the misalignment between the distribution of biLM hidden states and the distribution of task specific representations.</p>
<script type="math/tex; mode=display">v_i = f(R_i; \Theta^\text{task}) = \gamma^\text{task} \sum_{\ell=0}^L s^\text{task}_i \mathbf{h}_{i,\ell}</script>
<p>To evaluate what kind of information is captured by hidden states across different layers, ELMo is applied on semantic-intensive and syntax-intensive tasks respectively using representations in different layers of biLM:</p>
<ul>
<li><strong>Semantic task</strong>: The <em>word sense disambiguation (WSD)</em> task emphasizes the meaning of a word given a context. The biLM top layer is better at this task than the first layer.</li>
<li><strong>Syntax task</strong>: The <em><a href="https://en.wikipedia.org/wiki/Part-of-speech_tagging">part-of-speech</a> (POS) tagging</em> task aims to infer the grammatical role of a word in one sentence. A higher accuracy can be achieved by using the biLM first layer than the top layer.</li>
</ul>
<p>The comparison study indicates that syntactic information is better represented at lower layers while semantic information is captured by higher layers. Because different layers tend to carry different type of information, <em>stacking them together helps</em>.</p>
<h3 id="use-elmo-in-downstream-tasks">Use ELMo in Downstream Tasks</h3>
<p>Similar to how <a href="#use-cove-in-downstream-tasks">CoVe</a> can help different downstream tasks, ELMo embedding vectors are included in the input or lower levels of task-specific models. Moreover, for some tasks (i.e., <a href="#nli">SNLI</a> and <a href="#qa">SQuAD</a>, but not <a href="#srl">SRL</a>), adding them into the output level helps too.</p>
<p>The improvements brought up by ELMo are largest for tasks with a small supervised dataset. With ELMo, we can also achieve similar performance with much less labeled data.</p>
<p><strong>Summary</strong>: The language model pre-training is unsupervised and theoretically the pre-training can be scaled up as much as possible since the unlabeled text corpora are abundant. However, it still has the dependency on task-customized models and thus the improvement is only incremental, while searching for a good model architecture for every task remains non-trivial.</p>
<h2 id="cross-view-training">Cross-View Training</h2>
<p>In ELMo the unsupervised pre-training and task-specific learning happen for two independent models in two separate training stages. <strong>Cross-View Training</strong> (abbr. <strong>CVT</strong>; <a href="https://arxiv.org/abs/1809.08370">Clark et al., 2018</a>) combines them into one unified semi-supervised learning procedure where the representation of a biLSTM encoder is improved by both supervised learning with labeled data and unsupervised learning with unlabeled data on auxiliary tasks.</p>
<h3 id="model-architecture">Model Architecture</h3>
<p>The model consists of a two-layer bidirectional LSTM encoder and a primary prediction module. During training, the model is fed with labeled and unlabeled data batches alternatively.</p>
<ul>
<li>On <em>labeled examples</em>, all the model parameters are updated by standard supervised learning. The loss is the standard cross entropy.</li>
<li>On <em>unlabeled examples</em>, the primary prediction module still can produce a “soft” target, even though we cannot know exactly how accurate they are. In a couple of auxiliary tasks, the predictor only sees and processes a restricted view of the input, such as only using encoder hidden state representation in one direction. The auxiliary task outputs are expected to match the primary prediction target for a full view of input. <br />In this way, the encoder is forced to distill the knowledge of the full context into partial representation. At this stage, the biLSTM encoder is backpropagated but the primary prediction module is <em>fixed</em>. The loss is to minimize the distance between auxiliary and primary predictions.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CVT.png" alt="CVT" /></p>
<p><em>Fig. 4. The overview of semi-supervised language model cross-view training. (Image source: <a href="https://arxiv.org/abs/1809.08370">original paper</a>)</em></p>
<h3 id="multi-task-learning">Multi-Task Learning</h3>
<p>When training for multiple tasks simultaneously, CVT adds several extra primary prediction models for additional tasks. They all share the same sentence representation encoder.
During supervised training, once one task is randomly selected, parameters in its corresponding predictor and the representation encoder are updated.
With unlabeled data samples, the encoder is optimized jointly across all the tasks by minimizing the differences between auxiliary outputs and primary prediction for every task.</p>
<p>The multi-task learning encourages better generality of representation and in the meantime produces a nice side-product: all-tasks-labeled examples from unlabeled data. They are precious data labels considering that cross-task labels are useful but fairly rare.</p>
<h3 id="use-cvt-in-downstream-tasks">Use CVT in Downstream Tasks</h3>
<p>Theoretically the primary prediction module can take any form, generic or task-specific design. The examples presented in the CVT paper include both cases.</p>
<p>In sequential tagging tasks (classification for every token) like <a href="#ner">NER</a> or <a href="#pos">POS</a> tagging, the predictor module contains two fully connected layers and a softmax layer on the output to produce a probability distribution over class labels.
For each token <script type="math/tex">\mathbf{x}_i</script>, we take the corresponding hidden states in two layers, <script type="math/tex">\mathbf{h}_1^{(i)}</script> and <script type="math/tex">\mathbf{h}_2^{(i)}</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
p_\theta(y_i \mid \mathbf{x}_i)
&= \text{NN}(\mathbf{h}^{(i)}) \\
&= \text{NN}([\mathbf{h}_1^{(i)}; \mathbf{h}_2^{(i)}]) \\
&= \text{softmax} \big( \mathbf{W}\cdot\text{ReLU}(\mathbf{W'}\cdot[\mathbf{h}_1^{(i)}; \mathbf{h}_2^{(i)}]) + \mathbf{b} \big)
\end{aligned} %]]></script>
<p>The auxiliary tasks are only fed with forward or backward LSTM state in the first layer. Because they only observe partial context, either on the left or right, they have to learn like a language model, trying to predict the next token given the context. The <code class="language-plaintext highlighter-rouge">fwd</code> and <code class="language-plaintext highlighter-rouge">bwd</code> auxiliary tasks only take one direction. The <code class="language-plaintext highlighter-rouge">future</code> and <code class="language-plaintext highlighter-rouge">past</code> tasks take one step further in forward and backward direction, respectively.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
p_\theta^\text{fwd}(y_i \mid \mathbf{x}_i) &= \text{NN}^\text{fwd}(\overrightarrow{\mathbf{h}}^{(i)}) \\
p_\theta^\text{bwd}(y_i \mid \mathbf{x}_i) &= \text{NN}^\text{bwd}(\overleftarrow{\mathbf{h}}^{(i)}) \\
p_\theta^\text{future}(y_i \mid \mathbf{x}_i) &= \text{NN}^\text{future}(\overrightarrow{\mathbf{h}}^{(i-1)}) \\
p_\theta^\text{past}(y_i \mid \mathbf{x}_i) &= \text{NN}^\text{past}(\overleftarrow{\mathbf{h}}^{(i+1)})
\end{aligned} %]]></script>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/CVT-example.png" alt="CVT sequential tagging" /></p>
<p><em>Fig. 5. The sequential tagging task depends on four auxiliary prediction models, their inputs only involving hidden states in one direction: forward, backward, future and past. (Image source: <a href="https://arxiv.org/abs/1809.08370">original paper</a>)</em></p>
<p>Note that if the primary prediction module has dropout, the dropout layer works as usual when training with labeled data, but it is not applied when generating “soft” target for auxiliary tasks during training with unlabeled data.</p>
<p>In the machine translation task, the primary prediction module is replaced with a standard unidirectional LSTM decoder with attention. There are two auxiliary tasks: (1) apply dropout on the attention weight vector by randomly zeroing out some values; (2) predict the future word in the target sequence. The primary prediction for auxiliary tasks to match is the best predicted target sequence produced by running the fixed primary decoder on the input sequence with <a href="https://en.wikipedia.org/wiki/Beam_search">beam search</a>.</p>
<h2 id="ulmfit">ULMFiT</h2>
<p>The idea of using generative pretrained LM + task-specific fine-tuning was first explored in ULMFiT (<a href="https://arxiv.org/abs/1801.06146">Howard & Ruder, 2018</a>), directly motivated by the success of using ImageNet pre-training for computer vision tasks. The base model is <a href="https://arxiv.org/abs/1708.02182">AWD-LSTM</a>.</p>
<p>ULMFiT follows three steps to achieve good transfer learning results on downstream language classification tasks:</p>
<p>1) <em>General LM pre-training</em>: on Wikipedia text.</p>
<p>2) <em>Target task LM fine-tuning</em>: ULMFiT proposed two training techniques for stabilizing the fine-tuning process. See below.</p>
<ul>
<li>
<p><strong>Discriminative fine-tuning</strong> is motivated by the fact that different layers of LM capture different types of information (see <a href="#elmo-representations">discussion</a> above). ULMFiT proposed to tune each layer with different learning rates, <script type="math/tex">\{\eta^1, \dots, \eta^\ell, \dots, \eta^L\}</script>, where <script type="math/tex">\eta</script> is the base learning rate for the first layer, <script type="math/tex">\eta^\ell</script> is for the <script type="math/tex">\ell</script>-th layer and there are <script type="math/tex">L</script> layers in total.</p>
</li>
<li>
<p><strong>Slanted triangular learning rates (STLR)</strong> refer to a special learning rate scheduling that first linearly increases the learning rate and then linearly decays it. The increase stage is short so that the model can converge to a parameter space suitable for the task fast, while the decay period is long allowing for better fine-tuning.</p>
</li>
</ul>
<p>3) <em>Target task classifier fine-tuning</em>: The pretrained LM is augmented with two standard feed-forward layers and a softmax normalization at the end to predict a target label distribution.</p>
<ul>
<li>
<p><strong>Concat pooling</strong> extracts max-polling and mean-pooling over the history of hidden states and concatenates them with the final hidden state.</p>
</li>
<li>
<p><strong>Gradual unfreezing</strong> helps to avoid catastrophic forgetting by gradually unfreezing the model layers starting from the last one. First the last layer is unfrozen and fine-tuned for one epoch. Then the next lower layer is unfrozen. This process is repeated until all the layers are tuned.</p>
</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ULMFiT.png" alt="ULMFiT" /></p>
<p><em>Fig. 6. Three training stages of ULMFiT. (Image source: <a href="https://arxiv.org/abs/1801.06146">original paper</a>)</em></p>
<h2 id="openai-gpt">OpenAI GPT</h2>
<p>Following the similar idea of ELMo, OpenAI <strong>GPT</strong>, short for <strong>Generative Pre-training Transformer</strong> (<a href="https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf">Radford et al., 2018</a>), expands the unsupervised language model to a much larger scale by training on a giant collection of free text corpora. Despite of the similarity, GPT has two major differences from ELMo.</p>
<ol>
<li>The model architectures are different: ELMo uses a shallow concatenation of independently trained left-to-right and right-to-left multi-layer LSTMs, while GPT is a multi-layer transformer decoder.</li>
<li>The use of contextualized embeddings in downstream tasks are different: ELMo feeds embeddings into models customized for specific tasks as additional features, while GPT fine-tunes the same base model for all end tasks.</li>
</ol>
<h3 id="transformer-decoder-as-language-model">Transformer Decoder as Language Model</h3>
<p>Compared to the <a href="https://arxiv.org/abs/1706.03762">original transformer</a> architecture, the <a href="https://arxiv.org/abs/1801.10198">transformer decoder</a> model discards the encoder part, so there is only one single input sentence rather than two separate source and target sequences.</p>
<p>This model applies multiple transformer blocks over the embeddings of input sequences. Each block contains a masked <em>multi-headed self-attention</em> layer and a <em>pointwise feed-forward</em> layer. The final output produces a distribution over target tokens after softmax normalization.</p>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/OpenAI-GPT-transformer-decoder.png" alt="OpenAI GPT transformer decoder" /></p>
<p><em>Fig. 7. The transformer decoder model architecture in OpenAI GPT.</em></p>
<p>The loss is the negative log-likelihood, same as <a href="#elmo">ELMo</a>, but without backward computation. Let’s say, the context window of the size <script type="math/tex">k</script> is located before the target word and the loss would look like:</p>
<script type="math/tex; mode=display">\mathcal{L}_\text{LM} = -\sum_{i} \log p(x_i\mid x_{i-k}, \dots, x_{i-1})</script>
<h3 id="byte-pair-encoding">Byte Pair Encoding</h3>
<p><strong>Byte Pair Encoding</strong> (<a href="https://arxiv.org/abs/1508.07909"><strong>BPE</strong></a>) is used to encode the input sequences. BPE was originally proposed as a data compression algorithm in 1990s and then was adopted to solve the open-vocabulary issue in machine translation, as we can easily run into rare and unknown words when translating into a new language. Motivated by the intuition that rare and unknown words can often be decomposed into multiple subwords, BPE finds the best word segmentation by iteratively and greedily merging frequent pairs of characters.</p>
<h3 id="supervised-fine-tuning">Supervised Fine-Tuning</h3>
<p>The most substantial upgrade that OpenAI GPT proposed is to get rid of the task-specific model and use the pre-trained language model directly!</p>
<p>Let’s take classification as an example. Say, in the labeled dataset, each input has <script type="math/tex">n</script> tokens, <script type="math/tex">\mathbf{x} = (x_1, \dots, x_n)</script>, and one label <script type="math/tex">y</script>. GPT first processes the input sequence <script type="math/tex">\mathbf{x}</script> through the pre-trained transformer decoder and the last layer output for the last token <script type="math/tex">x_n</script> is <script type="math/tex">\mathbf{h}_L^{(n)}</script>. Then with only one new trainable weight matrix <script type="math/tex">\mathbf{W}_y</script>, it can predict a distribution over class labels.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/GPT-classification.png" alt="GPT classification" /></p>
<script type="math/tex; mode=display">P(y\mid x_1, \dots, x_n) = \text{softmax}(\mathbf{h}_L^{(n)}\mathbf{W}_y)</script>
<p>The loss is to minimize the negative log-likelihood for true labels. In addition, adding the LM loss as an auxiliary loss is found to be beneficial, because:</p>
<ul>
<li>(1) it helps accelerate convergence during training and</li>
<li>(2) it is expected to improve the generalization of the supervised model.</li>
</ul>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}_\text{cls} &= \sum_{(\mathbf{x}, y) \in \mathcal{D}} \log P(y\mid x_1, \dots, x_n) = \sum_{(\mathbf{x}, y) \in \mathcal{D}} \log \text{softmax}(\mathbf{h}_L^{(n)}(\mathbf{x})\mathbf{W}_y) \\
\mathcal{L}_\text{LM} &= -\sum_{i} \log p(x_i\mid x_{i-k}, \dots, x_{i-1}) \\
\mathcal{L} &= \mathcal{L}_\text{cls} + \lambda \mathcal{L}_\text{LM}
\end{aligned} %]]></script>
<p>With similar designs, no customized model structure is needed for other end tasks (see Fig. 7). If the task input contains multiple sentences, a special delimiter token (<code class="language-plaintext highlighter-rouge">$</code>) is added between each pair of sentences. The embedding for this delimiter token is a new parameter we need to learn, but it should be pretty minimal.</p>
<p>For the sentence similarity task, because the ordering does not matter, both orderings are included. For the multiple choice task, the context is paired with every answer candidate.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/GPT-downstream-tasks.png" alt="GPT downstream tasks" /></p>
<p><em>Fig. 8. Training objects in slightly modified GPT transformer models for downstream tasks. (Image source: <a href="https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf">original paper</a>)</em></p>
<p><strong>Summary</strong>: It is super neat and encouraging to see that such a general framework is capable to beat SOTA on most language tasks at that time (June 2018). At the first stage, generative pre-training of a language model can absorb as much free text as possible. Then at the second stage, the model is fine-tuned on specific tasks with a small labeled dataset and a minimal set of new parameters to learn.</p>
<p>One limitation of GPT is its uni-directional nature — the model is only trained to predict the future left-to-right context.</p>
<h2 id="bert">BERT</h2>
<p><strong>BERT</strong>, short for <strong>Bidirectional Encoder Representations from Transformers</strong> (<a href="https://arxiv.org/abs/1810.04805">Devlin, et al., 2019</a>) is a direct descendant to <a href="#gpt">GPT</a>: train a large language model on free text and then fine-tune on specific tasks without customized network architectures.</p>
<p>Compared to GPT, the largest difference and improvement of BERT is to make training <strong>bi-directional</strong>. The model learns to predict both context on the left and right. The paper according to the ablation study claimed that:</p>
<blockquote>
<p>“bidirectional nature of our model is the single most important new contribution”</p>
</blockquote>
<h3 id="pre-training-tasks">Pre-training Tasks</h3>
<p>The model architecture of BERT is a multi-layer bidirectional Transformer encoder.</p>
<p style="width: 25%;" class="center"><img src="/lil-log/assets/images/transformer-encoder-2.png" alt="transformer encoder" /></p>
<p><em>Fig. 9. Recap of Transformer Encoder model architecture. (Image source: <a href="https://arxiv.org/abs/1706.03762">Transformer paper</a>)</em></p>
<p>To encourage the bi-directional prediction and sentence-level understanding, BERT is trained with two tasks instead of the basic language task (that is, to predict the next token given context).</p>
<p><strong>Task 1: Mask language model (MLM)</strong></p>
<blockquote>
<p>From <a href="https://en.wikipedia.org/wiki/Cloze_test">Wikipedia</a>: “A cloze test (also cloze deletion test) is an exercise, test, or assessment consisting of a portion of language with certain items, words, or signs removed (cloze text), where the participant is asked to replace the missing language item. … The exercise was first described by W.L. Taylor in 1953.”</p>
</blockquote>
<p>It is unsurprising to believe that a representation that learns the context around a word rather than just after the word is able to better capture its meaning, both syntactically and semantically. BERT encourages the model to do so by training on the <em>“mask language model” task</em>:</p>
<ol>
<li>Randomly mask 15% of tokens in each sequence. Because if we only replace masked tokens with a special placeholder <code class="language-plaintext highlighter-rouge">[MASK]</code>, the special token would never be encountered during fine-tuning. Hence, BERT employed several heuristic tricks:
<ul>
<li>(a) with 80% probability, replace the chosen words with <code class="language-plaintext highlighter-rouge">[MASK]</code>;</li>
<li>(b) with 10% probability, replace with a random word;</li>
<li>(c) with 10% probability, keep it the same.</li>
</ul>
</li>
<li>The model only predicts the missing words, but it has no information on which words have been replaced or which words should be predicted. The output size is only 15% of the input size.</li>
</ol>
<p><a name="NSP"></a><strong>Task 2: Next sentence prediction</strong></p>
<p>Motivated by the fact that many downstream tasks involve the understanding of relationships between sentences (i.e., <a href="#qa">QA</a>, <a href="#nli">NLI</a>), BERT added another auxiliary task on training a <em>binary classifier</em> for telling whether one sentence is the next sentence of the other:</p>
<ol>
<li>Sample sentence pairs (A, B) so that:
<ul>
<li>(a) 50% of the time, B follows A;</li>
<li>(b) 50% of the time, B does not follow A.</li>
</ul>
</li>
<li>The model processes both sentences and output a binary label indicating whether B is the next sentence of A.</li>
</ol>
<p>The training data for both auxiliary tasks above can be trivially generated from any monolingual corpus. Hence the scale of training is unbounded. The training loss is the sum of the mean masked LM likelihood and mean next sentence prediction likelihood.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/language-model-comparison.png" alt="Language model comparison" /></p>
<p><em>Fig. 10. Comparison of BERT, OpenAI GPT and ELMo model architectures. (Image source: <a href="https://arxiv.org/abs/1810.04805">original paper</a>)</em></p>
<h3 id="input-embedding">Input Embedding</h3>
<p>The input embedding is the sum of three parts:</p>
<ol>
<li><em>WordPiece tokenization embeddings</em>: The <a href="https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf">WordPiece</a> <a href="https://arxiv.org/pdf/1609.08144.pdf">model</a> was originally proposed for Japanese or Korean segmentation problem. Instead of using naturally split English word, they can be further divided into smaller sub-word units so that it is more effective to handle rare or unknown words. Please read <a href="https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf">linked</a> <a href="https://arxiv.org/pdf/1609.08144.pdf">papers</a> for the optimal way to split words if interested.</li>
<li><em>Segment embeddings</em>: If the input contains two sentences, they have sentence A embeddings and sentence B embeddings respectively and they are separated by a special character <code class="language-plaintext highlighter-rouge">[SEP]</code>; Only sentence A embeddings are used if the input only contains one sentence.</li>
<li><em>Position embeddings</em>: Positional embeddings are learned rather than hard-coded.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/BERT-input-embedding.png" alt="BERT input embedding" /></p>
<p><em>Fig. 11. BERT input representation. (Image source: <a href="https://arxiv.org/abs/1810.04805">original paper</a>)</em></p>
<p>Note that the first token is always forced to be <code class="language-plaintext highlighter-rouge">[CLS]</code> — a placeholder that will be used later for prediction in downstream tasks.</p>
<h3 id="use-bert-in-downstream-tasks">Use BERT in Downstream Tasks</h3>
<p>BERT fine-tuning requires only a few new parameters added, just like OpenAI GPT.</p>
<p>For classification tasks, we get the prediction by taking the final hidden state of the special first token <code class="language-plaintext highlighter-rouge">[CLS]</code>, <script type="math/tex">\mathbf{h}^\text{[CLS]}_L</script>, and multiplying it with a small weight matrix, <script type="math/tex">\text{softmax}(\mathbf{h}^\text{[CLS]}_L \mathbf{W}_\text{cls})</script>.</p>
<p>For <a href="#qa">QA</a> tasks like SQuAD, we need to predict the text span in the given paragraph for an given question. BERT predicts two probability distributions of every token, being the start and the end of the text span. Only two new small matrices, <script type="math/tex">\mathbf{W}_\text{s}</script> and <script type="math/tex">\mathbf{W}_\text{e}</script>, are newly learned during fine-tuning and <script type="math/tex">\text{softmax}(\mathbf{h}^\text{(i)}_L \mathbf{W}_\text{s})</script> and <script type="math/tex">\text{softmax}(\mathbf{h}^\text{(i)}_L \mathbf{W}_\text{e})</script> define two probability distributions.</p>
<p>Overall the add-on part for end task fine-tuning is very minimal — one or two weight matrices to convert the Transform hidden states to an interpretable format. Check the paper for implementation details for other cases.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/BERT-downstream-tasks.png" alt="BERT downstream tasks" /></p>
<p><em>Fig. 12. Training objects in slightly modified BERT models for downstream tasks. (Image source: <a href="https://arxiv.org/abs/1810.04805">original paper</a>)</em></p>
<p>A summary table compares differences between fine-tuning of OpenAI GPT and BERT.</p>
<table class="info">
<tbody>
<tr>
<td> </td>
<td><strong>OpenAI GPT</strong></td>
<td><strong>BERT</strong></td>
</tr>
<tr>
<td>Special char</td>
<td><code class="language-plaintext highlighter-rouge">[SEP]</code> and <code class="language-plaintext highlighter-rouge">[CLS]</code> are only introduced at fine-tuning stage.</td>
<td><code class="language-plaintext highlighter-rouge">[SEP]</code> and <code class="language-plaintext highlighter-rouge">[CLS]</code> and sentence A/B embeddings are learned at the pre-training stage.</td>
</tr>
<tr>
<td>Training process</td>
<td>1M steps, batch size 32k words.</td>
<td>1M steps, batch size 128k words.</td>
</tr>
<tr>
<td>Fine-tuning</td>
<td>lr = 5e-5 for all fine-tuning tasks.</td>
<td>Use task-specific lr for fine-tuning.</td>
</tr>
</tbody>
</table>
<h2 id="albert">ALBERT</h2>
<p><strong>ALBERT</strong> (<a href="https://arxiv.org/abs/1909.11942">Lan, et al. 2019</a>), short for <strong>A Lite BERT</strong>, is a light-weighted version of <a href="#BERT">BERT</a> model. An ALBERT model can be trained 1.7x faster with 18x fewer parameters, compared to a BERT model of similar configuration. ALBERT incorporates three changes as follows: the first two help reduce parameters and memory consumption and hence speed up the training speed, while the third one proposes a more chanllenging training task to replace the next sentence prediction (NSP) objective.</p>
<h3 id="factorized-embedding-parameterization">Factorized Embedding Parameterization</h3>
<p>In BERT, the WordPiece tokenization embedding size <script type="math/tex">E</script> is configured to be the same as the hidden state size <script type="math/tex">H</script>. That is saying, if we want to increase the model size (larger <script type="math/tex">H</script>), we need to learn a larger tokenization embedding too, which is expensive because it depends on the vocabulary size (<script type="math/tex">V</script>).</p>
<p>Conceptually, because the tokenization embedding is expected to learn <em>context-independent</em> representation and the hidden states are <em>context-dependent</em>, it makes sense to separate the size of the hidden layers from the size of vocabulary embedding. Using factorized embedding parameterization, the large vocabulary embedding matrix of size <script type="math/tex">V \times H</script> is decomposed into two small matrices of size <script type="math/tex">V \times E</script> and <script type="math/tex">E \times H</script>. Given <script type="math/tex">H \gt E</script> or even <script type="math/tex">H \gg E</script>, factorization can result in significant parameter reduction.</p>
<h3 id="cross-layer-parameter-sharing">Cross-layer Parameter Sharing</h3>
<p>Parameter sharing across layers can happen in many ways: (a) only share feed-forward part; (b) only share attention parameters; or (c) share all the parameters. This technique reduces the number of parameters by a ton and does not damage the performance too much.</p>
<h3 id="sentence-order-prediction-sop">Sentence-Order Prediction (SOP)</h3>
<p>Interestingly, the <a href="#NSP">next sentence prediction (NSP)</a> task of BERT turned out to be too easy. ALBERT instead adopted a sentence-order prediction (SOP) <a href="/lil-log/2019/11/10/self-supervised-learning.html">self-supervised</a> loss,
Positive sample: two consecutive segments from the same document.
Negative sample: same as above, but the segment order is switched.</p>
<p>For the NSP task, the model can make reasonable predictions if it is able to detect topics when A and B are from different contexts. In comparison, SOP is harder as it requires the model to fully understand the coherence and ordering between segments.</p>
<h2 id="openai-gpt-2">OpenAI GPT-2</h2>
<p>The <a href="https://blog.openai.com/better-language-models/">OpenAI</a> <a href="https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf">GPT-2</a> language model is a direct successor to <a href="#openai-gpt">GPT</a>. GPT-2 has 1.5B parameters, 10x more than the original GPT, and it achieves SOTA results on 7 out of 8 tested language modeling datasets in a <em>zero-shot transfer setting</em> without any task-specific fine-tuning. The pre-training dataset contains 8 million Web pages collected by crawling qualified outbound links from <a href="https://www.reddit.com/">Reddit</a>. Large improvements by OpenAI GPT-2 are specially noticeable on small datasets and datasets used for measuring <em>long-term dependency</em>.</p>
<h3 id="zero-shot-transfer">Zero-Shot Transfer</h3>
<p>The pre-training task for GPT-2 is solely language modeling. All the downstream language tasks are framed as predicting conditional probabilities and there is no task-specific fine-tuning.</p>
<ul>
<li>Text generation is straightforward using LM.</li>
<li>Machine translation task, for example, English to Chinese, is induced by conditioning LM on pairs of “English sentence = Chinese sentence” and “the target English sentence =” at the end.
<ul>
<li>For example, the conditional probability to predict might look like: <code class="language-plaintext highlighter-rouge">P(? | I like green apples. = 我喜欢绿苹果。 A cat meows at him. = 一只猫对他喵。It is raining cats and dogs. =")</code></li>
</ul>
</li>
<li>QA task is formatted similar to translation with pairs of questions and answers in the context.</li>
<li>Summarization task is induced by adding <code class="language-plaintext highlighter-rouge">TL;DR:</code> after the articles in the context.</li>
</ul>
<h3 id="bpe-on-byte-sequences">BPE on Byte Sequences</h3>
<p>Same as the original GPT, GPT-2 uses <a href="#byte-pair-encoding">BPE</a> but on <a href="https://en.wikipedia.org/wiki/UTF-8">UTF-8</a> byte sequences. Each byte can represent 256 different values in 8 bits, while UTF-8 can use up to 4 bytes for one character, supporting up to <script type="math/tex">2^{31}</script> characters in total. Therefore, with byte sequence representation we only need a vocabulary of size 256 and do not need to worry about pre-processing, tokenization, etc. Despite of the benefit, current byte-level LMs still have non-negligible performance gap with the SOTA word-level LMs.</p>
<p>BPE merges frequently co-occurred byte pairs in a greedy manner. To prevent it from generating multiple versions of common words (i.e. <code class="language-plaintext highlighter-rouge">dog.</code>, <code class="language-plaintext highlighter-rouge">dog!</code> and <code class="language-plaintext highlighter-rouge">dog?</code> for the word <code class="language-plaintext highlighter-rouge">dog</code>), GPT-2 prevents BPE from merging characters across categories (thus <code class="language-plaintext highlighter-rouge">dog</code> would not be merged with punctuations like <code class="language-plaintext highlighter-rouge">.</code>, <code class="language-plaintext highlighter-rouge">!</code> and <code class="language-plaintext highlighter-rouge">?</code>). This tricks help increase the quality of the final byte segmentation.</p>
<p>Using the byte sequence representation, GPT-2 is able to assign a probability to any Unicode string, regardless of any pre-processing steps.</p>
<h3 id="model-modifications">Model Modifications</h3>
<p>Compared to GPT, other than having many more transformer layers and parameters, GPT-2 incorporates only a few architecture modifications:</p>
<ul>
<li><a href="https://arxiv.org/abs/1607.06450">Layer normalization</a> was moved to the input of each sub-block, similar to a residual unit of type <a href="https://arxiv.org/abs/1603.05027">“building block”</a> (differently from the original type <a href="https://arxiv.org/abs/1512.03385">“bottleneck”</a>, it has batch normalization applied before weight layers).</li>
<li>An additional layer normalization was added after the final self-attention block.</li>
<li>A modified initialization was constructed as a function of the model depth.</li>
<li>The weights of residual layers were initially scaled by a factor of <script type="math/tex">1/ \sqrt{N}</script> where N is the number of residual layers.</li>
<li>Use larger vocabulary size and context size.</li>
</ul>
<h2 id="summary">Summary</h2>
<table class="info">
<thead>
<tr>
<th> </th>
<th>Base model</th>
<th>pre-training</th>
<th>Downstream tasks</th>
<th>Downstream model</th>
<th>Fine-tuning</th>
</tr>
</thead>
<tbody>
<tr>
<td>CoVe</td>
<td>seq2seq NMT model</td>
<td>supervised</td>
<td>feature-based</td>
<td>task-specific</td>
<td>/</td>
</tr>
<tr>
<td>ELMo</td>
<td>two-layer biLSTM</td>
<td>unsupervised</td>
<td>feature-based</td>
<td>task-specific</td>
<td>/</td>
</tr>
<tr>
<td>CVT</td>
<td>two-layer biLSTM</td>
<td>semi-supervised</td>
<td>model-based</td>
<td>task-specific / task-agnostic</td>
<td>/</td>
</tr>
<tr>
<td>ULMFiT</td>
<td>AWD-LSTM</td>
<td>unsupervised</td>
<td>model-based</td>
<td>task-agnostic</td>
<td>all layers; with various training tricks</td>
</tr>
<tr>
<td>GPT</td>
<td>Transformer decoder</td>
<td>unsupervised</td>
<td>model-based</td>
<td>task-agnostic</td>
<td>pre-trained layers + top task layer(s)</td>
</tr>
<tr>
<td>BERT</td>
<td>Transformer encoder</td>
<td>unsupervised</td>
<td>model-based</td>
<td>task-agnostic</td>
<td>pre-trained layers + top task layer(s)</td>
</tr>
<tr>
<td>GPT-2</td>
<td>Transformer decoder</td>
<td>unsupervised</td>
<td>model-based</td>
<td>task-agnostic</td>
<td>pre-trained layers + top task layer(s)</td>
</tr>
</tbody>
</table>
<h2 id="metric-perplexity">Metric: Perplexity</h2>
<p>Perplexity is often used as an intrinsic evaluation metric for gauging how well a language model can capture the real word distribution conditioned on the context.</p>
<p>A <a href="https://en.wikipedia.org/wiki/Perplexity">perplexity</a> of a discrete proability distribution <script type="math/tex">p</script> is defined as the exponentiation of the entropy:</p>
<script type="math/tex; mode=display">2^{H(p)} = 2^{-\sum_x p(x) \log_2 p(x)}</script>
<p>Given a sentence with <script type="math/tex">N</script> words, <script type="math/tex">s = (w_1, \dots, w_N)</script>, the entropy looks as follows, simply assuming that each word has the same frequency, <script type="math/tex">\frac{1}{N}</script>:</p>
<script type="math/tex; mode=display">H(s) = -\sum_{i=1}^N P(w_i) \log_2 p(w_i) = -\sum_{i=1}^N \frac{1}{N} \log_2 p(w_i)</script>
<p>The perplexity for the sentence becomes:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
2^{H(s)} &= 2^{-\frac{1}{N} \sum_{i=1}^N \log_2 p(w_i)}
= (2^{\sum_{i=1}^N \log_2 p(w_i)})^{-\frac{1}{N}}
= (p(w_1) \dots p(w_N))^{-\frac{1}{N}}
\end{aligned} %]]></script>
<p>A good language model should predict high word probabilities. Therefore, the smaller perplexity the better.</p>
<h2 id="common-tasks-and-datasets">Common Tasks and Datasets</h2>
<p><a name="qa"></a>
<strong>Question-Answering</strong></p>
<ul>
<li><a href="https://rajpurkar.github.io/SQuAD-explorer/">SQuAD</a> (Stanford Question Answering Dataset): A reading comprehension dataset, consisting of questions posed on a set of Wikipedia articles, where the answer to every question is a span of text.</li>
<li><a href="http://www.qizhexie.com/data/RACE_leaderboard">RACE</a> (ReAding Comprehension from Examinations): A large-scale reading comprehension dataset with more than 28,000 passages and nearly 100,000 questions. The dataset is collected from English examinations in China, which are designed for middle school and high school students.</li>
</ul>
<p><strong>Commonsense Reasoning</strong></p>
<ul>
<li><a href="http://cs.rochester.edu/nlp/rocstories/">Story Cloze Test</a>: A commonsense reasoning framework for evaluating story understanding and generation. The test requires a system to choose the correct ending to multi-sentence stories from two options.</li>
<li><a href="https://rowanzellers.com/swag/">SWAG</a> (Situations With Adversarial Generations): multiple choices; contains 113k sentence-pair completion examples that evaluate grounded common-sense inference</li>
</ul>
<p><a name="nli"></a>
<strong>Natural Language Inference (NLI)</strong>: also known as <strong>Text Entailment</strong>, an exercise to discern in logic whether one sentence can be inferred from another.</p>
<ul>
<li><a href="https://aclweb.org/aclwiki/Textual_Entailment_Resource_Pool">RTE</a> (Recognizing Textual Entailment): A set of datasets initiated by text entailment challenges.</li>
<li><a href="https://nlp.stanford.edu/projects/snli/">SNLI</a> (Stanford Natural Language Inference): A collection of 570k human-written English sentence pairs manually labeled for balanced classification with the labels <code class="language-plaintext highlighter-rouge">entailment</code>, <code class="language-plaintext highlighter-rouge">contradiction</code>, and <code class="language-plaintext highlighter-rouge">neutral</code>.</li>
<li><a href="https://www.nyu.edu/projects/bowman/multinli/">MNLI</a> (Multi-Genre NLI): Similar to SNLI, but with a more diverse variety of text styles and topics, collected from transcribed speech, popular fiction, and government reports.</li>
<li><a href="https://gluebenchmark.com/tasks">QNLI</a> (Question NLI): Converted from SQuAD dataset to be a binary classification task over pairs of (question, sentence).</li>
<li><a href="http://data.allenai.org/scitail/">SciTail</a>: An entailment dataset created from multiple-choice science exams and web sentences.</li>
</ul>
<p><a name="ner"></a>
<strong>Named Entity Recognition (NER)</strong>: labels sequences of words in a text which are the names of things, such as person and company names, or gene and protein names</p>
<ul>
<li><a href="https://www.clips.uantwerpen.be/conll2003/">CoNLL 2003 NER task</a>: consists of newswire from the Reuters, concentrating on four types of named entities: persons, locations, organizations and names of miscellaneous entities.</li>
<li><a href="https://catalog.ldc.upenn.edu/LDC2013T19">OntoNotes 5.0</a>: This corpus contains text in English, Arabic and Chinese, tagged with four different entity types (PER, LOC, ORG, MISC).</li>
<li><a href="https://trec.nist.gov/data/reuters/reuters.html">Reuters Corpus</a>: A large collection of Reuters News stories.</li>
<li>Fine-Grained NER (FGN)</li>
</ul>
<p><strong>Sentiment Analysis</strong></p>
<ul>
<li><a href="https://nlp.stanford.edu/sentiment/index.html">SST</a> (Stanford Sentiment Treebank)</li>
<li><a href="http://ai.stanford.edu/~amaas/data/sentiment/">IMDb</a>: A large dataset of movie reviews with binary sentiment classification labels.</li>
</ul>
<p><a name="srl"></a>
<strong>Semantic Role Labeling (SRL)</strong>: models the predicate-argument structure of a sentence, and is often described as answering “Who did what to whom”.</p>
<ul>
<li><a href="http://www.lsi.upc.edu/~srlconll/">CoNLL-2004 & CoNLL-2005</a></li>
</ul>
<p><strong>Sentence similarity</strong>: also known as <em>paraphrase detection</em></p>
<ul>
<li><a href="https://www.microsoft.com/en-us/download/details.aspx?id=52398">MRPC</a> (MicRosoft Paraphrase Corpus): It contains pairs of sentences extracted from news sources on the web, with annotations indicating whether each pair is semantically equivalent.</li>
<li><a href="https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs">QQP</a> (Quora Question Pairs)
STS Benchmark: Semantic Textual Similarity</li>
</ul>
<p><strong>Sentence Acceptability</strong>: a task to annotate sentences for grammatical acceptability.</p>
<ul>
<li><a href="https://nyu-mll.github.io/CoLA/">CoLA</a> (Corpus of Linguistic Acceptability): a binary single-sentence classification task.</li>
</ul>
<p><strong>Text Chunking</strong>: To divide a text in syntactically correlated parts of words.</p>
<ul>
<li><a href="https://www.clips.uantwerpen.be/conll2000/chunking/">CoNLL-2000</a></li>
</ul>
<p><a name="pos"></a>
<strong>Part-of-Speech (POS) Tagging</strong>: tag parts of speech to each token, such as noun, verb, adjective, etc.
the Wall Street Journal portion of the Penn Treebank (Marcus et al., 1993).</p>
<p><strong>Machine Translation</strong>: See <a href="https://nlp.stanford.edu/projects/nmt/">Standard NLP</a> page.</p>
<ul>
<li>WMT 2015 English-Czech data (Large)</li>
<li>WMT 2014 English-German data (Medium)</li>
<li>IWSLT 2015 English-Vietnamese data (Small)</li>
</ul>
<p><strong>Coreference Resolution</strong>: cluster mentions in text that refer to the same underlying real world entities.</p>
<ul>
<li><a href="http://conll.cemantix.org/2012/data.html">CoNLL-2012</a></li>
</ul>
<p><strong>Long-range Dependency</strong></p>
<ul>
<li><a href="http://clic.cimec.unitn.it/lambada/">LAMBADA</a> (LAnguage Modeling Broadened to Account for Discourse Aspects): A collection of narrative passages extracted from the BookCorpus and the task is to predict the last word, which require at least 50 tokens of context for a human to successfully predict.</li>
<li><a href="https://research.fb.com/downloads/babi/">Children’s Book Test</a>: is built from books that are freely available in <a href="https://www.gutenberg.org/">Project Gutenberg</a>. The task is to predict the missing word among 10 candidates.</li>
</ul>
<p><strong>Multi-task benchmark</strong></p>
<ul>
<li>GLUE multi-task benchmark: <a href="https://gluebenchmark.com/">https://gluebenchmark.com</a></li>
<li>decaNLP benmark: <a href="https://decanlp.com/">https://decanlp.com</a></li>
</ul>
<p><strong>Unsupervised pretraining dataset</strong></p>
<ul>
<li><a href="https://googlebooks.byu.edu/">Books corpus</a>: The corpus contains “over 7,000 unique unpublished books from a variety of genres including Adventure, Fantasy, and Romance.”</li>
<li><a href="http://www.statmt.org/lm-benchmark/">1B Word Language Model Benchmark</a></li>
<li><a href="https://en.wikipedia.org/wiki/Wikipedia:Database_download#English-language_Wikipedia">English Wikipedia</a>: ~2500M words</li>
</ul>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019LM,
title = "Generalized Language Models",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/01/31/generalized-language-models.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Bryan McCann, et al. <a href="https://arxiv.org/abs/1708.00107">“Learned in translation: Contextualized word vectors.”</a> NIPS. 2017.</p>
<p>[2] Kevin Clark et al. <a href="https://arxiv.org/abs/1809.08370">“Semi-Supervised Sequence Modeling with Cross-View Training.”</a> EMNLP 2018.</p>
<p>[3] Matthew E. Peters, et al. <a href="https://arxiv.org/abs/1802.05365">“Deep contextualized word representations.”</a> NAACL-HLT 2017.</p>
<p>[4] OpenAI Blog <a href="https://blog.openai.com/language-unsupervised/">“Improving Language Understanding with Unsupervised Learning”</a>, June 11, 2018.</p>
<p>[5] OpenAI Blog <a href="https://blog.openai.com/better-language-models/">“Better Language Models and Their Implications.”</a> Feb 14, 2019.</p>
<p>[6] Jeremy Howard and Sebastian Ruder. <a href="https://arxiv.org/abs/1801.06146">“Universal language model fine-tuning for text classification.”</a> ACL 2018.</p>
<p>[7] Alec Radford et al. <a href="https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf">“Improving Language Understanding by Generative Pre-Training”</a>. OpenAI Blog, June 11, 2018.</p>
<p>[8] Jacob Devlin, et al. <a href="https://arxiv.org/abs/1810.04805">“BERT: Pre-training of deep bidirectional transformers for language understanding.”</a> arXiv:1810.04805 (2018).</p>
<p>[9] Mike Schuster, and Kaisuke Nakajima. <a href="https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf">“Japanese and Korean voice search.”</a> ICASSP. 2012.</p>
<p>[10] Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation</p>
<p>[11] Ashish Vaswani, et al. <a href="https://arxiv.org/abs/1706.03762">“Attention is all you need.”</a> NIPS 2017.</p>
<p>[12] Peter J. Liu, et al. <a href="https://arxiv.org/abs/1801.10198">“Generating wikipedia by summarizing long sequences.”</a> ICLR 2018.</p>
<p>[13] Sebastian Ruder. <a href="http://ruder.io/10-exciting-ideas-of-2018-in-nlp/">“10 Exciting Ideas of 2018 in NLP”</a> Dec 2018.</p>
<p>[14] Alec Radford, et al. <a href="https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf">“Language Models are Unsupervised Multitask Learners.”</a>. 2019.</p>
<p>[15] Rico Sennrich, et al. <a href="https://arxiv.org/abs/1508.07909">“Neural machine translation of rare words with subword units.”</a> arXiv preprint arXiv:1508.07909. 2015.</p>
<p>[16] Zhenzhong Lan, et al. <a href="https://arxiv.org/abs/1909.11942">“ALBERT: A Lite BERT for Self-supervised Learning of Language Representations”</a> arXiv Preprint arXiv:1909.11942 (2019).</p>Lilian WengAs a follow up of word embedding post, we will discuss the models on learning contextualized word vectors, as well as the new trend in large unsupervised pre-trained language models which have achieved amazing SOTA results on a variety of language tasks.Object Detection Part 4: Fast Detection Models2018-12-27T12:00:00+00:002018-12-27T12:00:00+00:00https://lilianweng.github.io/lil-log/2018/12/27/object-detection-part-4<blockquote>
<p>Part 4 of the “Object Detection for Dummies” series focuses on one-stage models for fast detection, including SSD, RetinaNet, and models in the YOLO family. These models skip the explicit region proposal stage but apply the detection directly on dense sampled areas.</p>
</blockquote>
<!--more-->
<p>In <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html">Part 3</a>, we have reviewed models in the R-CNN family. All of them are region-based object detection algorithms. They can achieve high accuracy but could be too slow for certain applications such as autonomous driving. In Part 4, we only focus on fast object detection models, including SSD, RetinaNet, and models in the YOLO family.</p>
<p>Links to all the posts in the series:
[<a href="/lil-log/2017/10/29/object-recognition-for-dummies-part-1.html">Part 1</a>]
[<a href="/lil-log/2017/12/15/object-recognition-for-dummies-part-2.html">Part 2</a>]
[<a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html">Part 3</a>]
[<a href="/lil-log/2018/12/27/object-detection-part-4.html">Part 4</a>].</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#two-stage-vs-one-stage-detectors" id="markdown-toc-two-stage-vs-one-stage-detectors">Two-stage vs One-stage Detectors</a></li>
<li><a href="#yolo-you-only-look-once" id="markdown-toc-yolo-you-only-look-once">YOLO: You Only Look Once</a> <ul>
<li><a href="#workflow" id="markdown-toc-workflow">Workflow</a></li>
<li><a href="#network-architecture" id="markdown-toc-network-architecture">Network Architecture</a></li>
<li><a href="#loss-function" id="markdown-toc-loss-function">Loss Function</a></li>
</ul>
</li>
<li><a href="#ssd-single-shot-multibox-detector" id="markdown-toc-ssd-single-shot-multibox-detector">SSD: Single Shot MultiBox Detector</a> <ul>
<li><a href="#image-pyramid" id="markdown-toc-image-pyramid">Image Pyramid</a></li>
<li><a href="#workflow-1" id="markdown-toc-workflow-1">Workflow</a></li>
<li><a href="#loss-function-1" id="markdown-toc-loss-function-1">Loss Function</a></li>
</ul>
</li>
<li><a href="#yolov2--yolo9000" id="markdown-toc-yolov2--yolo9000">YOLOv2 / YOLO9000</a> <ul>
<li><a href="#yolov2-improvement" id="markdown-toc-yolov2-improvement">YOLOv2 Improvement</a></li>
<li><a href="#yolo9000-rich-dataset-training" id="markdown-toc-yolo9000-rich-dataset-training">YOLO9000: Rich Dataset Training</a></li>
</ul>
</li>
<li><a href="#retinanet" id="markdown-toc-retinanet">RetinaNet</a> <ul>
<li><a href="#focal-loss" id="markdown-toc-focal-loss">Focal Loss</a></li>
<li><a href="#featurized-image-pyramid" id="markdown-toc-featurized-image-pyramid">Featurized Image Pyramid</a></li>
<li><a href="#model-architecture" id="markdown-toc-model-architecture">Model Architecture</a></li>
</ul>
</li>
<li><a href="#yolov3" id="markdown-toc-yolov3">YOLOv3</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="two-stage-vs-one-stage-detectors">Two-stage vs One-stage Detectors</h2>
<p>Models in the R-CNN family are all region-based. The detection happens in two stages: (1) First, the model proposes a set of regions of interests by select search or regional proposal network. The proposed regions are sparse as the potential bounding box candidates can be infinite. (2) Then a classifier only processes the region candidates.</p>
<p>The other different approach skips the region proposal stage and runs detection directly over a dense sampling of possible locations. This is how a one-stage object detection algorithm works. This is faster and simpler, but might potentially drag down the performance a bit.</p>
<p>All the models introduced in this post are one-stage detectors.</p>
<h2 id="yolo-you-only-look-once">YOLO: You Only Look Once</h2>
<p>The <strong>YOLO</strong> model (<strong>“You Only Look Once”</strong>; <a href="https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Redmon_You_Only_Look_CVPR_2016_paper.pdf">Redmon et al., 2016</a>) is the very first attempt at building a fast real-time object detector. Because YOLO does not undergo the region proposal step and only predicts over a limited number of bounding boxes, it is able to do inference super fast.</p>
<h3 id="workflow">Workflow</h3>
<ol>
<li>
<p><strong>Pre-train</strong> a CNN network on image classification task.</p>
</li>
<li>Split an image into <script type="math/tex">S \times S</script> cells. If an object’s center falls into a cell, that cell is “responsible” for detecting the existence of that object. Each cell predicts (a) the location of <script type="math/tex">B</script> bounding boxes, (b) a confidence score, and (c) a probability of object class conditioned on the existence of an object in the bounding box.
<br />
<br />
<ul>
<li>The <strong>coordinates</strong> of bounding box are defined by a tuple of 4 values, (center x-coord, center y-coord, width, height) — <script type="math/tex">(x, y, w, h)</script>, where <script type="math/tex">x</script> and <script type="math/tex">y</script> are set to be offset of a cell location. Moreover, <script type="math/tex">x</script>, <script type="math/tex">y</script>, <script type="math/tex">w</script> and <script type="math/tex">h</script> are normalized by the image width and height, and thus all between (0, 1].</li>
<li>A <strong>confidence score</strong> indicates the likelihood that the cell contains an object: <code class="language-plaintext highlighter-rouge">Pr(containing an object) x IoU(pred, truth)</code>; where <code class="language-plaintext highlighter-rouge">Pr</code> = probability and <code class="language-plaintext highlighter-rouge">IoU</code> = interaction under union.</li>
<li>If the cell contains an object, it predicts a <strong>probability</strong> of this object belonging to every class <script type="math/tex">C_i, i=1, \dots, K</script>: <code class="language-plaintext highlighter-rouge">Pr(the object belongs to the class C_i | containing an object)</code>. At this stage, the model only predicts one set of class probabilities per cell, regardless of the number of bounding boxes, <script type="math/tex">B</script>.</li>
<li>In total, one image contains <script type="math/tex">S \times S \times B</script> bounding boxes, each box corresponding to 4 location predictions, 1 confidence score, and K conditional probabilities for object classification. The total prediction values for one image is <script type="math/tex">S \times S \times (5B + K)</script>, which is the tensor shape of the final conv layer of the model.
<br />
<br /></li>
</ul>
</li>
<li>The final layer of the pre-trained CNN is modified to output a prediction tensor of size <script type="math/tex">S \times S \times (5B + K)</script>.</li>
</ol>
<p class="center"><img src="/lil-log/assets/images/yolo.png" alt="YOLO workflow" /></p>
<p><em>Fig. 1. The workflow of YOLO model. (Image source: <a href="https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Redmon_You_Only_Look_CVPR_2016_paper.pdf">original paper</a>)</em></p>
<h3 id="network-architecture">Network Architecture</h3>
<p>The base model is similar to <a href="https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf">GoogLeNet</a> with inception module replaced by 1x1 and 3x3 conv layers. The final prediction of shape <script type="math/tex">S \times S \times (5B + K)</script> is produced by two fully connected layers over the whole conv feature map.</p>
<p class="center"><img src="/lil-log/assets/images/yolo-network-architecture.png" alt="YOLO architecture" /></p>
<p><em>Fig. 2. The network architecture of YOLO.</em></p>
<h3 id="loss-function">Loss Function</h3>
<p>The loss consists of two parts, the <em>localization loss</em> for bounding box offset prediction and the <em>classification loss</em> for conditional class probabilities. Both parts are computed as the sum of squared errors. Two scale parameters are used to control how much we want to increase the loss from bounding box coordinate predictions (<script type="math/tex">\lambda_\text{coord}</script>) and how much we want to decrease the loss of confidence score predictions for boxes without objects (<script type="math/tex">\lambda_\text{noobj}</script>). Down-weighting the loss contributed by background boxes is important as most of the bounding boxes involve no instance. In the paper, the model sets <script type="math/tex">\lambda_\text{coord} = 5</script> and <script type="math/tex">\lambda_\text{noobj} = 0.5</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}_\text{loc} &= \lambda_\text{coord} \sum_{i=0}^{S^2} \sum_{j=0}^B \mathbb{1}_{ij}^\text{obj} [(x_i - \hat{x}_i)^2 + (y_i - \hat{y}_i)^2 + (\sqrt{w_i} - \sqrt{\hat{w}_i})^2 + (\sqrt{h_i} - \sqrt{\hat{h}_i})^2 ] \\
\mathcal{L}_\text{cls} &= \sum_{i=0}^{S^2} \sum_{j=0}^B \big( \mathbb{1}_{ij}^\text{obj} + \lambda_\text{noobj} (1 - \mathbb{1}_{ij}^\text{obj})\big) (C_{ij} - \hat{C}_{ij})^2 + \sum_{i=0}^{S^2} \sum_{c \in \mathcal{C}} \mathbb{1}_i^\text{obj} (p_i(c) - \hat{p}_i(c))^2\\
\mathcal{L} &= \mathcal{L}_\text{loc} + \mathcal{L}_\text{cls}
\end{aligned} %]]></script>
<blockquote>
<p>NOTE: In the original YOLO paper, the loss function uses <script type="math/tex">C_i</script> instead of <script type="math/tex">C_{ij}</script> as confidence score. I made the correction based on my own understanding, since every bounding box should have its own confidence score. Please kindly let me if you do not agree. Many thanks.</p>
</blockquote>
<p>where,</p>
<ul>
<li><script type="math/tex">\mathbb{1}_i^\text{obj}</script>: An indicator function of whether the cell i contains an object.</li>
<li><script type="math/tex">\mathbb{1}_{ij}^\text{obj}</script>: It indicates whether the j-th bounding box of the cell i is “responsible” for the object prediction (see Fig. 3).</li>
<li><script type="math/tex">C_{ij}</script>: The confidence score of cell i, <code class="language-plaintext highlighter-rouge">Pr(containing an object) * IoU(pred, truth)</code>.</li>
<li><script type="math/tex">\hat{C}_{ij}</script>: The predicted confidence score.</li>
<li><script type="math/tex">\mathcal{C}</script>: The set of all classes.</li>
<li><script type="math/tex">p_i(c)</script>: The conditional probability of whether cell i contains an object of class <script type="math/tex">c \in \mathcal{C}</script>.</li>
<li><script type="math/tex">\hat{p}_i(c)</script>: The predicted conditional class probability.</li>
</ul>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/yolo-responsible-predictor.png" alt="YOLO responsible predictor" /></p>
<p><em>Fig. 3. At one location, in cell i, the model proposes B bounding box candidates and the one that has highest overlap with the ground truth is the “responsible” predictor.</em></p>
<p>The loss function only penalizes classification error if an object is present in that grid cell, <script type="math/tex">\mathbb{1}_i^\text{obj} = 1</script>. It also only penalizes bounding box coordinate error if that predictor is “responsible” for the ground truth box, <script type="math/tex">\mathbb{1}_{ij}^\text{obj} = 1</script>.</p>
<p>As a one-stage object detector, YOLO is super fast, but it is not good at recognizing irregularly shaped objects or a group of small objects due to a limited number of bounding box candidates.</p>
<h2 id="ssd-single-shot-multibox-detector">SSD: Single Shot MultiBox Detector</h2>
<p>The <strong>Single Shot Detector</strong> (<strong>SSD</strong>; <a href="https://arxiv.org/abs/1512.02325">Liu et al, 2016</a>) is one of the first attempts at using convolutional neural network’s pyramidal feature hierarchy for efficient detection of objects of various sizes.</p>
<h3 id="image-pyramid">Image Pyramid</h3>
<p>SSD uses the <a href="https://arxiv.org/abs/1409.1556">VGG-16</a> model pre-trained on ImageNet as its base model for extracting useful image features.
On top of VGG16, SSD adds several conv feature layers of decreasing sizes. They can be seen as a <em>pyramid representation</em> of images at different scales. Intuitively large fine-grained feature maps at earlier levels are good at capturing small objects and small coarse-grained feature maps can detect large objects well. In SSD, the detection happens in every pyramidal layer, targeting at objects of various sizes.</p>
<p class="center"><img src="/lil-log/assets/images/SSD-architecture.png" alt="SSD architecture" /></p>
<p><em>Fig. 4. The model architecture of SSD.</em></p>
<h3 id="workflow-1">Workflow</h3>
<p>Unlike YOLO, SSD does not split the image into grids of arbitrary size but predicts offset of predefined <em>anchor boxes</em> (this is called “default boxes” in the paper) for every location of the feature map. Each box has a fixed size and position relative to its corresponding cell. All the anchor boxes tile the whole feature map in a convolutional manner.</p>
<p>Feature maps at different levels have different receptive field sizes. The anchor boxes on different levels are rescaled so that one feature map is only responsible for objects at one particular scale. For example, in Fig. 5 the dog can only be detected in the 4x4 feature map (higher level) while the cat is just captured by the 8x8 feature map (lower level).</p>
<p class="center"><img src="/lil-log/assets/images/SSD-framework.png" alt="SSD framework" /></p>
<p><em>Fig. 5. The SSD framework. (a) The training data contains images and ground truth boxes for every object. (b) In a fine-grained feature maps (8 x 8), the anchor boxes of different aspect ratios correspond to smaller area of the raw input. (c) In a coarse-grained feature map (4 x 4), the anchor boxes cover larger area of the raw input. (Image source: <a href="https://arxiv.org/abs/1512.02325">original paper</a>)</em></p>
<p>The width, height and the center location of an anchor box are all normalized to be (0, 1). At a location <script type="math/tex">(i, j)</script> of the <script type="math/tex">\ell</script>-th feature layer of size <script type="math/tex">m \times n</script>, <script type="math/tex">i=1,\dots,n, j=1,\dots,m</script>, we have a unique linear scale proportional to the layer level and 5 different box aspect ratios (width-to-height ratios), in addition to a special scale (why we need this? the paper didn’t explain. maybe just a heuristic trick) when the aspect ratio is 1. This gives us 6 anchor boxes in total per feature cell.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\text{level index: } &\ell = 1, \dots, L \\
\text{scale of boxes: } &s_\ell = s_\text{min} + \frac{s_\text{max} - s_\text{min}}{L - 1} (\ell - 1) \\
\text{aspect ratio: } &r \in \{1, 2, 3, 1/2, 1/3\}\\
\text{additional scale: } & s'_\ell = \sqrt{s_\ell s_{\ell + 1}} \text{ when } r = 1 \text{thus, 6 boxes in total.}\\
\text{width: } &w_\ell^r = s_\ell \sqrt{r} \\
\text{height: } &h_\ell^r = s_\ell / \sqrt{r} \\
\text{center location: } & (x^i_\ell, y^j_\ell) = (\frac{i+0.5}{m}, \frac{j+0.5}{n})
\end{aligned} %]]></script>
<p class="center"><img src="/lil-log/assets/images/SSD-box-scales.png" alt="Box scales" /></p>
<p><em>Fig. 6. An example of how the anchor box size is scaled up with the layer index <script type="math/tex">\ell</script> for <script type="math/tex">L=6, s_\text{min} = 0.2, s_\text{max} = 0.9</script>. Only the boxes of aspect ratio <script type="math/tex">r=1</script> are illustrated.</em></p>
<p>At every location, the model outputs 4 offsets and <script type="math/tex">c</script> class probabilities by applying a <script type="math/tex">3 \times 3 \times p</script> conv filter (where <script type="math/tex">p</script> is the number of channels in the feature map) for every one of <script type="math/tex">k</script> anchor boxes. Therefore, given a feature map of size <script type="math/tex">m \times n</script>, we need <script type="math/tex">kmn(c+4)</script> prediction filters.</p>
<h3 id="loss-function-1">Loss Function</h3>
<p>Same as YOLO, the loss function is the sum of a localization loss and a classification loss.</p>
<script type="math/tex; mode=display">\mathcal{L} = \frac{1}{N}(\mathcal{L}_\text{cls} + \alpha \mathcal{L}_\text{loc})</script>
<p>where <script type="math/tex">N</script> is the number of matched bounding boxes and <script type="math/tex">\alpha</script> balances the weights between two losses, picked by cross validation.</p>
<p>The <em>localization loss</em> is a <a href="https://github.com/rbgirshick/py-faster-rcnn/files/764206/SmoothL1Loss.1.pdf">smooth L1 loss</a> between the predicted bounding box correction and the true values. The coordinate correction transformation is same as what <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html#r-cnn">R-CNN</a> does in <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html#bounding-box-regression">bounding box regression</a>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}_\text{loc} &= \sum_{i,j} \sum_{m\in\{x, y, w, h\}} \mathbb{1}_{ij}^\text{match}
L_1^\text{smooth}(d_m^i - t_m^j)^2\\
L_1^\text{smooth}(x) &= \begin{cases}
0.5 x^2 & \text{if } \vert x \vert < 1\\
\vert x \vert - 0.5 & \text{otherwise}
\end{cases} \\
t^j_x &= (g^j_x - p^i_x) / p^i_w \\
t^j_y &= (g^j_y - p^i_y) / p^i_h \\
t^j_w &= \log(g^j_w / p^i_w) \\
t^j_h &= \log(g^j_h / p^i_h)
\end{aligned} %]]></script>
<p>where <script type="math/tex">\mathbb{1}_{ij}^\text{match}</script> indicates whether the <script type="math/tex">i</script>-th bounding box with coordinates <script type="math/tex">(p^i_x, p^i_y, p^i_w, p^i_h)</script> is matched to the <script type="math/tex">j</script>-th ground truth box with coordinates <script type="math/tex">(g^j_x, g^j_y, g^j_w, g^j_h)</script> for any object. <script type="math/tex">d^i_m, m\in\{x, y, w, h\}</script> are the predicted correction terms. See <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html#bounding-box-regression">this</a> for how the transformation works.</p>
<p>The <em>classification loss</em> is a softmax loss over multiple classes (<a href="https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits">softmax_cross_entropy_with_logits</a> in tensorflow):</p>
<script type="math/tex; mode=display">\mathcal{L}_\text{cls} = -\sum_{i \in \text{pos}} \mathbb{1}_{ij}^k \log(\hat{c}_i^k) - \sum_{i \in \text{neg}} \log(\hat{c}_i^0)\text{, where }\hat{c}_i^k = \text{softmax}(c_i^k)</script>
<p>where <script type="math/tex">\mathbb{1}_{ij}^k</script> indicates whether the <script type="math/tex">i</script>-th bounding box and the <script type="math/tex">j</script>-th ground truth box are matched for an object in class <script type="math/tex">k</script>. <script type="math/tex">\text{pos}</script> is the set of matched bounding boxes (<script type="math/tex">N</script> items in total) and <script type="math/tex">\text{neg}</script> is the set of negative examples. SSD uses <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html#common-tricks">hard negative mining</a> to select easily misclassified negative examples to construct this <script type="math/tex">\text{neg}</script> set: Once all the anchor boxes are sorted by objectiveness confidence score, the model picks the top candidates for training so that neg:pos is at most 3:1.</p>
<h2 id="yolov2--yolo9000">YOLOv2 / YOLO9000</h2>
<p><strong>YOLOv2</strong> (<a href="https://arxiv.org/abs/1612.08242">Redmon & Farhadi, 2017</a>) is an enhanced version of YOLO. <strong>YOLO9000</strong> is built on top of YOLOv2 but trained with joint dataset combining the COCO detection dataset and the top 9000 classes from ImageNet.</p>
<h3 id="yolov2-improvement">YOLOv2 Improvement</h3>
<p>A variety of modifications are applied to make YOLO prediction more accurate and faster, including:</p>
<p><strong>1. BatchNorm helps</strong>: Add <em>batch norm</em> on all the convolutional layers, leading to significant improvement over convergence.</p>
<p><strong>2. Image resolution matters</strong>: Fine-tuning the base model with <em>high resolution</em> images improves the detection performance.</p>
<p><strong>3. Convolutional anchor box detection</strong>: Rather than predicts the bounding box position with fully-connected layers over the whole feature map, YOLOv2 uses <em>convolutional layers</em> to predict locations of <em>anchor boxes</em>, like in faster R-CNN. The prediction of spatial locations and class probabilities are decoupled. Overall, the change leads to a slight decrease in mAP, but an increase in recall.</p>
<p><strong>4. K-mean clustering of box dimensions</strong>: Different from faster R-CNN that uses hand-picked sizes of anchor boxes, YOLOv2 runs k-mean clustering on the training data to find good priors on anchor box dimensions. The distance metric is designed to <em>rely on IoU scores</em>:</p>
<script type="math/tex; mode=display">\text{dist}(x, c_i) = 1 - \text{IoU}(x, c_i), i=1,\dots,k</script>
<p>where <script type="math/tex">x</script> is a ground truth box candidate and <script type="math/tex">c_i</script> is one of the centroids. The best number of centroids (anchor boxes) <script type="math/tex">k</script> can be chosen by the <a href="https://en.wikipedia.org/wiki/Elbow_method_(clustering)">elbow method</a>.</p>
<p>The anchor boxes generated by clustering provide better average IoU conditioned on a fixed number of boxes.</p>
<p><strong>5. Direct location prediction</strong>: YOLOv2 formulates the bounding box prediction in a way that it would <em>not diverge</em> from the center location too much. If the box location prediction can place the box in any part of the image, like in regional proposal network, the model training could become unstable.</p>
<p>Given the anchor box of size <script type="math/tex">(p_w, p_h)</script> at the grid cell with its top left corner at <script type="math/tex">(c_x, c_y)</script>, the model predicts the offset and the scale, <script type="math/tex">(t_x, t_y, t_w, t_h)</script> and the corresponding predicted bounding box <script type="math/tex">b</script> has center <script type="math/tex">(b_x, b_y)</script> and size <script type="math/tex">(b_w, b_h)</script>. The confidence score is the sigmoid (<script type="math/tex">\sigma</script>) of another output <script type="math/tex">t_o</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
b_x &= \sigma(t_x) + c_x\\
b_y &= \sigma(t_y) + c_y\\
b_w &= p_w e^{t_w}\\
b_h &= p_h e^{t_h}\\
\text{Pr}(\text{object}) &\cdot \text{IoU}(b, \text{object}) = \sigma(t_o)
\end{aligned} %]]></script>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/yolov2-loc-prediction.png" alt="YOLOv2 bbox location prediction" /></p>
<p><em>Fig. 7. YOLOv2 bounding box location prediction. (Image source: <a href="https://arxiv.org/abs/1612.08242">original paper</a>)</em></p>
<p><strong>6. Add fine-grained features</strong>: YOLOv2 adds a passthrough layer to bring <em>fine-grained features</em> from an earlier layer to the last output layer. The mechanism of this passthrough layer is similar to <em>identity mappings in ResNet</em> to extract higher-dimensional features from previous layers. This leads to 1% performance increase.</p>
<p><strong>7. Multi-scale training</strong>: In order to train the model to be robust to input images of different sizes, a <em>new size</em> of input dimension is <em>randomly sampled</em> every 10 batches. Since conv layers of YOLOv2 downsample the input dimension by a factor of 32, the newly sampled size is a multiple of 32.</p>
<p><strong>8. Light-weighted base model</strong>: To make prediction even faster, YOLOv2 adopts a light-weighted base model, DarkNet-19, which has 19 conv layers and 5 max-pooling layers. The key point is to insert avg poolings and 1x1 conv filters between 3x3 conv layers.</p>
<h3 id="yolo9000-rich-dataset-training">YOLO9000: Rich Dataset Training</h3>
<p>Because drawing bounding boxes on images for object detection is much more expensive than tagging images for classification, the paper proposed a way to combine small object detection dataset with large ImageNet so that the model can be exposed to a much larger number of object categories. The name of YOLO9000 comes from the top 9000 classes in ImageNet. During joint training, if an input image comes from the classification dataset, it only backpropagates the classification loss.</p>
<p>The detection dataset has much fewer and more general labels and, moreover, labels cross multiple datasets are often not mutually exclusive. For example, ImageNet has a label “Persian cat” while in COCO the same image would be labeled as “cat”. Without mutual exclusiveness, it does not make sense to apply softmax over all the classes.</p>
<p>In order to efficiently merge ImageNet labels (1000 classes, fine-grained) with COCO/PASCAL (< 100 classes, coarse-grained), YOLO9000 built a hierarchical tree structure with reference to <a href="https://wordnet.princeton.edu/">WordNet</a> so that general labels are closer to the root and the fine-grained class labels are leaves. In this way, “cat” is the parent node of “Persian cat”.</p>
<p style="width:100%;" class="center"><img src="/lil-log/assets/images/word-tree.png" alt="WordTree" /></p>
<p><em>Fig. 8. The WordTree hierarchy merges labels from COCO and ImageNet. Blue nodes are COCO labels and red nodes are ImageNet labels. (Image source: <a href="https://arxiv.org/abs/1612.08242">original paper</a>)</em></p>
<p>To predict the probability of a class node, we can follow the path from the node to the root:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Pr("persian cat" | contain a "physical object")
= Pr("persian cat" | "cat")
Pr("cat" | "animal")
Pr("animal" | "physical object")
Pr(contain a "physical object") # confidence score.
</code></pre></div></div>
<p>Note that <code class="language-plaintext highlighter-rouge">Pr(contain a "physical object")</code> is the confidence score, predicted separately in the bounding box detection pipeline. The path of conditional probability prediction can stop at any step, depending on which labels are available.</p>
<h2 id="retinanet">RetinaNet</h2>
<p>The <strong>RetinaNet</strong> (<a href="https://arxiv.org/abs/1708.02002">Lin et al., 2018</a>) is a one-stage dense object detector. Two crucial building blocks are <em>featurized image pyramid</em> and the use of <em>focal loss</em>.</p>
<h3 id="focal-loss">Focal Loss</h3>
<p>One issue for object detection model training is an extreme imbalance between background that contains no object and foreground that holds objects of interests. <strong>Focal loss</strong> is designed to assign more weights on hard, easily misclassified examples (i.e. background with noisy texture or partial object) and to down-weight easy examples (i.e. obviously empty background).</p>
<p>Starting with a normal cross entropy loss for binary classification,</p>
<script type="math/tex; mode=display">\text{CE}(p, y) = -y\log p - (1-y)\log(1-p)</script>
<p>where <script type="math/tex">y \in \{0, 1\}</script> is a ground truth binary label, indicating whether a bounding box contains a object, and <script type="math/tex">p \in [0, 1]</script> is the predicted probability of objectiveness (aka confidence score).</p>
<p>For notational convenience,</p>
<script type="math/tex; mode=display">% <![CDATA[
\text{let } p_t = \begin{cases}
p & \text{if } y = 1\\
1-p & \text{otherwise}
\end{cases},
\text{then } \text{CE}(p, y)=\text{CE}(p_t) = -\log p_t %]]></script>
<p>Easily classified examples with large <script type="math/tex">p_t \gg 0.5</script>, that is, when <script type="math/tex">p</script> is very close to 0 (when y=0) or 1 (when y=1), can incur a loss with non-trivial magnitude. Focal loss explicitly adds a weighting factor <script type="math/tex">(1-p_t)^\gamma, \gamma \geq 0</script> to each term in cross entropy so that the weight is small when <script type="math/tex">p_t</script> is large and therefore easy examples are down-weighted.</p>
<script type="math/tex; mode=display">\text{FL}(p_t) = -(1-p_t)^\gamma \log p_t</script>
<p style="width:65%;" class="center"><img src="/lil-log/assets/images/focal-loss.png" alt="Focal Loss" /></p>
<p><em>Fig. 9. The focal loss focuses less on easy examples with a factor of <script type="math/tex">(1-p_t)^\gamma</script>. (Image source: <a href="https://arxiv.org/abs/1708.02002">original paper</a>)</em></p>
<p>For a better control of the shape of the weighting function (see Fig. 10.), RetinaNet uses an <script type="math/tex">\alpha</script>-balanced variant of the focal loss, where <script type="math/tex">\alpha=0.25, \gamma=2</script> works the best.</p>
<script type="math/tex; mode=display">\text{FL}(p_t) = -\alpha (1-p_t)^\gamma \log p_t</script>
<p style="width:90%;" class="center"><img src="/lil-log/assets/images/focal-loss-weights.png" alt="WordTree" /></p>
<p><em>Fig. 10. The plot of focal loss weights <script type="math/tex">\alpha (1-p_t)^\gamma</script> as a function of <script type="math/tex">p_t</script>, given different values of <script type="math/tex">\alpha</script> and <script type="math/tex">\gamma</script>.</em></p>
<h3 id="featurized-image-pyramid">Featurized Image Pyramid</h3>
<p>The <strong>featurized image pyramid</strong> (<a href="https://arxiv.org/abs/1612.03144">Lin et al., 2017</a>) is the backbone network for RetinaNet. Following the same approach by <a href="#image-pyramid">image pyramid</a> in SSD, featurized image pyramids provide a basic vision component for object detection at different scales.</p>
<p>The key idea of feature pyramid network is demonstrated in Fig. 11. The base structure contains a sequence of <em>pyramid levels</em>, each corresponding to one network <em>stage</em>. One stage contains multiple convolutional layers of the same size and the stage sizes are scaled down by a factor of 2. Let’s denote the last layer of the <script type="math/tex">i</script>-th stage as <script type="math/tex">C_i</script>.</p>
<p style="width:100%;" class="center"><img src="/lil-log/assets/images/featurized-image-pyramid.png" alt="Featurized image pyramid" /></p>
<p><em>Fig. 11. The illustration of the featurized image pyramid module. (Replot based on figure 3 in <a href="https://arxiv.org/abs/1612.03144">FPN paper</a>)</em></p>
<p>Two pathways connect conv layers:</p>
<ul>
<li><strong>Bottom-up pathway</strong> is the normal feedforward computation.</li>
<li><strong>Top-down pathway</strong> goes in the inverse direction, adding coarse but semantically stronger feature maps back into the previous pyramid levels of a larger size via lateral connections.
<ul>
<li>First, the higher-level features are upsampled spatially coarser to be 2x larger. For image upscaling, the paper used nearest neighbor upsampling. While there are many <a href="https://en.wikipedia.org/wiki/Image_scaling#Algorithms">image upscaling algorithms</a> such as using <a href="https://www.tensorflow.org/api_docs/python/tf/layers/conv2d_transpose">deconv</a>, adopting another image scaling method might or might not improve the performance of RetinaNet.</li>
<li>The larger feature map undergoes a 1x1 conv layer to reduce the channel dimension.</li>
<li>Finally, these two feature maps are merged by element-wise addition.
<br />
<br />
The lateral connections only happen at the last layer in stages, denoted as <script type="math/tex">\{C_i\}</script>, and the process continues until the finest (largest) merged feature map is generated. The prediction is made out of every merged map after a 3x3 conv layer, <script type="math/tex">\{P_i\}</script>.</li>
</ul>
</li>
</ul>
<p>According to ablation studies, the importance rank of components of the featurized image pyramid design is as follows: <strong>1x1 lateral connection</strong> > detect object across multiple layers > top-down enrichment > pyramid representation (compared to only check the finest layer).</p>
<h3 id="model-architecture">Model Architecture</h3>
<p>The featurized pyramid is constructed on top of the ResNet architecture. Recall that <a href="TBA">ResNet</a> has 5 conv blocks (= network stages / pyramid levels). The last layer of the <script type="math/tex">i</script>-th pyramid level, <script type="math/tex">C_i</script>, has resolution <script type="math/tex">2^i</script> lower than the raw input dimension.</p>
<p>RetinaNet utilizes feature pyramid levels <script type="math/tex">P_3</script> to <script type="math/tex">P_7</script>:</p>
<ul>
<li><script type="math/tex">P_3</script> to <script type="math/tex">P_5</script> are computed from the corresponding ResNet residual stage from <script type="math/tex">C_3</script> to <script type="math/tex">C_5</script>. They are connected by both top-down and bottom-up pathways.</li>
<li><script type="math/tex">P_6</script> is obtained via a 3×3 stride-2 conv on top of <script type="math/tex">C_5</script></li>
<li><script type="math/tex">P_7</script> applies ReLU and a 3×3 stride-2 conv on <script type="math/tex">P_6</script>.</li>
</ul>
<p>Adding higher pyramid levels on ResNet improves the performance for detecting large objects.</p>
<p>Same as in SSD, detection happens in all pyramid levels by making a prediction out of every merged feature map. Because predictions share the same classifier and the box regressor, they are all formed to have the same channel dimension d=256.</p>
<p>There are A=9 anchor boxes per level:</p>
<ul>
<li>The base size corresponds to areas of <script type="math/tex">32^2</script> to <script type="math/tex">512^2</script> pixels on <script type="math/tex">P_3</script> to <script type="math/tex">P_7</script> respectively. There are three size ratios, <script type="math/tex">\{2^0, 2^{1/3}, 2^{2/3}\}</script>.</li>
<li>For each size, there are three aspect ratios {1/2, 1, 2}.</li>
</ul>
<p>As usual, for each anchor box, the model outputs a class probability for each of <script type="math/tex">K</script> classes in the classification subnet and regresses the offset from this anchor box to the nearest ground truth object in the box regression subnet. The classification subnet adopts the focal loss introduced above.</p>
<p style="width:100%;" class="center"><img src="/lil-log/assets/images/retina-net.png" alt="RetinaNet" /></p>
<p><em>Fig. 12. The RetinaNet model architecture uses a <a href="https://arxiv.org/abs/1612.03144">FPN</a> backbone on top of ResNet. (Image source: the <a href="https://arxiv.org/abs/1612.03144">FPN</a> paper)</em></p>
<h2 id="yolov3">YOLOv3</h2>
<p><a href="https://pjreddie.com/media/files/papers/YOLOv3.pdf">YOLOv3</a> is created by applying a bunch of design tricks on YOLOv2. The changes are inspired by recent advances in the object detection world.</p>
<p>Here are a list of changes:</p>
<p><strong>1. Logistic regression for confidence scores</strong>: YOLOv3 predicts an confidence score for each bounding box using <em>logistic regression</em>, while YOLO and YOLOv2 uses sum of squared errors for classification terms (see the <a href="#loss-function">loss function</a> above). Linear regression of offset prediction leads to a decrease in mAP.</p>
<p><strong>2. No more softmax for class prediction</strong>: When predicting class confidence, YOLOv3 uses <em>multiple independent logistic classifier</em> for each class rather than one softmax layer. This is very helpful especially considering that one image might have multiple labels and not all the labels are guaranteed to be mutually exclusive.</p>
<p><strong>3. Darknet + ResNet as the base model</strong>: The new Darknet-53 still relies on successive 3x3 and 1x1 conv layers, just like the original dark net architecture, but has residual blocks added.</p>
<p><strong>4. Multi-scale prediction</strong>: Inspired by image pyramid, YOLOv3 adds several conv layers after the base feature extractor model and makes prediction at three different scales among these conv layers. In this way, it has to deal with many more bounding box candidates of various sizes overall.</p>
<p><strong>5. Skip-layer concatenation</strong>: YOLOv3 also adds cross-layer connections between two prediction layers (except for the output layer) and earlier finer-grained feature maps. The model first up-samples the coarse feature maps and then merges it with the previous features by concatenation. The combination with finer-grained information makes it better at detecting small objects.</p>
<p>Interestingly, focal loss does not help YOLOv3, potentially it might be due to the usage of <script type="math/tex">\lambda_\text{noobj}</script> and <script type="math/tex">\lambda_\text{coord}</script> — they increase the loss from bounding box location predictions and decrease the loss from confidence predictions for background boxes.</p>
<p>Overall YOLOv3 performs better and faster than SSD, and worse than RetinaNet but 3.8x faster.</p>
<p style="width:80%;" class="center"><img src="/lil-log/assets/images/yolov3-perf.png" alt="YOLOv3 performance" /></p>
<p><em>Fig. 13. The comparison of various fast object detection models on speed and mAP performance. (Image source: <a href="https://arxiv.org/abs/1708.02002">focal loss</a> paper with additional labels from the <a href="https://pjreddie.com/media/files/papers/YOLOv3.pdf">YOLOv3</a> paper.)</em></p>
<hr />
<p>Cited as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2018detection4,
title = "Object Detection Part 4: Fast Detection Models",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2018",
url = "http://lilianweng.github.io/lil-log/2018/12/27/object-detection-part-4.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Joseph Redmon, et al. <a href="https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Redmon_You_Only_Look_CVPR_2016_paper.pdf">“You only look once: Unified, real-time object detection.”</a> CVPR 2016.</p>
<p>[2] Joseph Redmon and Ali Farhadi. <a href="http://openaccess.thecvf.com/content_cvpr_2017/papers/Redmon_YOLO9000_Better_Faster_CVPR_2017_paper.pdf">“YOLO9000: Better, Faster, Stronger.”</a> CVPR 2017.</p>
<p>[3] Joseph Redmon, Ali Farhadi. <a href="https://pjreddie.com/media/files/papers/YOLOv3.pdf">“YOLOv3: An incremental improvement.”</a>.</p>
<p>[4] Wei Liu et al. <a href="https://arxiv.org/abs/1512.02325">“SSD: Single Shot MultiBox Detector.”</a> ECCV 2016.</p>
<p>[5] Tsung-Yi Lin, et al. <a href="https://arxiv.org/abs/1612.03144">“Feature Pyramid Networks for Object Detection.”</a> CVPR 2017.</p>
<p>[6] Tsung-Yi Lin, et al. <a href="https://arxiv.org/abs/1708.02002">“Focal Loss for Dense Object Detection.”</a> IEEE transactions on pattern analysis and machine intelligence, 2018.</p>
<p>[7] <a href="https://towardsdatascience.com/yolo-v3-object-detection-53fb7d3bfe6b">“What’s new in YOLO v3?”</a> by Ayoosh Kathuria on “Towards Data Science”, Apr 23, 2018.</p>Lilian WengPart 4 of the “Object Detection for Dummies” series focuses on one-stage models for fast detection, including SSD, RetinaNet, and models in the YOLO family. These models skip the explicit region proposal stage but apply the detection directly on dense sampled areas.