Jekyll2019-07-27T03:01:59+00:00https://lilianweng.github.io/lil-log/feed.xmlLil’LogDocument my learning notes.Lilian WengMeta Reinforcement Learning2019-06-23T12:00:00+00:002019-06-23T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/06/23/meta-reinforcement-learning<blockquote>
<p>Meta-RL is meta-learning on reinforcement learning tasks. After trained over a distribution of tasks, the agent is able to solve a new task by developing a new RL algorithm with its internal activity dynamics. This post starts with the origin of meta-RL and then dives into three key components of meta-RL.</p>
</blockquote>
<!--more-->
<p>In my earlier post on <a href="/lil-log/2018/11/30/meta-learning.html">meta-learning</a>, the problem is mainly defined in the context of few-shot classification. Here I would like to explore more into cases when we try to “meta-learn” <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">Reinforcement Learning (RL)</a> tasks by developing an agent that can solve unseen tasks fast and efficiently.</p>
<p>To recap, a good meta-learning model is expected to generalize to new tasks or new environments that have never been encountered during training. The adaptation process, essentially a <em>mini learning session</em>, happens at test with limited exposure to the new configurations. Even without any explicit fine-tuning (no gradient backpropagation on trainable variables), the meta-learning model autonomously adjusts internal hidden states to learn.</p>
<p>Training RL algorithms can be notoriously difficult sometimes. If the meta-learning agent could become so smart that the distribution of solvable unseen tasks grows extremely broad, we are on track towards <a href="http://incompleteideas.net/IncIdeas/BitterLesson.html">general purpose methods</a> — essentially building a “brain” which would solve all kinds of RL problems without much human interference or manual feature engineering. Sounds amazing, right? 💖</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#on-the-origin-of-meta-rl" id="markdown-toc-on-the-origin-of-meta-rl">On the Origin of Meta-RL</a> <ul>
<li><a href="#back-in-2001" id="markdown-toc-back-in-2001">Back in 2001</a></li>
<li><a href="#proposal-in-2016" id="markdown-toc-proposal-in-2016">Proposal in 2016</a></li>
</ul>
</li>
<li><a href="#define-meta-rl" id="markdown-toc-define-meta-rl">Define Meta-RL</a> <ul>
<li><a href="#formulation" id="markdown-toc-formulation">Formulation</a></li>
<li><a href="#main-differences-from-rl" id="markdown-toc-main-differences-from-rl">Main Differences from RL</a></li>
<li><a href="#key-components" id="markdown-toc-key-components">Key Components</a></li>
</ul>
</li>
<li><a href="#meta-learning-algorithms-for-meta-rl" id="markdown-toc-meta-learning-algorithms-for-meta-rl">Meta-Learning Algorithms for Meta-RL</a> <ul>
<li><a href="#optimizing-model-weights-for-meta-learning" id="markdown-toc-optimizing-model-weights-for-meta-learning">Optimizing Model Weights for Meta-learning</a></li>
<li><a href="#meta-learning-hyperparameters" id="markdown-toc-meta-learning-hyperparameters">Meta-learning Hyperparameters</a></li>
<li><a href="#meta-learning-the-loss-function" id="markdown-toc-meta-learning-the-loss-function">Meta-learning the Loss Function</a></li>
<li><a href="#meta-learning-the-exploration-strategies" id="markdown-toc-meta-learning-the-exploration-strategies">Meta-learning the Exploration Strategies</a></li>
<li><a href="#episodic-control" id="markdown-toc-episodic-control">Episodic Control</a></li>
</ul>
</li>
<li><a href="#training-task-acquisition" id="markdown-toc-training-task-acquisition">Training Task Acquisition</a> <ul>
<li><a href="#task-generation-by-domain-randomization" id="markdown-toc-task-generation-by-domain-randomization">Task Generation by Domain Randomization</a></li>
<li><a href="#evolutionary-algorithm-on-environment-generation" id="markdown-toc-evolutionary-algorithm-on-environment-generation">Evolutionary Algorithm on Environment Generation</a></li>
<li><a href="#learning-with-random-rewards" id="markdown-toc-learning-with-random-rewards">Learning with Random Rewards</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="on-the-origin-of-meta-rl">On the Origin of Meta-RL</h2>
<h3 id="back-in-2001">Back in 2001</h3>
<p>I encountered a paper written in 2001 by <a href="http://snowedin.net/tmp/Hochreiter2001.pdf">Hochreiter et al.</a> when reading <a href="https://arxiv.org/pdf/1611.05763.pdf">Wang et al., 2016</a>. Although the idea was proposed for supervised learning, there are so many resemblances to the current approach to meta-RL.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/Hochreiter-meta-learning.png" alt="Hochreiter 2001" /></p>
<p><em>Fig. 1. The meta-learning system consists of the supervisory and the subordinate systems. The subordinate system is a recurrent neural network that takes as input both the observation at the current time step, <script type="math/tex">x_t</script> and the label at the last time step, <script type="math/tex">y_{t-1}</script>. (Image source: <a href="http://snowedin.net/tmp/Hochreiter2001.pdf">Hochreiter et al., 2001</a>)</em></p>
<p>Hochreiter’s meta-learning model is a recurrent network with LSTM cell. LSTM is a good choice because it can internalize a history of inputs and tune its own weights effectively through <a href="https://en.wikipedia.org/wiki/Backpropagation_through_time">BPTT</a>. The training data contains <script type="math/tex">K</script> sequences and each sequence is consist of <script type="math/tex">N</script> samples generated by a target function <script type="math/tex">f_k(.), k=1, \dots, K</script>,</p>
<script type="math/tex; mode=display">\{\text{input: }(\mathbf{x}^k_i, \mathbf{y}^k_{i-1}) \to \text{label: }\mathbf{y}^k_i\}_{i=1}^N
\text{ where }\mathbf{y}^k_i = f_k(\mathbf{x}^k_i)</script>
<p>Noted that <em>the last label</em> <script type="math/tex">\mathbf{y}^k_{i-1}</script> is also provided as an auxiliary input so that the function can learn the presented mapping.</p>
<p>In the experiment of decoding two-dimensional quadratic functions, <script type="math/tex">a x_1^2 + b x_2^2 + c x_1 x_2 + d x_1 + e x_2 + f</script>, with coefficients <script type="math/tex">a</script>-<script type="math/tex">f</script> are randomly sampled from [-1, 1], this meta-learning system was able to approximate the function after seeing only ~35 examples.</p>
<h3 id="proposal-in-2016">Proposal in 2016</h3>
<p>In the modern days of DL, <a href="https://arxiv.org/abs/1611.05763">Wang et al.</a> (2016) and <a href="https://arxiv.org/abs/1611.02779">Duan et al.</a> (2017) simultaneously proposed the very similar idea of <strong>Meta-RL</strong> (it is called <strong>RL^2</strong> in the second paper). A meta-RL model is trained over a distribution of MDPs, and at test time, it is able to learn to solve a new task quickly. The goal of meta-RL is ambitious, taking one step further towards general algorithms.</p>
<h2 id="define-meta-rl">Define Meta-RL</h2>
<p><em>Meta Reinforcement Learning</em>, in short, is to do <a href="/lil-log/2018/11/30/meta-learning.html">meta-learning</a> in the field of <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">reinforcement learning</a>. Usually the train and test tasks are different but drawn from the same family of problems; i.e., experiments in the papers included multi-armed bandit with different reward probabilities, mazes with different layouts, same robots but with different physical parameters in simulator, and many others.</p>
<h3 id="formulation">Formulation</h3>
<p>Let’s say we have a distribution of tasks, each formularized as an <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#markov-decision-processes">MDP</a> (Markov Decision Process), <script type="math/tex">M_i \in \mathcal{M}</script>. An MDP is determined by a 4-tuple, <script type="math/tex">M_i= \langle \mathcal{S}, \mathcal{A}, P_i, R_i \rangle</script>:</p>
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr>
<td><script type="math/tex">\mathcal{S}</script></td>
<td>A set of states.</td>
</tr>
<tr>
<td><script type="math/tex">\mathcal{A}</script></td>
<td>A set of actions.</td>
</tr>
<tr>
<td><script type="math/tex">P_i: \mathcal{S} \times \mathcal{A} \times \mathcal{S} \to \mathbb{R}_{+}</script></td>
<td>Transition probability function.</td>
</tr>
<tr>
<td><script type="math/tex">R_i: \mathcal{S} \times \mathcal{A} \to \mathbb{R}</script></td>
<td>Reward function.</td>
</tr>
</tbody>
</table>
<p>(RL^2 paper adds an extra parameter, horizon <script type="math/tex">T</script>, into the MDP tuple to emphasize that each MDP should have a finite horizon.)</p>
<p>Note that common state <script type="math/tex">\mathcal{S}</script> and action space <script type="math/tex">\mathcal{A}</script> are used above, so that a (stochastic) policy: <script type="math/tex">\pi_\theta: \mathcal{S} \times \mathcal{A} \to \mathbb{R}_{+}</script> would get inputs compatible across different tasks. The test tasks are sampled from the same distribution <script type="math/tex">\mathcal{M}</script> or slightly modified version.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/meta-RL-illustration.png" alt="Illustration of meta-RL" /></p>
<p><em>Fig. 2. Illustration of meta-RL, containing two optimization loops. The outer loop samples a new environment in every iteration and adjusts parameters that determine the agent’s behavior. In the inner loop, the agent interacts with the environment and optimizes for the maximal reward. (Image source: <a href="https://www.cell.com/action/showPdf?pii=S1364-6613%2819%2930061-0">Botvinick, et al. 2019</a></em></p>
<h3 id="main-differences-from-rl">Main Differences from RL</h3>
<p>The overall configure of meta-RL is very similar to an ordinary RL algorithm, except that <strong>the last reward</strong> <script type="math/tex">r_{t-1}</script> and <strong>the last action</strong> <script type="math/tex">a_{t-1}</script> are also incorporated into the policy observation in addition to the current state <script type="math/tex">s_t</script>.</p>
<ul>
<li>In RL: <script type="math/tex">\pi_\theta(s_t) \to</script> a distribution over <script type="math/tex">\mathcal{A}</script></li>
<li>In meta-RL: <script type="math/tex">\pi_\theta(a_{t-1}, r_{t-1}, s_t) \to</script> a distribution over <script type="math/tex">\mathcal{A}</script></li>
</ul>
<p>The intention of this design is to feed a history into the model so that the policy can internalize the dynamics between states, rewards, and actions in the current MDP and adjust its strategy accordingly. This is well aligned with the setup in <a href="#back-in-2001">Hochreiter’s system</a>. Both meta-RL and RL^2 implemented an LSTM policy and the LSTM’s hidden states serve as a <em>memory</em> for tracking characteristics of the trajectories. Because the policy is recurrent, there is no need to feed the last state as inputs explicitly.</p>
<p>The training procedure works as follows:</p>
<ol>
<li>Sample a new MDP, <script type="math/tex">M_i \sim \mathcal{M}</script>;</li>
<li><strong>Reset the hidden state</strong> of the model;</li>
<li>Collect multiple trajectories and update the model weights;</li>
<li>Repeat from step 1.</li>
</ol>
<p style="width: 68%;" class="center"><img src="/lil-log/assets/images/L2RL.png" alt="L2RL" /></p>
<p><em>Fig. 3. In the meta-RL paper, different actor-critic architectures all use a recurrent model. Last reward and last action are additional inputs. The observation is fed into the LSTM either as a one-hot vector or as an embedding vector after passed through an encoder model. (Image source: <a href="https://arxiv.org/abs/1611.05763">Wang et al., 2016</a>)</em></p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/RL_2.png" alt="RL^2" /></p>
<p><em>Fig. 4. As described in the RL^2 paper, illustration of the procedure of the model interacting with a series of MDPs in training time . (Image source: <a href="https://arxiv.org/abs/1611.02779">Duan et al., 2017</a>)</em></p>
<h3 id="key-components">Key Components</h3>
<p>There are three key components in Meta-RL:</p>
<blockquote>
<p>⭐ <strong>A Model with Memory</strong>
<br />
A recurrent neural network maintains a hidden state. Thus, it could acquire and memorize the knowledge about the current task by updating the hidden state during rollouts. Without memory, meta-RL would not work.</p>
</blockquote>
<blockquote>
<p>⭐ <strong>Meta-learning Algorithm</strong>
<br />
A meta-learning algorithm refers to how we can update the model weights to optimize for the purpose of solving an unseen task fast at test time. In both Meta-RL and RL^2 papers, the meta-learning algorithm is the ordinary gradient descent update of LSTM with hidden state reset between a switch of MDPs.</p>
</blockquote>
<blockquote>
<p>⭐ <strong>A Distribution of MDPs</strong>
<br />
While the agent is exposed to a variety of environments and tasks during training, it has to learn how to adapt to different MDPs.</p>
</blockquote>
<p>According to <a href="https://www.cell.com/action/showPdf?pii=S1364-6613%2819%2930061-0">Botvinick et al.</a> (2019), one source of slowness in RL training is <em>weak <a href="https://en.wikipedia.org/wiki/Inductive_bias">inductive bias</a></em> ( = “a set of assumptions that the learner uses to predict outputs given inputs that it has not encountered”). As a general ML rule, a learning algorithm with weak inductive bias will be able to master a wider range of variance, but usually, will be less sample-efficient. Therefore, to narrow down the hypotheses with stronger inductive biases help improve the learning speed.</p>
<p>In meta-RL, we impose certain types of inductive biases from the <em>task distribution</em> and store them in <em>memory</em>. Which inductive bias to adopt at test time depends on the <em>algorithm</em>. Together, these three key components depict a compelling view of meta-RL: Adjusting the weights of a recurrent network is slow but it allows the model to work out a new task fast with its own RL algorithm implemented in its internal activity dynamics.</p>
<p>Meta-RL interestingly and not very surprisingly matches the ideas in the <a href="https://arxiv.org/abs/1905.10985">AI-GAs</a> (“AI-Generating Algorithms”) paper by Jeff Clune (2019). He proposed that one efficient way towards building general AI is to make learning as automatic as possible. The AI-GAs approach involves three pillars: (1) meta-learning architectures, (2) meta-learning algorithms, and (3) automatically generated environments for effective learning.</p>
<hr />
<p>The topic of designing good recurrent network architectures is a bit too broad to be discussed here, so I will skip it. Next, let’s look further into another two components: meta-learning algorithms in the context of meta-RL and how to acquire a variety of training MDPs.</p>
<h2 id="meta-learning-algorithms-for-meta-rl">Meta-Learning Algorithms for Meta-RL</h2>
<p>My previous <a href="/lil-log/2018/11/30/meta-learning.html">post</a> on meta-learning has covered several classic meta-learning algorithms. Here I’m gonna include more related to RL.</p>
<h3 id="optimizing-model-weights-for-meta-learning">Optimizing Model Weights for Meta-learning</h3>
<p>Both MAML (<a href="https://arxiv.org/abs/1703.03400">Finn, et al. 2017</a>) and Reptile (<a href="https://arxiv.org/abs/1803.02999">Nichol et al., 2018</a>) are methods on updating model parameters in order to achieve good generalization performance on new tasks. See an earlier post <a href="/lil-log/2018/11/30/meta-learning.html#optimization-based">section</a> on MAML and Reptile.</p>
<h3 id="meta-learning-hyperparameters">Meta-learning Hyperparameters</h3>
<p>The <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#value-function">return</a> function in an RL problem, <script type="math/tex">G_t^{(n)}</script> or <script type="math/tex">G_t^\lambda</script>, involves a few hyperparameters that are often set heuristically, like the discount factor <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#value-function"><script type="math/tex">\gamma</script></a> and the bootstrapping parameter <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#combining-td-and-mc-learning"><script type="math/tex">\lambda</script></a>.
Meta-gradient RL (<a href="http://papers.nips.cc/paper/7507-meta-gradient-reinforcement-learning.pdf">Xu et al., 2018</a>) considers them as <em>meta-parameters</em>, <script type="math/tex">\eta=\{\gamma, \lambda \}</script>, that can be tuned and learned <em>online</em> while an agent is interacting with the environment. Therefore, the return becomes a function of <script type="math/tex">\eta</script> and dynamically adapts itself to a specific task over time.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
G_\eta^{(n)}(\tau_t) &= R_{t+1} + \gamma R_{t+2} + \dots + \gamma^{n-1}R_{t+n} + \gamma^n v_\theta(s_{t+n}) & \scriptstyle{\text{; n-step return}} \\
G_\eta^{\lambda}(\tau_t) &= (1-\lambda) \sum_{n=1}^\infty \lambda^{n-1} G_\eta^{(n)} & \scriptstyle{\text{; λ-return, mixture of n-step returns}}
\end{aligned} %]]></script>
<p>During training, we would like to update the policy parameters with gradients as a function of all the information in hand, <script type="math/tex">\theta' = \theta + f(\tau, \theta, \eta)</script>, where <script type="math/tex">\theta</script> are the current model weights, <script type="math/tex">\tau</script> is a sequence of trajectories, and <script type="math/tex">\eta</script> are the meta-parameters.</p>
<p>Meanwhile, let’s say we have a meta-objective function <script type="math/tex">J(\tau, \theta, \eta)</script> as a performance measure. The training process follows the principle of online cross-validation, using a sequence of consecutive experiences:</p>
<ol>
<li>Starting with parameter <script type="math/tex">\theta</script>, the policy <script type="math/tex">\pi_\theta</script> is updated on the first batch of samples <script type="math/tex">\tau</script>, resulting in <script type="math/tex">\theta'</script>.</li>
<li>Then we continue running the policy <script type="math/tex">\pi_{\theta'}</script> to collect a new set of experiences <script type="math/tex">\tau'</script>, just following <script type="math/tex">\tau</script> consecutively in time. The performance is measured as <script type="math/tex">J(\tau', \theta', \bar{\eta})</script> with a fixed meta-parameter <script type="math/tex">\bar{\eta}</script>.</li>
<li>The gradient of meta-objective <script type="math/tex">J(\tau', \theta', \bar{\eta})</script> w.r.t. <script type="math/tex">\eta</script> is used to update <script type="math/tex">\eta</script>:</li>
</ol>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\Delta \eta
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \eta} \\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \frac{d\theta'}{d\eta} & \scriptstyle{\text{ ; single variable chain rule.}} \\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \frac{\partial (\theta + f(\tau, \theta, \eta))}{\partial\eta} \\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \Big(\frac{d\theta}{d\eta} + \frac{\partial f(\tau, \theta, \eta)}{\partial\theta}\frac{d\theta}{d\eta} + \frac{\partial f(\tau, \theta, \eta)}{\partial\eta}\frac{d\eta}{d\eta} \Big) & \scriptstyle{\text{; multivariable chain rule.}}\\
&= -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \Big( \color{red}{\big(\mathbf{I} + \frac{\partial f(\tau, \theta, \eta)}{\partial\theta}\big)}\frac{d\theta}{d\eta} + \frac{\partial f(\tau, \theta, \eta)}{\partial\eta}\Big) & \scriptstyle{\text{; secondary gradient term in red.}}
\end{aligned} %]]></script>
<p>where <script type="math/tex">\beta</script> is the learning rate for <script type="math/tex">\eta</script>.</p>
<p>The meta-gradient RL algorithm simplifies the computation by setting the secondary gradient term to zero, <script type="math/tex">\mathbf{I} + \partial g(\tau, \theta, \eta)/\partial\theta = 0</script> — this choice prefers the immediate effect of the meta-parameters <script type="math/tex">\eta</script> on the parameters <script type="math/tex">\theta</script>. Eventually we get:</p>
<script type="math/tex; mode=display">\Delta \eta = -\beta \frac{\partial J(\tau', \theta', \bar{\eta})}{\partial \theta'} \frac{\partial f(\tau, \theta, \eta)}{\partial\eta}</script>
<p>Experiments in the paper adopted the meta-objective function same as <script type="math/tex">TD(\lambda)</script> algorithm, minimizing the error between the approximated value function <script type="math/tex">v_\theta(s)</script> and the <script type="math/tex">\lambda</script>-return:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
J(\tau, \theta, \eta) &= (G^\lambda_\eta(\tau) - v_\theta(s))^2 \\
J(\tau', \theta', \bar{\eta}) &= (G^\lambda_{\bar{\eta}}(\tau') - v_{\theta'}(s'))^2
\end{aligned} %]]></script>
<h3 id="meta-learning-the-loss-function">Meta-learning the Loss Function</h3>
<p>In policy gradient algorithms, the expected total reward is maximized by updating the policy parameters <script type="math/tex">\theta</script> in the direction of estimated gradient (<a href="https://arxiv.org/abs/1506.02438">Schulman et al., 2016</a>),</p>
<script type="math/tex; mode=display">g = \mathbb{E}[\sum_{t=0}^\infty \Psi_t \nabla_\theta \log \pi_\theta (a_t \mid s_t)]</script>
<p>where the candidates for <script type="math/tex">\Psi_t</script> include the trajectory return <script type="math/tex">G_t</script>, the Q value <script type="math/tex">Q(s_t, a_t)</script>, or the advantage value <script type="math/tex">A(s_t, a_t)</script>. The corresponding surrogate loss function for the policy gradient can be reverse-engineered:</p>
<script type="math/tex; mode=display">L_\text{pg} = \mathbb{E}[\sum_{t=0}^\infty \Psi_t \log \pi_\theta (a_t \mid s_t)]</script>
<p>This loss function is a measure over a history of trajectories, <script type="math/tex">(s_0, a_0, r_0, \dots, s_t, a_t, r_t, \dots)</script>. <strong>Evolved Policy Gradient</strong> (<strong>EPG</strong>; <a href="https://papers.nips.cc/paper/7785-evolved-policy-gradients.pdf">Houthooft, et al, 2018</a>) takes a step further by defining the policy gradient loss function as a temporal convolution (1-D convolution) over the agent’s past experience, <script type="math/tex">L_\phi</script>. The parameters <script type="math/tex">\phi</script> of the loss function network are evolved in a way that an agent can achieve higher returns.</p>
<p>Similar to many meta-learning algorithms, EPG has two optimization loops:</p>
<ul>
<li>In the internal loop, an agent learns to improve its policy <script type="math/tex">\pi_\theta</script>.</li>
<li>In the outer loop, the model updates the parameters <script type="math/tex">\phi</script> of the loss function <script type="math/tex">L_\phi</script>. Because there is no explicit way to write down a differentiable equation between the return and the loss, EPG turned to <a href="https://en.wikipedia.org/wiki/Evolution_strategy"><em>Evolutionary Strategies</em></a> (ES).</li>
</ul>
<p>A general idea is to train a population of <script type="math/tex">N</script> agents, each of them is trained with the loss function <script type="math/tex">L_{\phi + \sigma \epsilon_i}</script> parameterized with <script type="math/tex">\phi</script> added with a small Gaussian noise <script type="math/tex">\epsilon_i \sim \mathcal{N}(0, \mathbf{I})</script> of standard deviation <script type="math/tex">\sigma</script>. During the inner loop’s training, EPG tracks a history of experience and updates the policy parameters according to the loss function <script type="math/tex">L_{\phi + \sigma\epsilon_i}</script> for each agent:</p>
<script type="math/tex; mode=display">\theta_i \leftarrow \theta - \alpha_\text{in} \nabla_\theta L_{\phi + \sigma \epsilon_i} (\pi_\theta, \tau_{t-K, \dots, t})</script>
<p>where <script type="math/tex">\alpha_\text{in}</script> is the learning rate of the inner loop and <script type="math/tex">\tau_{t-K, \dots, t}</script> is a sequence of <script type="math/tex">M</script> transitions up to the current time step <script type="math/tex">t</script>.</p>
<p>Once the inner loop policy is mature enough, the policy is evaluated by the mean return <script type="math/tex">\bar{G}_{\phi+\sigma\epsilon_i}</script> over multiple randomly sampled trajectories. Eventually, we are able to estimate the gradient of <script type="math/tex">\phi</script> according to <a href="http://blog.otoro.net/2017/10/29/visual-evolution-strategies/">NES</a> numerically (<a href="https://arxiv.org/abs/1703.03864">Salimans et al, 2017</a>). While repeating this process, both the policy parameters <script type="math/tex">\theta</script> and the loss function weights <script type="math/tex">\phi</script> are being updated simultaneously to achieve higher returns.</p>
<script type="math/tex; mode=display">\phi \leftarrow \phi + \alpha_\text{out} \frac{1}{\sigma N} \sum_{i=1}^N \epsilon_i G_{\phi+\sigma\epsilon_i}</script>
<p>where <script type="math/tex">\alpha_\text{out}</script> is the learning rate of the outer loop.</p>
<p>In practice, the loss <script type="math/tex">L_\phi</script> is bootstrapped with an ordinary policy gradient (such as REINFORCE or PPO) surrogate loss <script type="math/tex">L_\text{pg}</script>, <script type="math/tex">\hat{L} = (1-\alpha) L_\phi + \alpha L_\text{pg}</script>. The weight <script type="math/tex">\alpha</script> is annealing from 1 to 0 gradually during training. At test time, the loss function parameter <script type="math/tex">\phi</script> stays fixed and the loss value is computed over a history of experience to update the policy parameters <script type="math/tex">\theta</script>.</p>
<h3 id="meta-learning-the-exploration-strategies">Meta-learning the Exploration Strategies</h3>
<p>The <a href="/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html#exploitation-vs-exploration">exploitation vs exploration</a> dilemma is a critical problem in RL. Common ways to do exploration include <script type="math/tex">\epsilon</script>-greedy, random noise on actions, or stochastic policy with built-in randomness on the action space.</p>
<p><strong>MAESN</strong> (<a href="http://papers.nips.cc/paper/7776-meta-reinforcement-learning-of-structured-exploration-strategies.pdf">Gupta et al, 2018</a>) is an algorithm to learn structured action noise from prior experience for better and more effective exploration. Simply adding random noise on actions cannot capture task-dependent or time-correlated exploration strategies. MAESN changes the policy to condition on a per-task random variable <script type="math/tex">z_i \sim \mathcal{N}(\mu_i, \sigma_i)</script>, for <script type="math/tex">i</script>-th task <script type="math/tex">M_i</script>, so we would have a policy <script type="math/tex">a \sim \pi_\theta(a\mid s, z_i)</script>.
The latent variable <script type="math/tex">z_i</script> is sampled once and fixed during one episode. Intuitively, the latent variable determines one type of behavior (or skills) that should be explored more at the beginning of a rollout and the agent would adjust its actions accordingly. Both the policy parameters and latent space are optimized to maximize the total task rewards. In the meantime, the policy learns to make use of the latent variables for exploration.</p>
<p>In addition, the loss function includes a KL divergence between the learned latent variable and a unit Gaussian prior, <script type="math/tex">D_\text{KL}(\mathcal{N}(\mu_i, \sigma_i)\|\mathcal{N}(0, \mathbf{I}))</script>. On one hand, it restricts the learned latent space not too far from a common prior. On the other hand, it creates the variational evidence lower bound (<a href="http://users.umiacs.umd.edu/~xyang35/files/understanding-variational-lower.pdf">ELBO</a>) for the reward function. Interestingly the paper found that <script type="math/tex">(\mu_i, \sigma_i)</script> for each task are usually close to the prior at convergence.</p>
<p style="width: 82%;" class="center"><img src="/lil-log/assets/images/MAESN.png" alt="MAESN" /></p>
<p><em>Fig. 5. The policy is conditioned on a latent variable variable <script type="math/tex">z_i \sim \mathcal{N}(\mu, \sigma)</script> that is sampled once every episode. Each task has different hyperparameters for the latent variable distribution, <script type="math/tex">(\mu_i, \sigma_i)</script> and they are optimized in the outer loop. (Image source: <a href="http://papers.nips.cc/paper/7776-meta-reinforcement-learning-of-structured-exploration-strategies.pdf">Gupta et al, 2018</a>)</em></p>
<h3 id="episodic-control">Episodic Control</h3>
<p>A major criticism of RL is on its sample inefficiency. A large number of samples and small learning steps are required for incremental parameter adjustment in RL in order to maximize generalization and avoid catastrophic forgetting of earlier learning (<a href="https://www.cell.com/trends/cognitive-sciences/fulltext/S1364-6613\(19\)30061-0">Botvinick et al., 2019</a>).</p>
<p><strong>Episodic control</strong> (<a href="http://papers.nips.cc/paper/3311-hippocampal-contributions-to-control-the-third-way.pdf">Lengyel & Dayan, 2008</a>) is proposed as a solution to avoid forgetting and improve generalization while training at a faster speed. It is partially inspired by hypotheses on instance-based <a href="https://en.wikipedia.org/wiki/Hippocampus">hippocampal</a> learning.</p>
<p>An <em>episodic memory</em> keeps explicit records of past events and uses these records directly as point of reference for making new decisions (i.e. just like <a href="/lil-log/2018/11/30/meta-learning.html#metric-based">metric-based</a> meta-learning). In <strong>MFEC</strong> (Model-Free Episodic Control; <a href="https://arxiv.org/abs/1606.04460">Blundell et al., 2016</a>), the memory is modeled as a big table, storing the state-action pair <script type="math/tex">(s, a)</script> as key and the corresponding Q-value <script type="math/tex">Q_\text{EC}(s, a)</script> as value. When receiving a new observation <script type="math/tex">s</script>, the Q value is estimated in an non-parametric way as the average Q-value of top <script type="math/tex">k</script> most similar samples:</p>
<script type="math/tex; mode=display">% <![CDATA[
\hat{Q}_\text{EC}(s, a) =
\begin{cases}
Q_\text{EC}(s, a) & \text{if } (s,a) \in Q_\text{EC}, \\
\frac{1}{k} \sum_{i=1}^k Q(s^{(i)}, a) & \text{otherwise}
\end{cases} %]]></script>
<p>where <script type="math/tex">s^{(i)}, i=1, \dots, k</script> are top <script type="math/tex">k</script> states with smallest distances to the state <script type="math/tex">s</script>. Then the action that yields the highest estimated Q value is selected. Then the memory table is updated according to the return received at <script type="math/tex">s_t</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
Q_\text{EC}(s, a) \leftarrow
\begin{cases}
\max\{Q_\text{EC}(s_t, a_t), G_t\} & \text{if } (s,a) \in Q_\text{EC}, \\
G_t & \text{otherwise}
\end{cases} %]]></script>
<p>As a tabular RL method, MFEC suffers from large memory consumption and a lack of ways to generalize among similar states. The first one can be fixed with an LRU cache. Inspired by <a href="/lil-log/2018/11/30/meta-learning.html#metric-based">metric-based</a> meta-learning, especially Matching Networks (<a href="http://papers.nips.cc/paper/6385-matching-networks-for-one-shot-learning.pdf">Vinyals et al., 2016</a>), the generalization problem is improved in a follow-up algorithm, <strong>NEC</strong> (Neural Episodic Control; <a href="https://arxiv.org/abs/1703.01988">Pritzel et al., 2016</a>).</p>
<p>The episodic memory in NEC is a Differentiable Neural Dictionary (<strong>DND</strong>), where the key is a convolutional embedding vector of input image pixels and the value stores estimated Q value. Given an inquiry key, the output is a weighted sum of values of top similar keys, where the weight is a normalized kernel measure between the query key and the selected key in the dictionary. This sounds like a hard <a href="/2018/06/24/attention-attention.html">attention</a> machanism.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/neural-episodic-control.png" alt="Neural episodic control" /></p>
<p><em>Fig. 6 Illustrations of episodic memory module in NEC and two operations on a differentiable neural dictionary. (Image source: <a href="https://arxiv.org/abs/1703.01988">Pritzel et al., 2016</a>)</em></p>
<p>Further, <strong>Episodic LSTM</strong> (<a href="https://arxiv.org/abs/1805.09692">Ritter et al., 2018</a>) enhances the basic LSTM architecture with a DND episodic memory, which stores task context embeddings as keys and the LSTM cell states as values. The stored hidden states are retrieved and added directly to the current cell state through the same gating mechanism within LSTM:</p>
<p style="width: 77%;" class="center"><img src="/lil-log/assets/images/episodic-LSTM.png" alt="Episodic LSTM" /></p>
<p><em>Fig. 7. Illustration of the episodic LSTM architecture. The additional structure of episodic memory is in bold. (Image source: <a href="https://arxiv.org/abs/1805.09692">Ritter et al., 2018</a>)</em></p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{c}_t &= \mathbf{i}_t \circ \mathbf{c}_\text{in} + \mathbf{f}_t \circ \mathbf{c}_{t-1} + \color{green}{\mathbf{r}_t \circ \mathbf{c}_\text{ep}} &\\
\mathbf{i}_t &= \sigma(\mathbf{W}_{i} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) & \scriptstyle{\text{; input gate}} \\
\mathbf{f}_t &= \sigma(\mathbf{W}_{f} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) & \scriptstyle{\text{; forget gate}} \\
\color{green}{\mathbf{r}_t} & \color{green}{=} \color{green}{\sigma(\mathbf{W}_{r} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_r)} & \scriptstyle{\text{; reinstatement gate}}
\end{aligned} %]]></script>
<p>where <script type="math/tex">\mathbf{c}_t</script> and <script type="math/tex">\mathbf{h}_t</script> are hidden and cell state at time <script type="math/tex">t</script>; <script type="math/tex">\mathbf{i}_t</script>, <script type="math/tex">\mathbf{f}_t</script> and <script type="math/tex">\mathbf{r}_t</script> are input, forget and reinstatement gates, respectively; <script type="math/tex">\mathbf{c}_\text{ep}</script> is the retrieved cell state from episodic memory. The newly added episodic memory components are marked in green.</p>
<p>This architecture provides a shortcut to the prior experience through context-based retrieval. Meanwhile, explicitly saving the task-dependent experience in an external memory avoids forgetting. In the paper, all the experiments have manually designed context vectors. How to construct an effective and efficient format of task context embeddings for more free-formed tasks would be an interesting topic.</p>
<p>Overall the capacity of episodic control is limited by the complexity of the environment. It is very rare for an agent to repeatedly visit exactly the same states in a real-world task, so properly encoding the states is critical. The learned embedding space compresses the observation data into a lower dimension space and, in the meantime, two states being close in this space are expected to demand similar strategies.</p>
<h2 id="training-task-acquisition">Training Task Acquisition</h2>
<p>Among three key components, how to design a proper distribution of tasks is the less studied and probably the most specific one to meta-RL itself. As described <a href="#formulation">above</a>, each task is a MDP: <script type="math/tex">M_i = \langle \mathcal{S}, \mathcal{A}, P_i, R_i \rangle \in \mathcal{M}</script>. We can build a distribution of MDPs by modifying:</p>
<ul>
<li>The <em>reward configuration</em>: Among different tasks, same behavior might get rewarded differently according to <script type="math/tex">R_i</script>.</li>
<li>Or, the <em>environment</em>: The transition function <script type="math/tex">P_i</script> can be reshaped by initializing the environment with varying shifts between states.</li>
</ul>
<h3 id="task-generation-by-domain-randomization">Task Generation by Domain Randomization</h3>
<p>Randomizing parameters in a simulator is an easy way to obtain tasks with modified transition functions. If interested in learning further, check my last <a href="/lil-log/2019/05/05/domain-randomization.html">post</a> on <strong>domain randomization</strong>.</p>
<h3 id="evolutionary-algorithm-on-environment-generation">Evolutionary Algorithm on Environment Generation</h3>
<p><a href="https://en.wikipedia.org/wiki/Evolutionary_algorithm">Evolutionary algorithm</a> is a gradient-free heuristic-based optimization method, inspired by natural selection. A population of solutions follows a loop of evaluation, selection, reproduction, and mutation. Eventually, good solutions survive and thus get selected.</p>
<p><strong>POET</strong> (<a href="https://arxiv.org/abs/1901.01753">Wang et al, 2019</a>), a framework based on the evolutionary algorithm, attempts to generate tasks while the problems themselves are being solved. The implementation of POET is only specifically designed for a simple 2D <a href="https://gym.openai.com/envs/BipedalWalkerHardcore-v2/">bipedal walker</a> environment but points out an interesting direction. It is noteworthy that the evolutionary algorithm has had some compelling applications in Deep Learning like <a href="#meta-learning-the-loss-function">EPG</a> and PBT (Population-Based Training; <a href="https://arxiv.org/abs/1711.09846"> Jaderberg et al, 2017</a>).</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/POET.png" alt="POET" /></p>
<p><em>Fig. 8. An example bipedal walking environment (top) and an overview of POET (bottom). (Image source: <a href="https://eng.uber.com/poet-open-ended-deep-learning/">POET blog post</a>)</em></p>
<p>The 2D bipedal walking environment is evolving: from a simple flat surface to a much more difficult trail with potential gaps, stumps, and rough terrains. POET pairs the generation of environmental challenges and the optimization of agents together so as to (a) select agents that can resolve current challenges and (b) evolve environments to be solvable. The algorithm maintains a list of <em>environment-agent pairs</em> and repeats the following:</p>
<ol>
<li><em>Mutation</em>: Generate new environments from currently active environments. Note that here types of mutation operations are created just for bipedal walker and a new environment would demand a new set of configurations.</li>
<li><em>Optimization</em>: Train paired agents within their respective environments.</li>
<li><em>Selection</em>: Periodically attempt to transfer current agents from one environment to another. Copy and update the best performing agent for every environment. The intuition is that skills learned in one environment might be helpful for a different environment.</li>
</ol>
<p>The procedure above is quite similar to <a href="https://arxiv.org/abs/1711.09846">PBT</a>, but PBT mutates and evolves hyperparameters instead. To some extent, POET is doing <a href="/lil-log/2019/05/05/domain-randomization.html">domain randomization</a>, as all the gaps, stumps and terrain roughness are controlled by some randomization probability parameters. Different from DR, the agents are not exposed to a fully randomized difficult environment all at once, but instead they are learning gradually with a curriculum configured by the evolutionary algorithm.</p>
<h3 id="learning-with-random-rewards">Learning with Random Rewards</h3>
<p>An MDP without a reward function <script type="math/tex">R</script> is known as a <em>Controlled Markov process</em> (CMP). Given a predefined CMP, <script type="math/tex">\langle \mathcal{S}, \mathcal{A}, P\rangle</script>, we can acquire a variety of tasks by generating a collection of reward functions <script type="math/tex">\mathcal{R}</script> that encourage the training of an effective meta-learning policy.</p>
<p><a href="https://arxiv.org/abs/1806.04640">Gupta et al. (2018)</a> proposed two unsupervised approaches for growing the task distribution in the context of CMP. Assuming there is an underlying latent variable <script type="math/tex">z \sim p(z)</script> associated with every task, it parameterizes/determines a reward function: <script type="math/tex">r_z(s) = \log D(z|s)</script>, where a “discriminator” function <script type="math/tex">D(.)</script> is used to extract the latent variable from the state. The paper described two ways to construct a discriminator function:</p>
<ul>
<li>Sample random weights <script type="math/tex">\phi_\text{rand}</script> of the discriminator, <script type="math/tex">D_{\phi_\text{rand}}(z \mid s)</script>.</li>
<li>Learn a discriminator function to encourage diversity-driven exploration. This method is introduced in more details in another sister paper “DIAYN” (<a href="https://arxiv.org/abs/1802.06070">Eysenbach et al., 2018</a>).</li>
</ul>
<p>DIAYN, short for “Diversity is all you need”, is a framework to encourage a policy to learn useful skills without a reward function. It explicitly models the latent variable <script type="math/tex">z</script> as a <em>skill</em> embedding and makes the policy conditioned on <script type="math/tex">z</script> in addition to state <script type="math/tex">s</script>, <script type="math/tex">\pi_\theta(a \mid s, z)</script>. (Ok, this part is same as <a href="#meta-learning-the-exploration-strategies">MAESN</a> unsurprisingly, as the papers are from the same group.) The design of DIAYN is motivated by a few hypotheses:</p>
<ul>
<li>Skills should be diverse and lead to visitations of different states. → maximize the mutual information between states and skills, <script type="math/tex">I(S; Z)</script></li>
<li>Skills should be distinguishable by states, not actions. → minimize the mutual information between actions and skills, conditioned on states <script type="math/tex">I(A; Z \mid S)</script></li>
</ul>
<p>The objective function to maximize is as follows, where the policy entropy is also added to encourage diversity:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{F}(\theta)
&= I(S; Z) + H[A \mid S] - I(A; Z \mid S) & \\
&= (H(Z) - H(Z \mid S)) + H[A \mid S] - (H[A\mid S] - H[A\mid S, Z]) & \\
&= H[A\mid S, Z] \color{green}{- H(Z \mid S) + H(Z)} & \\
&= H[A\mid S, Z] + \mathbb{E}_{z\sim p(z), s\sim\rho(s)}[\log p(z \mid s)] - \mathbb{E}_{z\sim p(z)}[\log p(z)] & \scriptstyle{\text{; can infer skills from states & p(z) is diverse.}} \\
&\ge H[A\mid S, Z] + \mathbb{E}_{z\sim p(z), s\sim\rho(s)}[\color{red}{\log D_\phi(z \mid s) - \log p(z)}] & \scriptstyle{\text{; according to Jensen's inequality; "pseudo-reward" in red.}}
\end{aligned} %]]></script>
<p>where <script type="math/tex">I(.)</script> is mutual information and <script type="math/tex">H[.]</script> is entropy measure. We cannot integrate all states to compute <script type="math/tex">p(z \mid s)</script>, so approximate it with <script type="math/tex">D_\phi(z \mid s)</script> — that is the diversity-driven discriminator function.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/DIAYN.png" alt="DIAYN" /></p>
<p><em>Fig. 9. DIAYN Algorithm. (Image source: <a href="https://arxiv.org/abs/1802.06070">Eysenbach et al., 2019</a>)</em></p>
<p>Once the discriminator function is learned, sampling a new MDP for training is strainght-forward: First, sample a latent variable, <script type="math/tex">z \sim p(z)</script> and construct a reward function <script type="math/tex">r_z(s) = \log(D(z \vert s))</script>. Pairing the reward function with a predefined CMP creates a new MDP.</p>
<!--
---
So far, experiments of meta-RL are still limited to a collection of very similar tasks, originated from the same family; such as multi-armed bandit with different reward probabilities, mazes with different layouts, or same robots but with different physical parameters in simulator. I'm looking forward to more research demonstrating the power of meta-RL over a more diverse set of tasks.
-->
<hr />
<p>Cited as:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019metaRL,
title = "Meta Reinforcement Learning",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/06/23/meta-reinforcement-learning.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Richard S. Sutton. <a href="http://incompleteideas.net/IncIdeas/BitterLesson.html">“The Bitter Lesson.”</a> March 13, 2019.</p>
<p>[2] Sepp Hochreiter, A. Steven Younger, and Peter R. Conwell. <a href="http://snowedin.net/tmp/Hochreiter2001.pdf">“Learning to learn using gradient descent.”</a> Intl. Conf. on Artificial Neural Networks. 2001.</p>
<p>[3] Jane X Wang, et al. <a href="https://arxiv.org/abs/1611.05763">“Learning to reinforcement learn.”</a> arXiv preprint arXiv:1611.05763 (2016).</p>
<p>[4] Yan Duan, et al. <a href="https://arxiv.org/abs/1611.02779">“RL $^ 2$: Fast Reinforcement Learning via Slow Reinforcement Learning.”</a> ICLR 2017.</p>
<p>[5] Matthew Botvinick, et al. <a href="https://www.cell.com/trends/cognitive-sciences/fulltext/S1364-6613\(19\)30061-0">“Reinforcement Learning, Fast and Slow”</a> Cell Review, Volume 23, Issue 5, P408-422, May 01, 2019.</p>
<p>[6] Jeff Clune. <a href="https://arxiv.org/abs/1905.10985">“AI-GAs: AI-generating algorithms, an alternate paradigm for producing general artificial intelligence”</a> arXiv preprint arXiv:1905.10985 (2019).</p>
<p>[7] Zhongwen Xu, et al. <a href="http://papers.nips.cc/paper/7507-meta-gradient-reinforcement-learning.pdf">“Meta-Gradient Reinforcement Learning”</a> NIPS 2018.</p>
<p>[8] Rein Houthooft, et al. <a href="https://papers.nips.cc/paper/7785-evolved-policy-gradients.pdf">“Evolved Policy Gradients.”</a> NIPS 2018.</p>
<p>[9] Tim Salimans, et al. <a href="https://arxiv.org/abs/1703.03864">“Evolution strategies as a scalable alternative to reinforcement learning.”</a> arXiv preprint arXiv:1703.03864 (2017).</p>
<p>[10] Abhishek Gupta, et al. <a href="http://papers.nips.cc/paper/7776-meta-reinforcement-learning-of-structured-exploration-strategies.pdf">“Meta-Reinforcement Learning of Structured Exploration Strategies.”</a> NIPS 2018.</p>
<p>[11] Alexander Pritzel, et al. <a href="https://arxiv.org/abs/1703.01988">“Neural episodic control.”</a> Proc. Intl. Conf. on Machine Learning, Volume 70, 2017.</p>
<p>[12] Charles Blundell, et al. <a href="https://arxiv.org/abs/1606.04460">“Model-free episodic control.”</a> arXiv preprint arXiv:1606.04460 (2016).</p>
<p>[13] Samuel Ritter, et al. <a href="https://arxiv.org/abs/1805.09692">“Been there, done that: Meta-learning with episodic recall.”</a> ICML, 2018.</p>
<p>[14] Rui Wang et al. <a href="https://arxiv.org/abs/1901.01753">“Paired Open-Ended Trailblazer (POET): Endlessly Generating Increasingly Complex and Diverse Learning Environments and Their Solutions”</a> arXiv preprint arXiv:1901.01753 (2019).</p>
<p>[15] Uber Engineering Blog: <a href="https://eng.uber.com/poet-open-ended-deep-learning/">“POET: Endlessly Generating Increasingly Complex and Diverse Learning Environments and their Solutions through the Paired Open-Ended Trailblazer.”</a> Jan 8, 2019.</p>
<p>[16] Abhishek Gupta, et al.<a href="https://arxiv.org/abs/1806.04640">“Unsupervised meta-learning for Reinforcement Learning”</a> arXiv preprint arXiv:1806.04640 (2018).</p>
<p>[17] Eysenbach, Benjamin, et al. <a href="https://arxiv.org/abs/1802.06070">“Diversity is all you need: Learning skills without a reward function.”</a> ICLR 2019.</p>
<p>[18] Max Jaderberg, et al. <a href="https://arxiv.org/abs/1711.09846">“Population Based Training of Neural Networks.”</a> arXiv preprint arXiv:1711.09846 (2017).</p>Lilian WengMeta-RL is meta-learning on reinforcement learning tasks. After trained over a distribution of tasks, the agent is able to solve a new task by developing a new RL algorithm with its internal activity dynamics. This post starts with the origin of meta-RL and then dives into three key components of meta-RL.Domain Randomization for Sim2Real Transfer2019-05-05T00:00:00+00:002019-05-05T00:00:00+00:00https://lilianweng.github.io/lil-log/2019/05/05/domain-randomization<blockquote>
<p>If a model or policy is mainly trained in a simulator but expected to work on a real robot, it would surely face the sim2real gap. <em>Domain Randomization</em> (DR) is a simple but powerful idea of closing this gap by randomizing properties of the training environment.</p>
</blockquote>
<!--more-->
<p>In Robotics, one of the hardest problems is how to make your model transfer to the real world. Due to the sample inefficiency of deep RL algorithms and the cost of data collection on real robots, we often need to train models in a simulator which theoretically provides an infinite amount of data. However, the reality gap between the simulator and the physical world often leads to failure when working with physical robots. The gap is triggered by an inconsistency between physical parameters (i.e. friction, kp, damping, mass, density) and, more fatally, the incorrect physical modeling (i.e. collision between soft surfaces).</p>
<p>To close the sim2real gap, we need to improve the simulator and make it closer to reality. A couple of approaches:</p>
<ul>
<li><strong>System identification</strong>
<ul>
<li><em>System identification</em> is to build a mathematical model for a physical system; in the context of RL, the mathematical model is the simulator. To make the simulator more realistic, careful calibration is necessary.</li>
<li>Unfortunately, calibration is expensive. Furthermore, many physical parameters of the same machine might vary significantly due to temperature, humidity, positioning or its wear-and-tear in time.</li>
</ul>
</li>
<li><strong>Domain adaptation</strong>
<ul>
<li><em>Domain adaptation (DA)</em> refers to a set of transfer learning techniques developed to update the data distribution in sim to match the real one through a mapping or regularization enforced by the task model.</li>
<li>Many DA models, especially for image classification or end-to-end image-based RL task, are built on adversarial loss or <a href="/lil-log/2017/08/20/from-GAN-to-WGAN.html">GAN</a>.</li>
</ul>
</li>
<li><strong>Domain randomization</strong>
<ul>
<li>With <em>domain randomization (DR)</em>, we are able to create a variety of simulated environments with randomized properties and train a model that works across all of them.</li>
<li>Likely this model can adapt to the real-world environment, as the real system is expected to be one sample in that rich distribution of training variations.</li>
</ul>
</li>
</ul>
<p>Both DA and DR are unsupervised. Compared to DA which requires a decent amount of real data samples to capture the distribution, DR may need <em>only a little or no</em> real data. DR is the focus of this post.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/sim2real-transfer.png" alt="Approaches for sim2real transfer" /></p>
<p><em>Fig. 1. Conceptual illustrations of three approaches for sim2real transfer.</em></p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#what-is-domain-randomization" id="markdown-toc-what-is-domain-randomization">What is Domain Randomization?</a></li>
<li><a href="#uniform-domain-randomization" id="markdown-toc-uniform-domain-randomization">Uniform Domain Randomization</a></li>
<li><a href="#why-does-domain-randomization-work" id="markdown-toc-why-does-domain-randomization-work">Why does Domain Randomization Work?</a> <ul>
<li><a href="#dr-as-optimization" id="markdown-toc-dr-as-optimization">DR as Optimization</a></li>
<li><a href="#dr-as-meta-learning" id="markdown-toc-dr-as-meta-learning">DR as Meta-Learning</a></li>
</ul>
</li>
<li><a href="#guided-domain-randomization" id="markdown-toc-guided-domain-randomization">Guided Domain Randomization</a> <ul>
<li><a href="#optimization-for-task-performance" id="markdown-toc-optimization-for-task-performance">Optimization for Task Performance</a></li>
<li><a href="#match-real-data-distribution" id="markdown-toc-match-real-data-distribution">Match Real Data Distribution</a></li>
<li><a href="#guided-by-data-in-simulator" id="markdown-toc-guided-by-data-in-simulator">Guided by Data in Simulator</a></li>
</ul>
</li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="what-is-domain-randomization">What is Domain Randomization?</h2>
<p>To make the definition more general, let us call the environment that we have full access to (i.e. simulator) <strong>source domain</strong> and the environment that we would like to transfer the model to <strong>target domain</strong> (i.e. physical world). Training happens in the source domain. We can control a set of <script type="math/tex">N</script> randomization parameters in the source domain <script type="math/tex">e_\xi</script> with a configuration <script type="math/tex">\xi</script>, sampled from a randomization space, <script type="math/tex">\xi \in \Xi \subset \mathbb{R}^N</script>.</p>
<p>During policy training, episodes are collected from source domain with randomization applied. Thus the policy is exposed to a variety of environments and learns to generalize. The policy parameter <script type="math/tex">\theta</script> is trained to maximize the expected reward <script type="math/tex">R(.)</script> average across a distribution of configurations:</p>
<script type="math/tex; mode=display">\theta^* = \arg\max_\theta \mathbb{E}_{\xi \sim \Xi} [\mathbb{E}_{\pi_\theta, \tau \sim e_\xi} [R(\tau)]]</script>
<p>where <script type="math/tex">\tau_\xi</script> is a trajectory collected in source domain randomized with <script type="math/tex">\xi</script>. In a way, <em>“discrepancies between the source and target domains are modeled as variability in the source domain.”</em> (quote from <a href="https://arxiv.org/abs/1710.06537">Peng et al. 2018</a>).</p>
<h2 id="uniform-domain-randomization">Uniform Domain Randomization</h2>
<p>In the original form of DR (<a href="https://arxiv.org/abs/1703.06907">Tobin et al, 2017</a>; <a href="https://arxiv.org/pdf/1611.04201.pdf">Sadeghi et al. 2016</a>), each randomization parameter <script type="math/tex">\xi_i</script> is bounded by an interval, <script type="math/tex">\xi_i \in [\xi_i^\text{low}, \xi_i^\text{high}], i=1,\dots,N</script> and each parameter is uniformly sampled within the range.</p>
<p>The randomization parameters can control appearances of the scene, including but not limited to the followings (see Fig. 2). A model trained on simulated and randomized images is able to transfer to real non-randomized images.</p>
<ul>
<li>Position, shape, and color of objects,</li>
<li>Material texture,</li>
<li>Lighting condition,</li>
<li>Random noise added to images,</li>
<li>Position, orientation, and field of view of the camera in the simulator.</li>
</ul>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/DR.png" alt="Domain Randomization" /></p>
<p><em>Fig. 2. Images captured in the training environment are randomized. (Image source: <a href="https://arxiv.org/abs/1703.06907">Tobin et al, 2017</a>)</em></p>
<p>Physical dynamics in the simulator can also be randomized (<a href="https://arxiv.org/abs/1710.06537">Peng et al. 2018</a>). Studies have showed that a <em>recurrent</em> policy can adapt to different physical dynamics including the partially observable reality. A set of physical dynamics features include but are not limited to:</p>
<ul>
<li>Mass and dimensions of objects,</li>
<li>Mass and dimensions of robot bodies,</li>
<li>Damping, kp, friction of the joints,</li>
<li>Gains for the PID controller (P term),</li>
<li>Joint limit,</li>
<li>Action delay,</li>
<li>Observation noise.</li>
</ul>
<p>With visual and dynamics DR, at OpenAI Robotics, we were able to learn a policy that works on real dexterous robot hand (<a href="https://arxiv.org/abs/1808.00177">OpenAI, 2018</a>). Our manipulation task is to teach the robot hand to rotate an object continously to achieve 50 successive random target orientations. The sim2real gap in this task is very large, due to (a) a high number of simultaneous contacts between the robot and the object and (b) imperfect simulation of object collision and other motions. At first, the policy could barely survive for more than 5 seconds without dropping the object. But with the help of DR, the policy evolved to work surprisingly well in reality eventually.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/DKe8FumoD4E" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen=""></iframe>
<h2 id="why-does-domain-randomization-work">Why does Domain Randomization Work?</h2>
<p>Now you may ask, why does domain randomization work so well? The idea sounds really simple. Here are two non-exclusive explanations I found most convincing.</p>
<h3 id="dr-as-optimization">DR as Optimization</h3>
<p>One idea (<a href="https://arxiv.org/abs/1903.11774">Vuong, et al, 2019</a>) is to view learning randomization parameters in DR as a <em>bilevel optimization</em>. Assuming we have access to the real environment <script type="math/tex">e_\text{real}</script> and the randomization config is sampled from a distribution parameterized by <script type="math/tex">\phi</script>, <script type="math/tex">\xi \sim P_\phi(\xi)</script>, we would like to learn a distribution on which a policy <script type="math/tex">\pi_\theta</script> is trained on can achieve maximal performance in <script type="math/tex">e_\text{real}</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
&\phi^* = \arg\min_{\phi} \mathcal{L}(\pi_{\theta^*(\phi)}; e_\text{real}) \\
\text{where } &\theta^*(\phi) = \arg\min_\theta \mathbb{E}_{\xi \sim P_\phi(\xi)}[\mathcal{L}(\pi_\theta; e_\xi)]
\end{aligned} %]]></script>
<p>where <script type="math/tex">\mathcal{L}(\pi; e)</script> is the loss function of policy <script type="math/tex">\pi</script> evaluated in the environment <script type="math/tex">e</script>.</p>
<p>Although randomization ranges are hand-picked in uniform DR, it often involves domain knowledge and a couple rounds of trial-and-error adjustment based on the transfer performance. Essentially this is a manual optimization process on tuning <script type="math/tex">\phi</script> for the optimal <script type="math/tex">\mathcal{L}(\pi_{\theta^*(\phi)}; e_\text{real})</script>.</p>
<p>Guided domain randomization in the next section is largely inspired by this view, aiming to do bilevel optimization and learn the best parameter distribution automatically.</p>
<h3 id="dr-as-meta-learning">DR as Meta-Learning</h3>
<p>In our learning dexterity project (<a href="https://arxiv.org/abs/1808.00177">OpenAI, 2018</a>), we trained an LSTM policy to generalize across different environmental dynamics. We observed that once a robot achieved the first rotation, the time it needed for the following successes was much shorter. Also, a FF policy without memory was found not able to transfer to a physical robot. Both are evidence of the policy dynamically learning and adapting to a new environment.</p>
<p>In some ways, domain randomization composes a collection of different tasks. Memory in the recurrent network empowers the policy to achieve <a href="/lil-log/2018/11/30/meta-learning.html"><em>meta-learning</em></a> across tasks and further work on a real-world setting.</p>
<h2 id="guided-domain-randomization">Guided Domain Randomization</h2>
<p>The vanilla DR assumes no access to the real data, and thus the randomization config is sampled as broadly and uniformly as possible in sim, hoping that the real environment could be covered under this broad distribution. It is reasonable to think of a more sophisticated strategy — replacing uniform sampling with guidance from <em>task performance</em>, <em>real data</em>, or <em>simulator</em>.</p>
<p>One motivation for guided DR is to save computation resources by avoiding training models in unrealistic environments. Another is to avoid infeasible solutions that might arise from overly wide randomization distributions and thus might hinder successful policy learning.</p>
<h3 id="optimization-for-task-performance">Optimization for Task Performance</h3>
<p>Say we train a family of policies with different randomization parameters <script type="math/tex">\xi \sim P_\phi(\xi)</script>, where <script type="math/tex">P_\xi</script> is the distribution for <script type="math/tex">\xi</script> parameterized by <script type="math/tex">\phi</script>. Later we decide to try every one of them on the downstream task in the target domain (i.e. control a robot in reality or evaluate on a validation set) to collect feedback. This feedback tells us how good a configuration <script type="math/tex">\xi</script> is and provides signals for optimizing <script type="math/tex">\phi</script>.</p>
<p>Inspired by <a href="https://ai.google/research/pubs/pub45826">NAS</a>, <strong>AutoAugment</strong> (<a href="https://arxiv.org/abs/1805.09501">Cubuk, et al. 2018</a>) frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem. Note that AutoAugment is not proposed for sim2real transfer, but falls in the bucket of DR guided by task performance. Individual augmentation configuration is tested on the evaluation set and the performance improvement is used as a reward to train a PPO policy. This policy outputs different augmentation strategies for different datasets; for example, for CIFAR-10 AutoAugment mostly picks color-based transformations, while ImageNet prefers geometric based.</p>
<p><a href="https://arxiv.org/abs/1810.02513">Ruiz (2019)</a> considered the <em>task feedback</em> as <em>reward</em> in RL problem and proposed a RL-based method, named “learning to simulate”, for adjusting <script type="math/tex">\xi</script>. A policy is trained to predict <script type="math/tex">\xi</script> using performance metrics on the validation data of the main task as rewards, which is modeled as a multivariate Gaussian. Overall the idea is similar to AutoAugment, applying NAS on data generation. According to their experiments, even if the main task model is not converged, it still can provide a reasonable signal to the data generation policy.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/learning-to-simulate.png" alt="Learning to simulate" /></p>
<p><em>Fig. 3. An overview of the “learning to simulate” approach. (Image source: <a href="https://arxiv.org/abs/1810.02513">Ruiz (2019)</a>)</em></p>
<p>Evolutionary algorithm is another way to go, where the <em>feedback</em> is treated as <em>fitness</em> for guiding evolution (<a href="https://openreview.net/forum?id=H1g6osRcFQ">Yu et al, 2019</a>). In this study, they used <a href="https://en.wikipedia.org/wiki/CMA-ES">CMA-ES</a> (covariance matrix adaptation evolution strategy) while fitness is the performance of a <script type="math/tex">\xi</script>-conditional policy in target environment. In the appendix, they compared CMA-ES with other ways of modeling the dynamics of <script type="math/tex">\xi</script>, including Bayesian optimization or a neural network. The main claim was those methods are not as stable or sample efficient as CMA-ES. Interestly, when modeling <script type="math/tex">P(\xi)</script> as a neural network, LSTM is found to notably outperform FF.</p>
<p>Some believe that sim2real gap is a combination of appearance gap and content gap; i.e. most GAN-inspired DA models focus on appearance gap. <strong>Meta-Sim</strong> (<a href="https://arxiv.org/abs/1904.11621">Kar, et al. 2019</a>) aims to close the content gap by generating task-specific synthetic datasets. Meta-Sim uses self-driving car training as an example and thus the scene could be very complicated. In this case, the synthetic scenes are parameterized by a hierarchy of objects with properties (i.e., location, color) as well as relationships between objects. The hierarchy is specified by a probabilistic scene grammar akin to structure domain randomization (<strong>SDR</strong>; <a href="https://arxiv.org/abs/1810.10093">Prakash et al., 2018</a>) and it is assumed to be known beforehand. A model <script type="math/tex">G</script> is trained to augment the distribution of scene properties <script type="math/tex">s</script> by following:</p>
<ol>
<li>Learn the prior first: pre-train <script type="math/tex">G</script> to learn the identity function <script type="math/tex">G(s) = s</script>.</li>
<li>Minimize MMD loss between the real and sim data distributions. This involves backpropagation through non-differentiable renderer. The paper computes it numerically by perturbing the attributes of <script type="math/tex">G(s)</script>.</li>
<li>Minimize REINFORCE task loss when trained on synthetic data but evaluated on real data. Again, very similar to AutoAugment.</li>
</ol>
<p>Unfortunately, this family of methods are not suitable for sim2real case. Either an RL policy or an EA model requires a large number of real samples. And it is really expensive to include real-time feedback collection on a physical robot into the training loop. Whether you want to trade less computation resource for real data collection would depend on your task.</p>
<h3 id="match-real-data-distribution">Match Real Data Distribution</h3>
<p>Using real data to guide domain randomization feels a lot like doing system identification or DA. The core idea behind DA is to improve the synthetic data to match the real data distribution. In the case of real-data-guided DR, we would like to learn the randomization parameters <script type="math/tex">\xi</script> that bring the state distribution in simulator close to the state distribution in the real world.</p>
<p>The <strong>SimOpt</strong> model (<a href="https://arxiv.org/abs/1810.05687">Chebotar et al, 2019</a>) is trained under an initial randomization distribution <script type="math/tex">P_\phi(\xi)</script> first, getting a policy <script type="math/tex">\pi_{\theta, P_\phi}</script>. Then this policy is deployed on both simulator and physical robot to collect trajectories <script type="math/tex">\tau_\xi</script> and <script type="math/tex">\tau_\text{real}</script> respectively. The optimization objective is to minimize the discrepancy between sim and real trajectories:</p>
<script type="math/tex; mode=display">\phi^* = \arg\min_{\phi}\mathbb{E}_{\xi \sim P_\phi(\xi)} [\mathbb{E}_{\pi_{\theta, P_\phi}} [D(\tau_\text{sim}, \tau_\text{real})]]</script>
<p>where <script type="math/tex">D(.)</script> is a trajectory-based discrepancy measure. Like the “Learning to simulate” paper, SimOpt also has to solve the tricky problem of how to propagate gradient through non-differentiable simulator. It used a method called <a href="https://www.aaai.org/ocs/index.php/AAAI/AAAI10/paper/viewFile/1851/2264">relative entropy policy search</a>, see paper for more details.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/simopt.png" alt="SimOpt" /></p>
<p><em>Fig. 4. An overview of the SimOpt framework. (Image source: <a href="https://arxiv.org/abs/1810.05687">Chebotar et al, 2019</a>)</em></p>
<p><strong>RCAN</strong> (<a href="https://arxiv.org/abs/1812.07252">James et al., 2019</a>), short for “Randomized-to-Canonical Adaptation Networks”, is a nice combination of DA and DR for end-to-end RL tasks. An image-conditional GAN (<a href="https://arxiv.org/abs/1611.07004">cGAN</a>) is trained in sim to translate a domain-randomized image into a non-randomized version (aka “canonical version”). Later the same model is used to translate real images into corresponding simulated version so that the agent would consume consistent observation as what it has encountered in training. Still, the underlying assumption is that the distribution of domain-randomized sim images is broad enough to cover real-world samples.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/RCAN.png" alt="RCAN" /></p>
<p><em>Fig. 5. RCAN is an image-conditional generator that can convert a domain-randomized or real image into its corresponding non-randomized simulator version. (Image source: <a href="https://arxiv.org/abs/1812.07252">James et al., 2019</a>)</em></p>
<p>The RL model is trained end-to-end in a simulator to do vision-based robot arm grasping. Randomization is applied at each timestep, including the position of tray divider, objects to grasp, random textures, as well as the position, direction, and color of the lighting. The canonical version is the default simulator look. RCAN is trying to learn a generator</p>
<p><script type="math/tex">G</script>: randomized image <script type="math/tex">\to</script> {canonical image, segmentation, depth}</p>
<p>where segmentation masks and depth images are used as auxiliary tasks. RCAN had a better zero-shot transfer compared to uniform DR, although both were shown to be worse than the model trained on only real images. Conceptually, RCAN operates in a reverse direction of <a href="https://arxiv.org/abs/1709.07857">GraspGAN</a> which translates synthetic images into real ones by domain adaptation.</p>
<h3 id="guided-by-data-in-simulator">Guided by Data in Simulator</h3>
<p>Network-driven domain randomization (<a href="https://arxiv.org/abs/1904.02750">Zakharov et al., 2019</a>), also known as <strong>DeceptionNet</strong>, is motivated by learning which randomizations are actually useful to bridge the domain gap for image classification tasks.</p>
<p>Randomization is applied through a set of deception modules with encoder-decoder architecture. The deception modules are specifically designed to transform images; such as change backgrounds, add distortion, change lightings, etc. The other recognition network handles the main task by running classification on transformed images.</p>
<p>The training involves two steps:</p>
<ol>
<li>With the recognition network fixed, <em>maximize the difference</em> between the prediction and the labels by applying reversed gradients during backpropagation. So that the deception module can learn the most confusing tricks.</li>
<li>With the deception modules fixed, train the recognition network with input images altered.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/deception-net.png" alt="DeceptionNet" /></p>
<p><em>Fig. 6. How DeceptionNet works. (Image source: <a href="https://arxiv.org/abs/1904.02750">Zakharov et al., 2019</a>)</em></p>
<p>The feedback for training deception modules is provided by the downstream classifier. But rather than trying to maximize the task performance like <a href="#optimization-for-task-performance">the section</a> above, the randomization modules aim to create harder cases. One big disadvantage is you need to manually design different deception modules for different datasets or tasks, making it not easily scalable. Given the fact that it is zero-shot, the results are still worse than SOTA DA methods on MNIST and LineMOD.</p>
<p>Similarly, Active domain randomization (<strong>ADR</strong>; <a href="https://arxiv.org/abs/1904.04762">Mehta et al., 2019</a>) also relies on sim data to create harder training samples. ADR searches for the <em>most informative</em> environment variations within the given randomization ranges, where the <em>informativeness</em> is measured as the discrepancies of policy rollouts in randomized and reference (original, non-randomized) environment instances. Sounds a bit like <a href="#match-real-data-distribution">SimOpt</a>? Well, noted that SimOpt measures the discrepancy between sim and real rollouts, while ADR measures between randomized and non-randomized sim, avoiding the expensive real data collection part.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/ADR.png" alt="ADR" /></p>
<p><em>Fig. 7. How active domain randomization (ADR) works. (Image source: <a href="https://arxiv.org/abs/1904.04762">Mehta et al., 2019</a>)</em></p>
<p>Precisely the training happens as follows:</p>
<ol>
<li>Given a policy, run it on both reference and randomized envs and collect two sets of trajectories respectively.</li>
<li>Train a discriminator model to tell whether a rollout trajectory is randomized apart from reference run. The predicted <script type="math/tex">\log p</script> (probability of being randomized) is used as reward. The more different randomized and reference rollouts, the easier the prediction, the higher the reward.
<ul>
<li>The intuition is that if an environment is easy, the same policy agent can produce similar trajectories as in the reference one. Then the model should reward and explore hard environments by encouraging different behaviors.</li>
</ul>
</li>
<li>The reward by discriminator is fed into <em>Stein Variational Policy Gradient</em> (<a href="https://arxiv.org/abs/1704.02399">SVPG</a>) particles, outputting a diverse set of randomization configurations.</li>
</ol>
<p>The idea of ADR is very appealing with two small concerns. The similarity between trajectories might not be a good way to measure the env difficulty when running a stochastic policy. The sim2real results look unfortunately not as exciting, but the paper pointed out the win being ADR explores a smaller range of randomization parameters.</p>
<hr />
<p>Cited as:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019DR,
title = "Domain Randomization for Sim2Real Transfer",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/05/04/domain-randomization.html"
}
</code></pre></div></div>
<p>Overall, after reading this post, I hope you like domain randomization as much as I do :).</p>
<h2 id="references">References</h2>
<p>[1] Josh Tobin, et al. <a href="https://arxiv.org/pdf/1703.06907.pdf">“Domain randomization for transferring deep neural networks from simulation to the real world.”</a> IROS, 2017.</p>
<p>[2] Fereshteh Sadeghi and Sergey Levine. <a href="https://arxiv.org/abs/1611.04201">“CAD2RL: Real single-image flight without a single real image.”</a> arXiv:1611.04201 (2016).</p>
<p>[3] Xue Bin Peng, et al. <a href="https://arxiv.org/abs/1710.06537">“Sim-to-real transfer of robotic control with dynamics randomization.”</a> ICRA, 2018.</p>
<p>[4] Nataniel Ruiz, et al. <a href="https://openreview.net/forum?id=HJgkx2Aqt7">“Learning to Simulate.”</a> ICLR 2019</p>
<p>[5] OpenAI. <a href="https://arxiv.org/abs/1808.00177">“Learning Dexterous In-Hand Manipulation.”</a> arXiv:1808.00177 (2018).</p>
<p>[6] OpenAI Blog. <a href="https://openai.com/blog/learning-dexterity/">“Learning dexterity”</a> July 30, 2018.</p>
<p>[7] Quan Vuong, et al. <a href="https://arxiv.org/abs/1903.11774">“How to pick the domain randomization parameters for sim-to-real transfer of reinforcement learning policies?.”</a> arXiv:1903.11774 (2019).</p>
<p>[8] Ekin D. Cubuk, et al. <a href="https://arxiv.org/abs/1805.09501">“AutoAugment: Learning augmentation policies from data.”</a> arXiv:1805.09501 (2018).</p>
<p>[9] Wenhao Yu et al. <a href="https://openreview.net/forum?id=H1g6osRcFQ">“Policy Transfer with Strategy Optimization.”</a> ICLR 2019</p>
<p>[10] Yevgen Chebotar et al. <a href="https://arxiv.org/abs/1810.05687">“Closing the Sim-to-Real Loop: Adapting Simulation Randomization with Real World Experience.”</a> Arxiv: 1810.05687 (2019).</p>
<p>[11] Stephen James et al. <a href="https://arxiv.org/abs/1812.07252">“Sim-to-real via sim-to-sim: Data-efficient robotic grasping via randomized-to-canonical adaptation networks”</a> CVPR 2019.</p>
<p>[12] Bhairav Mehta et al. <a href="https://arxiv.org/abs/1904.04762">“Active Domain Randomization”</a> arXiv:1904.04762</p>
<p>[13] Sergey Zakharov,et al. <a href="https://arxiv.org/abs/1904.02750">“DeceptionNet: Network-Driven Domain Randomization.”</a> arXiv:1904.02750 (2019).</p>
<p>[14] Amlan Kar, et al. <a href="https://arxiv.org/abs/1904.11621">“Meta-Sim: Learning to Generate Synthetic Datasets.”</a> arXiv:1904.11621 (2019).</p>
<p>[15] Aayush Prakash, et al. <a href="https://arxiv.org/abs/1810.10093">“Structured Domain Randomization: Bridging the Reality Gap by Context-Aware Synthetic Data.”</a> arXiv:1810.10093 (2018).</p>Lilian WengIf a model or policy is mainly trained in a simulator but expected to work on a real robot, it would surely face the sim2real gap. Domain Randomization (DR) is a simple but powerful idea of closing this gap by randomizing properties of the training environment.Are Deep Neural Networks Dramatically Overfitted?2019-03-14T12:00:00+00:002019-03-14T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/03/14/are-deep-neural-networks-dramatically-overfitted<blockquote>
<p>If you are, like me, confused by why deep neural networks can generalize to out-of-sample data points without drastic overfitting, keep on reading.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2019-05-27: add the <a href="#the-lottery-ticket-hypothesis">section</a> on Lottery Ticket Hypothesis.]</span></p>
<p>If you are like me, entering into the field of deep learning with experience in traditional machine learning, you may often ponder over this question: Since a typical deep neural network has so many parameters and training error can easily be perfect, it should surely suffer from substantial overfitting. How could it be ever generalized to out-of-sample data points?</p>
<p>The effort in understanding why deep neural networks can generalize somehow reminds me of this interesting paper on System Biology — <a href="https://bml.bioe.uic.edu/BML/Stuff/Stuff_files/biologist%20fix%20radio.pdf">“Can a biologist fix a radio?”</a> (Lazebnik, 2002). If a biologist intends to fix a radio machine like how she works on a biological system, life could be hard. Because the full mechanism of the radio system is not revealed, poking small local functionalities might give some hints but it can hardly present all the interactions within the system, let alone the entire working flow. No matter whether you think it is relevant to DL, it is a very fun read.</p>
<p>I would like to discuss a couple of papers on generalizability and complexity measurement of deep learning models in the post. Hopefully, it could shed light on your thinking path towards the understanding of why DNN can generalize.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#classic-theorems-on-compression-and-model-selection" id="markdown-toc-classic-theorems-on-compression-and-model-selection">Classic Theorems on Compression and Model Selection</a> <ul>
<li><a href="#occams-razor" id="markdown-toc-occams-razor">Occam’s Razor</a></li>
<li><a href="#minimum-description-length-principle" id="markdown-toc-minimum-description-length-principle">Minimum Description Length principle</a></li>
<li><a href="#kolmogorov-complexity" id="markdown-toc-kolmogorov-complexity">Kolmogorov Complexity</a></li>
<li><a href="#solomonoffs-inference-theory" id="markdown-toc-solomonoffs-inference-theory">Solomonoff’s Inference Theory</a></li>
</ul>
</li>
<li><a href="#expressive-power-of-dl-models" id="markdown-toc-expressive-power-of-dl-models">Expressive Power of DL Models</a> <ul>
<li><a href="#universal-approximation-theorem" id="markdown-toc-universal-approximation-theorem">Universal Approximation Theorem</a></li>
<li><a href="#proof-finite-sample-expressivity-of-two-layer-nn" id="markdown-toc-proof-finite-sample-expressivity-of-two-layer-nn">Proof: Finite Sample Expressivity of Two-layer NN</a></li>
<li><a href="#deep-nn-can-learn-random-noise" id="markdown-toc-deep-nn-can-learn-random-noise">Deep NN can Learn Random Noise</a></li>
</ul>
</li>
<li><a href="#are-deep-learning-models-dramatically-overfitted" id="markdown-toc-are-deep-learning-models-dramatically-overfitted">Are Deep Learning Models Dramatically Overfitted?</a> <ul>
<li><a href="#modern-risk-curve-for-deep-learning" id="markdown-toc-modern-risk-curve-for-deep-learning">Modern Risk Curve for Deep Learning</a></li>
<li><a href="#regularization-is-not-the-key-to-generalization" id="markdown-toc-regularization-is-not-the-key-to-generalization">Regularization is not the Key to Generalization</a></li>
<li><a href="#intrinsic-dimension" id="markdown-toc-intrinsic-dimension">Intrinsic Dimension</a></li>
<li><a href="#heterogeneous-layer-robustness" id="markdown-toc-heterogeneous-layer-robustness">Heterogeneous Layer Robustness</a></li>
<li><a href="#the-lottery-ticket-hypothesis" id="markdown-toc-the-lottery-ticket-hypothesis">The Lottery Ticket Hypothesis</a></li>
</ul>
</li>
<li><a href="#experiments" id="markdown-toc-experiments">Experiments</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="classic-theorems-on-compression-and-model-selection">Classic Theorems on Compression and Model Selection</h2>
<p>Let’s say we have a classification problem and a dataset, we can develop many models to solve it, from fitting a simple linear regression to memorizing the full dataset in disk space. Which one is better? If we only care about the accuracy over training data (especially given that testing data is likely unknown), the memorization approach seems to be the best — well, it doesn’t sound right.</p>
<p>There are many classic theorems to guide us when deciding what types of properties a good model should possess in such scenarios.</p>
<h3 id="occams-razor">Occam’s Razor</h3>
<p><a href="http://pespmc1.vub.ac.be/OCCAMRAZ.html">Occam’s Razor</a> is an informal principle for problem-solving, proposed by <a href="https://en.wikipedia.org/wiki/William_of_Ockham">William of Ockham</a> in the 14th century:</p>
<blockquote>
<p>“Simpler solutions are more likely to be correct than complex ones.”</p>
</blockquote>
<p>The statement is extremely powerful when we are facing multiple candidates of underlying theories to explain the world and have to pick one. Too many unnecessary assumptions might seem to be plausible for one problem, but harder to be generalized to other complications or to eventually lead to the basic principles of the universe.</p>
<p>Think of this, it took people hundreds of years to figure out that the sky is blue in the daytime but reddish at sunset are because of the same reason (<a href="https://en.wikipedia.org/wiki/Rayleigh_scattering">Rayleigh scattering</a>), although two phenomena look very different. People must have proposed many other explanations for them separately but the unified and simple version won eventually.</p>
<h3 id="minimum-description-length-principle">Minimum Description Length principle</h3>
<p>The principle of Occam’s Razor can be similarly applied to machine learning models. A formalized version of such concept is called the <em>Minimum Description Length (MDL)</em> principle, used for comparing competing models / explanations given data observed.</p>
<blockquote>
<p>“Comprehension is compression.”</p>
</blockquote>
<p>The fundamental idea in MDL is to <em>view learning as data compression</em>. By compressing the data, we need to discover regularity or patterns in the data with the high potentiality to generalize to unseen samples. <a href="/lil-log/2017/09/28/anatomize-deep-learning-with-information-theory.html">Information bottleneck</a> theory believes that a deep neural network is trained first to represent the data by minimizing the generalization error and then learn to compress this representation by trimming noise.</p>
<p>Meanwhile, MDL considers the model description as part of the compression delivery, so the model cannot be arbitrarily large.</p>
<p>A <em>two-part version</em> of MDL principle states that: Let <script type="math/tex">\mathcal{H}^{(1)}, \mathcal{H}^{(2)}, \dots</script> be a list of models that can explain the dataset <script type="math/tex">\mathcal{D}</script>. The best hypothesis among them should be the one that minimizes the sum:</p>
<script type="math/tex; mode=display">\mathcal{H}^\text{best} = \arg\min_\mathcal{H} [L(\mathcal{H}) + L(\mathcal{D}\vert\mathcal{H})]</script>
<ul>
<li><script type="math/tex">L(\mathcal{H})</script> is the length of the description of model <script type="math/tex">\mathcal{H}</script> in bits.</li>
<li><script type="math/tex">L(\mathcal{D}\vert\mathcal{H})</script> is the length of the description of the data <script type="math/tex">\mathcal{D}</script> in bits when encoded with <script type="math/tex">\mathcal{H}</script>.</li>
</ul>
<p>In simple words, the <em>best</em> model is the <em>smallest</em> model containing the encoded data and the model itself. Following this criterion, the memorization approach I proposed at the beginning of the section sounds horrible no matter how good accuracy it can achieve on the training data.</p>
<p>People might argue Occam’s Razor is wrong, as given the real world can be arbitrarily complicated, why do we have to find simple models? One interesting view by MDL is to consider models as <strong>“languages”</strong> instead of fundamental generative theorems. We would like to find good compression strategies to describe regularity in a small set of samples, and they <strong>do not have to be the “real” generative model</strong> for explaining the phenomenon. Models can be wrong but still useful (i.e., think of any Bayesian prior).</p>
<h3 id="kolmogorov-complexity">Kolmogorov Complexity</h3>
<p>Kolmogorov Complexity relies on the concept of modern computers to define the algorithmic (descriptive) complexity of an object: It is <em>the length of the shortest binary computer program that describes the object</em>. Following MDL, a computer is essentially the most general form of data decompressor.</p>
<p>The formal definition of Kolmogorov Complexity states that: Given a universal computer <script type="math/tex">\mathcal{U}</script> and a program <script type="math/tex">p</script>, let’s denote <script type="math/tex">\mathcal{U}(p)</script> as the output of the computer processing the program and <script type="math/tex">L(p)</script> as the descriptive length of the program. Then Kolmogorov Complexity <script type="math/tex">K_\mathcal{U}</script> of a string <script type="math/tex">s</script> with respect to a universal computer <script type="math/tex">\mathcal{U}</script> is:</p>
<script type="math/tex; mode=display">K_\mathcal{U}(s) = \min_{p: \mathcal{U}(p)=s} L(p)</script>
<p>Note that a universal computer is one that can mimic the actions of any other computers. All modern computers are universal as they can all be reduced to Turing machines. The definition is universal no matter which computers we are using, because another universal computer can always be programmed to clone the behavior of <script type="math/tex">\mathcal{U}</script>, while encoding this clone program is just a constant.</p>
<p>There are a lot of connections between Kolmogorov Complexity and Shannon Information Theory, as both are tied to universal coding. It is an amazing fact that the expected Kolmogorov Complexity of a random variable is approximately equal to its Shannon entropy (see Sec 2.3 of <a href="https://homepages.cwi.nl/~paulv/papers/info.pdf">the report</a>). More on this topic is out of the scope here, but there are many interesting readings online. Help yourself :)</p>
<h3 id="solomonoffs-inference-theory">Solomonoff’s Inference Theory</h3>
<p>Another mathematical formalization of Occam’s Razor is Solomonoff’s theory of universal inductive inference (<a href="https://www.sciencedirect.com/science/article/pii/S0019995864902232">Solomonoff</a>, <a href="https://www.sciencedirect.com/science/article/pii/S0019995864901317">1964</a>). The principle is to favor models that correspond to the “shortest program” to produce the training data, based on its Kolmogorov complexity</p>
<h2 id="expressive-power-of-dl-models">Expressive Power of DL Models</h2>
<p>Deep neural networks have an extremely large number of parameters compared to the traditional statistical models. If we use MDL to measure the complexity of a deep neural network and consider the number of parameters as the model description length, it would look awful. The model description <script type="math/tex">L(\mathcal{H})</script> can easily grow out of control.</p>
<p>However, having numerous parameters is <em>necessary</em> for a neural network to obtain high expressivity power. Because of its great capability to capture any flexible data representation, deep neural networks have achieved great success in many applications.</p>
<h3 id="universal-approximation-theorem">Universal Approximation Theorem</h3>
<p>The <em>Universal Approximation Theorem</em> states that a feedforward network with: 1) a linear output layer, 2) at least one hidden layer containing a finite number of neurons and 3) some activation function can approximate <strong>any</strong> continuous functions on a compact subset of <script type="math/tex">\mathbb{R}^n</script> to arbitrary accuracy. The theorem was first proved for sigmoid activation function (<a href="https://pdfs.semanticscholar.org/05ce/b32839c26c8d2cb38d5529cf7720a68c3fab.pdf">Cybenko, 1989</a>). Later it was shown that the universal approximation property is not specific to the choice of activation (<a href="http://zmjones.com/static/statistical-learning/hornik-nn-1991.pdf">Hornik, 1991</a>) but the multilayer feedforward architecture.</p>
<p>Although a feedforward network with a single layer is sufficient to represent any function, the width has to be exponentially large. The universal approximation theorem does not guarantee whether the model can be learned or generalized properly. Often, adding more layers helps to reduce the number of hidden neurons needed in a shallow network.</p>
<p>To take advantage of the universal approximation theorem, we can always find a neural network to represent the target function with error under any desired threshold, but we need to pay the price — the network might grow super large.</p>
<h3 id="proof-finite-sample-expressivity-of-two-layer-nn">Proof: Finite Sample Expressivity of Two-layer NN</h3>
<p>The Universal Approximation Theorem we have discussed so far does not consider a finite sample set. <a href="https://arxiv.org/abs/1611.03530">Zhang, et al. (2017)</a> provided a neat proof on the finite-sample expressivity of two-layer neural networks.</p>
<p>A neural network <script type="math/tex">C</script> can represent any function given a sample size <script type="math/tex">n</script> in <script type="math/tex">d</script> dimensions if: For every finite sample set <script type="math/tex">S \subseteq \mathbb{R}^d</script> with <script type="math/tex">\vert S \vert = n</script> and every function defined on this sample set: <script type="math/tex">f: S \mapsto \mathbb{R}</script>, we can find a set of weight configuration for <script type="math/tex">C</script> so that <script type="math/tex">C(\boldsymbol{x}) = f(\boldsymbol{x}), \forall \boldsymbol{x} \in S</script>.</p>
<p>The paper proposed a theorem:</p>
<blockquote>
<p>There exists a two-layer neural network with ReLU activations and <script type="math/tex">2n + d</script> weights that can represent any function on a sample of size <script type="math/tex">n</script> in <script type="math/tex">d</script> dimensions.</p>
</blockquote>
<p><em>Proof.</em> First we would like to construct a two-layer neural network <script type="math/tex">C: \mathbb{R}^d \mapsto \mathbb{R}</script>. The input is a <script type="math/tex">d</script>-dimensional vector, <script type="math/tex">\boldsymbol{x} \in \mathbb{R}^d</script>. The hidden layer has <script type="math/tex">h</script> hidden units, associated with a weight matrix <script type="math/tex">\mathbf{W} \in \mathbb{R}^{d\times h}</script>, a bias vector <script type="math/tex">-\mathbf{b} \in \mathbb{R}^h</script> and ReLU activation function. The second layer outputs a scalar value with weight vector <script type="math/tex">\boldsymbol{v} \in \mathbb{R}^h</script> and zero biases.</p>
<p>The output of network <script type="math/tex">C</script> for a input vector <script type="math/tex">\boldsymbol{x}</script> can be represented as follows:</p>
<script type="math/tex; mode=display">C(\boldsymbol{x})
= \boldsymbol{v} \max\{ \boldsymbol{x}\mathbf{W} - \boldsymbol{b}, 0\}^\top
= \sum_{i=1}^h v_i \max\{\boldsymbol{x}\boldsymbol{W}_{(:,i)} - b_i, 0\}</script>
<p>where <script type="math/tex">\boldsymbol{W}_{(:,i)}</script> is the <script type="math/tex">i</script>-th column in the <script type="math/tex">d \times h</script> matrix.</p>
<p>Given a sample set <script type="math/tex">S = \{\boldsymbol{x}_1, \dots, \boldsymbol{x}_n\}</script> and target values <script type="math/tex">\boldsymbol{y} = \{y_1, \dots, y_n \}</script>, we would like to find proper weights <script type="math/tex">\mathbf{W} \in \mathbb{R}^{d\times h}</script>, <script type="math/tex">\boldsymbol{b}, \boldsymbol{v} \in \mathbb{R}^h</script> so that <script type="math/tex">C(\boldsymbol{x}_i) = y_i, \forall i=1,\dots,n</script>.</p>
<p>Let’s combine all sample points into one batch as one input matrix <script type="math/tex">\mathbf{X} \in \mathbb{R}^{n \times d}</script>. If set <script type="math/tex">h=n</script>, <script type="math/tex">\mathbf{X}\mathbf{W} - \boldsymbol{b}</script> would be a square matrix of size <script type="math/tex">n \times n</script>.</p>
<script type="math/tex; mode=display">\mathbf{M}_\text{ReLU}
= \max\{\mathbf{X}\mathbf{W} - \boldsymbol{b}, 0 \}
= \begin{bmatrix}
\boldsymbol{x}_1\mathbf{W} - \boldsymbol{b} \\
\dots \\
\boldsymbol{x}_n\mathbf{W} - \boldsymbol{b} \\
\end{bmatrix}
= [\boldsymbol{x}_i\boldsymbol{W}_{(:,j)} - b_j]_{i \times j}</script>
<p>We can simplify <script type="math/tex">\mathbf{W}</script> to have the same column vectors across all the columns:</p>
<script type="math/tex; mode=display">\mathbf{W}_{(:,j)} = \boldsymbol{w} \in \mathbb{R}^{d}, \forall j = 1, \dots, n</script>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/nn-expressivity-proof.png" alt="intrinsic dimension experiment 1" /></p>
<p>Let <script type="math/tex">a_i = \boldsymbol{x}_i \boldsymbol{w}</script>, we would like to find a suitable <script type="math/tex">\boldsymbol{w}</script> and <script type="math/tex">\boldsymbol{b}</script> such that <script type="math/tex">% <![CDATA[
b_1 < a_1 < b_2 < a_2 < \dots < b_n < a_n %]]></script>. This is always achievable because we try to solve <script type="math/tex">n+d</script> unknown variables with <script type="math/tex">n</script> constraints and <script type="math/tex">\boldsymbol{x}_i</script> are independent (i.e. pick a random <script type="math/tex">\boldsymbol{w}</script>, sort <script type="math/tex">\boldsymbol{x}_i \boldsymbol{w}</script> and then set <script type="math/tex">b_j</script>’s as values in between). Then <script type="math/tex">\mathbf{M}_\text{ReLU}</script> becomes a lower triangular matrix:</p>
<script type="math/tex; mode=display">% <![CDATA[
\mathbf{M}_\text{ReLU} = [a_i - b_j]_{i \times j}
= \begin{bmatrix}
a_1 - b_1 & 0 & 0 & \dots & 0 \\
\vdots & \ddots & & & \vdots \\
a_i - b_1 & \dots & a_i - b_i & \dots & 0\\
\vdots & & & \ddots & \vdots \\
a_n - b_1 & a_n - b_2 & \dots & \dots & a_n - b_n \\
\end{bmatrix} %]]></script>
<p>It is a nonsingular square matrix as <script type="math/tex">\det(\mathbf{M}_\text{ReLU}) \neq 0</script>, so we can always find suitable <script type="math/tex">\boldsymbol{v}</script> to solve <script type="math/tex">\boldsymbol{v}\mathbf{M}_\text{ReLU}=\boldsymbol{y}</script> (In other words, the column space of <script type="math/tex">\mathbf{M}_\text{ReLU}</script> is all of <script type="math/tex">\mathbb{R}^n</script> and we can find a linear combination of column vectors to obtain any <script type="math/tex">\boldsymbol{y}</script>).</p>
<h3 id="deep-nn-can-learn-random-noise">Deep NN can Learn Random Noise</h3>
<p>As we know two-layer neural networks are universal approximators, it is less surprising to see that they are able to learn unstructured random noise perfectly, as shown in <a href="https://arxiv.org/abs/1611.03530">Zhang, et al. (2017)</a>. If labels of image classification dataset are randomly shuffled, the high expressivity power of deep neural networks can still empower them to achieve near-zero training loss. These results do not change with regularization terms added.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/fit-random-labels.png" alt="Fitting random labels" /></p>
<p><em>Fig. 1. Fit models on CIFAR10 with random labels or random pixels: (a) learning curves; (b-c) label corruption ratio is the percentage of randomly shuffled labels. (Image source: <a href="https://arxiv.org/abs/1611.03530">Zhang’s paper</a>)</em></p>
<h2 id="are-deep-learning-models-dramatically-overfitted">Are Deep Learning Models Dramatically Overfitted?</h2>
<p>Deep learning models are heavily over-parameterized and can often get to perfect results on training data. In the traditional view, like bias-variance trade-offs, this could be a disaster that nothing may generalize to the unseen test data. However, as is often the case, such “overfitted” (training error = 0) deep learning models still present a decent performance on out-of-sample test data. Hmm … interesting and why?</p>
<h3 id="modern-risk-curve-for-deep-learning">Modern Risk Curve for Deep Learning</h3>
<p>The traditional machine learning uses the following U-shape risk curve to measure the bias-variance trade-offs and quantify how generalizable a model is. If I get asked how to tell whether a model is overfitted, this would be the first thing popping into my mind.</p>
<p>As the model turns larger (more parameters added), the training error decreases to close to zero, but the test error (generalization error) starts to increase once the model complexity grows to pass the threshold between “underfitting” and “overfitting”. In a way, this is well aligned with Occam’s Razor.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/bias-variance-risk-curve.png" alt="Bias-variance risk curve" /></p>
<p><em>Fig. 2. U-shaped bias-variance risk curve. (Image source: (left) <a href="https://arxiv.org/abs/1812.11118">paper</a> (right) <a href="http://scott.fortmann-roe.com/docs/BiasVariance.html">fig. 6 of this post</a>)</em></p>
<p>Unfortunately this does not apply to deep learning models. <a href="https://arxiv.org/abs/1812.11118">Belkin et al. (2018)</a> reconciled the traditional bias-variance trade-offs and proposed a new double-U-shaped risk curve for deep neural networks. Once the number of network parameters is high enough, the risk curve enters another regime.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/new-bias-variance-risk-curve.png" alt="new risk curve" /></p>
<p><em>Fig. 3. A new double-U-shaped bias-variance risk curve for deep neural networks. (Image source: <a href="https://arxiv.org/abs/1812.11118">original paper</a>)</em></p>
<p>The paper claimed that it is likely due to two reasons:</p>
<ul>
<li>The number of parameters is not a good measure of <em>inductive bias</em>, defined as the set of assumptions of a learning algorithm used to predict for unknown samples. See more discussion on DL model complexity in <a href="#intrinsic-dimension">later</a> <a href="#heterogeneous-layer-robustness">sections</a>.</li>
<li>Equipped with a larger model, we might be able to discover larger function classes and further find interpolating functions that have smaller norm and are thus “simpler”.</li>
</ul>
<p>The double-U-shaped risk curve was observed empirically, as shown in the paper. However I was struggling quite a bit to reproduce the results. There are some signs of life, but in order to generate a pretty smooth curve similar to the theorem, <a href="#experiments">many details</a> in the experiment have to be taken care of.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/new-risk-curve-mnist.png" alt="New risk curve on MNIST" /></p>
<p><em>Fig. 4. Training and evaluation errors of a one hidden layer fc network of different numbers of hidden units, trained on 4000 data points sampled from MNIST. (Image source: <a href="https://arxiv.org/abs/1812.11118">original paper</a>)</em></p>
<h3 id="regularization-is-not-the-key-to-generalization">Regularization is not the Key to Generalization</h3>
<p>Regularization is a common way to control overfitting and improve model generalization performance. Interestingly some research (<a href="https://arxiv.org/abs/1611.03530">Zhang, et al. 2017</a>) has shown that explicit regularization (i.e. data augmentation, weight decay and dropout) is neither necessary or sufficient for reducing generalization error.</p>
<p>Taking the Inception model trained on CIFAR10 as an example (see Fig. 5), regularization techniques help with out-of-sample generalization but not much. No single regularization seems to be critical independent of other terms. Thus, it is unlikely that regularizers are the <em>fundamental reason</em> for generalization.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/regularization-generalization-test.png" alt="regularization test" /></p>
<p><em>Fig. 5. The accuracy of Inception model trained on CIFAR10 with different combinations of taking on or off data augmentation and weight decay. (Image source: Table 1 in the <a href="https://arxiv.org/abs/1611.03530">original paper</a>)</em></p>
<h3 id="intrinsic-dimension">Intrinsic Dimension</h3>
<p>The number of parameters is not correlated with model overfitting in the field of deep learning, suggesting that parameter counting cannot indicate the true complexity of deep neural networks.</p>
<p>Apart from parameter counting, researchers have proposed many ways to quantify the complexity of these models, such as the number of degrees of freedom of models (<a href="https://arxiv.org/abs/1603.09260">Gao & Jojic, 2016</a>), or prequential code (<a href="https://arxiv.org/abs/1802.07044">Blier & Ollivier, 2018</a>).</p>
<p>I would like to discuss a recent method on this matter, named <strong>intrinsic dimension</strong> (<a href="https://arxiv.org/abs/1804.08838">Li et al, 2018</a>). Intrinsic dimension is intuitive, easy to measure, while still revealing many interesting properties of models of different sizes.</p>
<p>Considering a neural network with a great number of parameters, forming a high-dimensional parameter space, the learning happens on this high-dimensional <em>objective landscape</em>.
The shape of the parameter space manifold is critical. For example, a smoother manifold is beneficial for optimization by providing more predictive gradients and allowing for larger learning rates—this was claimed to be the reason why batch normalization has succeeded in stabilizing training (<a href="https://arxiv.org/abs/1805.11604">Santurkar, et al, 2019</a>).</p>
<p>Even though the parameter space is huge, fortunately we don’t have to worry too much about the optimization process getting stuck in local optima, as it has been <a href="https://arxiv.org/abs/1406.2572">shown</a> that local optimal points in the objective landscape almost always lay in saddle-points rather than valleys. In other words, there is always a subset of dimensions containing paths to leave local optima and keep on exploring.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/optimization-landscape-shape.png" alt="parameter landscape shape" /></p>
<p><em>Fig. 6. Illustrations of various types of critical points on the parameter optimization landscape. (Image source: <a href="https://www.offconvex.org/2016/03/22/saddlepoints/">here</a>)</em></p>
<p>One intuition behind the measurement of intrinsic dimension is that, since the parameter space has such high dimensionality, it is probably not necessary to exploit all the dimensions to learn efficiently. If we only travel through a slice of objective landscape and still can learn a good solution, the complexity of the resulting model is likely lower than what it appears to be by parameter-counting. This is essentially what intrinsic dimension tries to assess.</p>
<p>Say a model has <script type="math/tex">D</script> dimensions and its parameters are denoted as <script type="math/tex">\theta^{(D)}</script>. For learning, a smaller <script type="math/tex">d</script>-dimensional subspace is randomly sampled, <script type="math/tex">\theta^{(d)}</script>, where <script type="math/tex">% <![CDATA[
d < D %]]></script>. During one optimization update, rather than taking a gradient step according to all <script type="math/tex">D</script> dimensions, only the smaller subspace <script type="math/tex">\theta^{(d)}</script> is used and remapped to update model parameters.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/intrinsic-dimension-illustration.png" alt="illustration" /></p>
<p><em>Fig. 7. Illustration of parameter vectors for direct optimization when <script type="math/tex">D=3</script>. (Image source: <a href="https://arxiv.org/abs/1804.08838">original paper</a>)</em></p>
<p>The gradient update formula looks like the follows:</p>
<script type="math/tex; mode=display">\theta^{(D)} = \theta_0^{(D)} + \mathbf{P} \theta^{(d)}</script>
<p>where <script type="math/tex">\theta_0^{(D)}</script> are the initialization values and <script type="math/tex">\mathbf{P}</script> is a <script type="math/tex">D \times d</script> projection matrix that is randomly sampled before training. Both <script type="math/tex">\theta_0^{(D)}</script> and <script type="math/tex">\mathbf{P}</script> are not trainable and fixed during training. <script type="math/tex">\theta^{(d)}</script> is initialized as all zeros.</p>
<p>By searching through the value of <script type="math/tex">d = 1, 2, \dots, D</script>, the corresponding <script type="math/tex">d</script> when the solution emerges is defined as the <em>intrinsic dimension</em>.</p>
<p>It turns out many problems have much smaller intrinsic dimensions than the number of parameters. For example, on CIFAR10 image classification, a fully-connected network with 650k+ parameters has only 9k intrinsic dimension and a convolutional network containing 62k parameters has an even lower intrinsic dimension of 2.9k.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/intrinsic-dimension.png" alt="intrinsic dimension results" /></p>
<p><em>Fig. 8. The measured intrinsic dimensions <script type="math/tex">d</script> for various models achieving 90% of the best performance. (Image source: <a href="https://arxiv.org/abs/1804.08838">original paper</a>)</em></p>
<p>The measurement of intrinsic dimensions suggests that deep learning models are significantly simpler than what they might appear to be.</p>
<h3 id="heterogeneous-layer-robustness">Heterogeneous Layer Robustness</h3>
<p><a href="https://arxiv.org/abs/1902.01996">Zhang et al. (2019)</a> investigated the role of parameters in different layers. The fundamental question raised by the paper is: <em>“are all layers created equal?”</em> The short answer is: No. The model is more sensitive to changes in some layers but not others.</p>
<p>The paper proposed two types of operations that can be applied to parameters of the <script type="math/tex">\ell</script>-th layer, <script type="math/tex">\ell = 1, \dots, L</script>, at time <script type="math/tex">t</script>, <script type="math/tex">\theta^{(\ell)}_t</script> to test their impacts on model robustness:</p>
<ul>
<li>
<p><strong>Re-initialization</strong>: Reset the parameters to the initial values, <script type="math/tex">\theta^{(\ell)}_t \leftarrow \theta^{(\ell)}_0</script>. The performance of a network in which layer <script type="math/tex">\ell</script> was re-initialized is referred to as the <em>re-initialization robustness</em> of layer <script type="math/tex">\ell</script>.</p>
</li>
<li>
<p><strong>Re-randomization</strong>: Re-sampling the layer’s parameters at random, <script type="math/tex">\theta^{(\ell)}_t \leftarrow \tilde{\theta}^{(\ell)} \sim \mathcal{P}^{(\ell)}</script>. The corresponding network performance is called the <em>re-randomization robustness</em> of layer <script type="math/tex">\ell</script>.</p>
</li>
</ul>
<p>Layers can be categorized into two categories with the help of these two operations:</p>
<ul>
<li><strong>Robust Layers</strong>: The network has no or only negligible performance degradation after re-initializing or re-randomizing the layer.</li>
<li><strong>Critical Layers</strong>: Otherwise.</li>
</ul>
<p>Similar patterns are observed on fully-connected and convolutional networks. Re-randomizing any of the layers <em>completely destroys</em> the model performance, as the prediction drops to random guessing immediately. More interestingly and surprisingly, when applying re-initialization, only the first or the first few layers (those closest to the input layer) are critical, while re-initializing higher levels causes <em>only negligible decrease</em> in performance.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/layer-robustness-results.png" alt="Re-initialization robustness" /></p>
<p><em>Fig. 9. (a) A fc network trained on MNIST. Each row corresponds to one layer in the network. The first column is re-randomization robustness of each layer and the rest of the columns indicate re-initialization robustness at different training time. (b) VGG11 model (conv net) trained on CIFAR 10. Similar representation as in (a) but rows and columns are transposed. (Image source: <a href="https://arxiv.org/abs/1902.01996">original paper</a>)</em></p>
<p>ResNet is able to use shortcuts between non-adjacent layers to re-distribute the sensitive layers across the networks rather than just at the bottom. With the help of residual block architecture, the network can <em>evenly be robust to re-randomization</em>. Only the first layer of each residual block is still sensitive to both re-initialization and re-randomization. If we consider each residual block as a local sub-network, the robustness pattern resembles the fc and conv nets above.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/layer-robustness-resnet.png" alt="ResNet robustness" /></p>
<p><em>Fig. 10. Re-randomization (first row) and re-initialization (the reset rows) robustness of layers in ResNet-50 model trained on CIFAR10. (Image source: <a href="https://arxiv.org/abs/1902.01996">original paper</a>)</em></p>
<p>Based on the fact that many top layers in deep neural networks are not critical to the model performance after re-initialization, the paper loosely concluded that:</p>
<blockquote>
<p>“Over-capacitated deep networks trained with stochastic gradient have low-complexity due to self-restricting the number of critical layers.”</p>
</blockquote>
<p>We can consider re-initialization as a way to reduce the effective number of parameters, and thus the observation is aligned with what intrinsic dimension has demonstrated.</p>
<h3 id="the-lottery-ticket-hypothesis">The Lottery Ticket Hypothesis</h3>
<p>The lottery ticket hypothesis (<a href="https://arxiv.org/abs/1803.03635">Frankle & Carbin, 2019</a>) is another intriguing and inspiring discovery, supporting that only a subset of network parameters have impact on the model performance and thus the network is not overfitted. The lottery ticket hypothesis states that a randomly initialized, dense, feed-forward network contains a pool of subnetworks and among them only a subset are <em>“winning tickets”</em> which can achieve the optimal performance when <em>trained in isolation</em>.</p>
<p>The idea is motivated by network pruning techniques — removing unnecessary weights (i.e. tiny weights that are almost negligible) without harming the model performance. Although the final network size can be reduced dramatically, it is hard to train such a pruned network architecture successfully from scratch. It feels like in order to successfully train a neural network, we need a large number of parameters, but we don’t need that many parameters to keep the accuracy high once the model is trained. Why is that?</p>
<p>The lottery ticket hypothesis did the following experiments:</p>
<ol>
<li>Randomly initialize a dense feed-forward network with initialization values <script type="math/tex">\theta_0</script>;</li>
<li>Train the network for multiple iterations to achieve a good performance with parameter config <script type="math/tex">\theta</script>;</li>
<li>Run pruning on <script type="math/tex">\theta</script> and creating a mask <script type="math/tex">m</script>.</li>
<li>The “winning ticket” initialization config is <script type="math/tex">m \odot \theta_0</script>.</li>
</ol>
<p>Only training the small “winning ticket” subset of parameters with the initial values as found in step 1, the model is able to achieve the same level of accuracy as in step 2. It turns out a large parameter space is not needed in the final solution representation, but needed for training as it provides a big pool of initialization configs of many much smaller subnetworks.</p>
<p>The lottery ticket hypothesis opens a new perspective about interpreting and dissecting deep neural network results. Many interesting following-up works are on the way.</p>
<h2 id="experiments">Experiments</h2>
<p>After seeing all the interesting findings above, it should be pretty fun to reproduce them. Some results are easily to reproduce than others. Details are described below. My code is available on github <a href="https://github.com/lilianweng/generalization-experiment">lilianweng/generalization-experiment</a>.</p>
<p><strong>New Risk Curve for DL Models</strong></p>
<p>This is the trickiest one to reproduce. The authors did give me a lot of good advice and I appreciate it a lot. Here are a couple of noticeable settings in their experiments:</p>
<ul>
<li>There are no regularization terms like weight decay, dropout.</li>
<li>In Fig 3, the training set contains 4k samples. It is only sampled once and fixed for all the models. The evaluation uses the full MNIST test set.</li>
<li>Each network is trained for a long time to achieve near-zero training risk. The learning rate is adjusted differently for models of different sizes.</li>
<li>To make the model less sensitive to the initialization in the under-parameterization region, their experiments adopted a <em>“weight reuse”</em> scheme: the parameters obtained from training a smaller neural network are used as initialization for training larger networks.</li>
</ul>
<p>I did not train or tune each model long enough to get perfect training performance, but evaluation error indeed shows a special twist around the interpolation threshold, different from training error. For example, for MNIST, the threshold is the number of training samples times the number of classes (10), that is 40000.</p>
<p>The x-axis is the number of model parameters: (28 * 28 + 1) * num. units + num. units * 10, in logarithm.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/risk_curve_loss-mse_sample-4000_epoch-500.png" alt="risk curve experiment 1" /></p>
<p><br /></p>
<p><strong>Layers are not Created Equal</strong></p>
<p>This one is fairly easy to reproduce. See my implementation <a href="https://github.com/lilianweng/generalization-experiment/blob/master/layer_equality.py">here</a>.</p>
<p>In the first experiment, I used a three-layer fc networks with 256 units in each layer. Layer 0 is the input layer while layer 3 is the output. The network is trained on MNIST for 100 epochs.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/layer_equality_256x3.png" alt="Layer equality experiment 1" /></p>
<p>In the second experiment, I used a four-layer fc networks with 128 units in each layer. Other settings are the same as experiment 1.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/layer_equality_128x4.png" alt="Layer equality experiment 2" /></p>
<p><br /></p>
<p><strong>Intrinsic Dimension Measurement</strong></p>
<p>To correctly map the <script type="math/tex">d</script>-dimensional subspace to the full parameter space, the projection matrix <script type="math/tex">\mathbf{P}</script> should have orthogonal columns. Because the production <script type="math/tex">\mathbf{P}\theta^{(d)}</script> is the sum of columns of <script type="math/tex">\mathbf{P}</script> scaled by corresponding scalar values in the <script type="math/tex">d</script>-dim vector, <script type="math/tex">\sum_{i=1}^d \theta^{(d)}_i \mathbf{P}^\top_{(:,i)}</script>, it is better to fully utilize the subspace with orthogonal columns in <script type="math/tex">\mathbf{P}</script>.</p>
<p>My implementation follows a naive approach by sampling a large matrix with independent entries from a standard normal distribution. The columns are expected to be independent in a high dimension space and thus to be orthogonal. This works when the dimension is not too large. When exploring with a large <script type="math/tex">d</script>, there are methods for creating sparse projection matrices, which is what the intrinsic dimension paper suggested.</p>
<p>Here are experiment runs on two networks: (left) a two-layer fc network with 64 units in each layer and (right) a one-layer fc network with 128 hidden units, trained on 10% of MNIST. For every <script type="math/tex">d</script>, the model is trained for 100 epochs. See the <a href="https://github.com/lilianweng/generalization-experiment/blob/master/intrinsic_dimensions.py">code</a> <a href="https://github.com/lilianweng/generalization-experiment/blob/master/intrinsic_dimensions_measurement.py">here</a>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/intrinsic-dimension-net-64-64-and-128.png" alt="intrinsic dimension experiment 1" /></p>
<hr />
<p>Cited as:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019overfit,
title = "Are Deep Neural Networks Dramatically Overfitted?",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/03/14/are-deep-neural-networks-dramatically-overfitted.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Wikipedia page on <a href="https://en.wikipedia.org/wiki/Occam%27s_razor">Occam’s Razor</a>.</p>
<p>[2] <a href="http://pespmc1.vub.ac.be/OCCAMRAZ.html">Occam’s Razor</a> on Principia Cybernetica Web.</p>
<p>[3] Peter Grunwald. <a href="https://arxiv.org/abs/math/0406077">“A Tutorial Introduction to the Minimum Description Length Principle”</a>. 2004.</p>
<p>[4] Ian Goodfellow, et al. <a href="https://www.deeplearningbook.org/">Deep Learning</a>. 2016. <a href="https://www.deeplearningbook.org/contents/mlp.html">Sec 6.4.1</a>.</p>
<p>[5] Zhang, Chiyuan, et al. <a href="https://arxiv.org/abs/1611.03530">“Understanding deep learning requires rethinking generalization.”</a> ICLR 2017.</p>
<p>[6] Shibani Santurkar, et al. <a href="https://arxiv.org/abs/1805.11604">“How does batch normalization help optimization?.”</a> NIPS 2018.</p>
<p>[7] Mikhail Belkin, et al. <a href="https://arxiv.org/abs/1812.11118">“Reconciling modern machine learning and the bias-variance trade-off.”</a> arXiv:1812.11118, 2018.</p>
<p>[8] Chiyuan Zhang, et al. <a href="https://arxiv.org/abs/1902.01996">“Are All Layers Created Equal?”</a> arXiv:1902.01996, 2019.</p>
<p>[9] Chunyuan Li, et al. <a href="https://arxiv.org/abs/1804.08838">“Measuring the intrinsic dimension of objective landscapes.”</a> ICLR 2018.</p>
<p>[10] Jonathan Frankle and Michael Carbin. <a href="https://arxiv.org/abs/1803.03635">“The lottery ticket hypothesis: Finding sparse, trainable neural networks.”</a> ICLR 2019.</p>Lilian WengIf you are, like me, confused by why deep neural networks can generalize to out-of-sample data points without drastic overfitting, keep on reading.Generalized Language Models2019-01-31T12:00:00+00:002019-01-31T12:00:00+00:00https://lilianweng.github.io/lil-log/2019/01/31/generalized-language-models<blockquote>
<p>As a follow up of word embedding post, we will discuss the models on learning contextualized word vectors, as well as the new trend in large unsupervised pre-trained language models which have achieved amazing SOTA results on a variety of language tasks.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2019-02-14: add <a href="#ulmfit">ULMFiT</a> and <a href="#openai-gpt-2">OpenAI GPT-2</a>.]</span></p>
<p style="width: 60%;" class="center"><br />
<img src="/lil-log/assets/images/elmo-and-bert.png" alt="Elmo & Bert" /></p>
<p><em>Fig. 0. I guess they are Elmo & Bert? (Image source: <a href="https://www.youtube.com/watch?v=l5einDQ-Ttc">here</a>)</em>
<br /></p>
<p>We have seen amazing progress in NLP in 2018. Large-scale pre-trained language modes like <a href="https://blog.openai.com/language-unsupervised/">OpenAI GPT</a> and <a href="https://arxiv.org/abs/1810.04805">BERT</a> have achieved great performance on a variety of language tasks using generic model architectures. The idea is similar to how ImageNet classification pre-training helps many vision tasks (*). Even better than vision classification pre-training, this simple and powerful approach in NLP does not require labeled data for pre-training, allowing us to experiment with increased training scale, up to our very limit.</p>
<p><em>(*) Although recently He et al. (2018) <a href="https://arxiv.org/abs/1811.08883">found</a> that pre-training might not be necessary for image segmentation task.</em></p>
<p>In my previous NLP <a href="/lil-log/2017/10/15/learning-word-embedding.html">post on word embedding</a>, the introduced embeddings are not context-specific — they are learned based on word concurrency but not sequential context. So in two sentences, “<em>I am eating an apple</em>” and “<em>I have an Apple phone</em>”, two “apple” words refer to very different things but they would still share the same word embedding vector.</p>
<p>Despite this, early adoption of word embeddings in problem-solving is to use them as additional features for an existing task-specific model and in a way the improvement is bounded.</p>
<p>In this post, we will discuss how various approaches were proposed to make embeddings dependent on context, and to make them easier and cheaper to be applied to downstream tasks in general form.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#cove" id="markdown-toc-cove">CoVe</a> <ul>
<li><a href="#nmt-recap" id="markdown-toc-nmt-recap">NMT Recap</a></li>
<li><a href="#use-cove-in-downstream-tasks" id="markdown-toc-use-cove-in-downstream-tasks">Use CoVe in Downstream Tasks</a></li>
</ul>
</li>
<li><a href="#elmo" id="markdown-toc-elmo">ELMo</a> <ul>
<li><a href="#bidirectional-language-model" id="markdown-toc-bidirectional-language-model">Bidirectional Language Model</a></li>
<li><a href="#elmo-representations" id="markdown-toc-elmo-representations">ELMo Representations</a></li>
<li><a href="#use-elmo-in-downstream-tasks" id="markdown-toc-use-elmo-in-downstream-tasks">Use ELMo in Downstream Tasks</a></li>
</ul>
</li>
<li><a href="#cross-view-training" id="markdown-toc-cross-view-training">Cross-View Training</a> <ul>
<li><a href="#model-architecture" id="markdown-toc-model-architecture">Model Architecture</a></li>
<li><a href="#multi-task-learning" id="markdown-toc-multi-task-learning">Multi-Task Learning</a></li>
<li><a href="#use-cvt-in-downstream-tasks" id="markdown-toc-use-cvt-in-downstream-tasks">Use CVT in Downstream Tasks</a></li>
</ul>
</li>
<li><a href="#ulmfit" id="markdown-toc-ulmfit">ULMFiT</a></li>
<li><a href="#openai-gpt" id="markdown-toc-openai-gpt">OpenAI GPT</a> <ul>
<li><a href="#transformer-decoder-as-language-model" id="markdown-toc-transformer-decoder-as-language-model">Transformer Decoder as Language Model</a></li>
<li><a href="#bpe" id="markdown-toc-bpe">BPE</a></li>
<li><a href="#supervised-fine-tuning" id="markdown-toc-supervised-fine-tuning">Supervised Fine-Tuning</a></li>
</ul>
</li>
<li><a href="#bert" id="markdown-toc-bert">BERT</a> <ul>
<li><a href="#pre-training-tasks" id="markdown-toc-pre-training-tasks">Pre-training Tasks</a></li>
<li><a href="#input-embedding" id="markdown-toc-input-embedding">Input Embedding</a></li>
<li><a href="#use-bert-in-downstream-tasks" id="markdown-toc-use-bert-in-downstream-tasks">Use BERT in Downstream Tasks</a></li>
</ul>
</li>
<li><a href="#openai-gpt-2" id="markdown-toc-openai-gpt-2">OpenAI GPT-2</a> <ul>
<li><a href="#zero-shot-transfer" id="markdown-toc-zero-shot-transfer">Zero-Shot Transfer</a></li>
<li><a href="#bpe-on-byte-sequences" id="markdown-toc-bpe-on-byte-sequences">BPE on Byte Sequences</a></li>
<li><a href="#model-modifications" id="markdown-toc-model-modifications">Model Modifications</a></li>
</ul>
</li>
<li><a href="#summary" id="markdown-toc-summary">Summary</a></li>
<li><a href="#metric-perplexity" id="markdown-toc-metric-perplexity">Metric: Perplexity</a></li>
<li><a href="#common-tasks-and-datasets" id="markdown-toc-common-tasks-and-datasets">Common Tasks and Datasets</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="cove">CoVe</h2>
<p><strong>CoVe</strong> (<a href="https://arxiv.org/abs/1708.00107">McCann et al. 2017</a>), short for <strong>Contextual Word Vectors</strong>, is a type of word embeddings learned by an encoder in an <a href="/lil-log/2018/06/24/attention-attention.html#born-for-translation">attentional seq-to-seq</a> machine translation model.
Different from traditional word embeddings introduced <a href="/lil-log/2017/10/15/learning-word-embedding.html">here</a>, CoVe word representations are functions of the entire input sentence.</p>
<h3 id="nmt-recap">NMT Recap</h3>
<p>Here the Neural Machine Translation (<a href="https://github.com/THUNLP-MT/MT-Reading-List">NMT</a>) model is composed of a standard, two-layer, bidirectional LSTM encoder and an attentional two-layer unidirectional LSTM decoder. It is pre-trained on the English-German translation task. The encoder learns and optimizes the embedding vectors of English words in order to translate them to German. With the intuition that the encoder should capture high-level semantic and syntactic meanings before transforming words into another language, the encoder output is used to provide contextualized word embeddings for various downstream language tasks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/nmt-recap.png" alt="NMT Recap" /></p>
<p><em>Fig. 1. The NMT base model used in CoVe.</em></p>
<ul>
<li>A sequence of <script type="math/tex">n</script> words in source language (English): <script type="math/tex">x = [x_1, \dots, x_n]</script>.</li>
<li>A sequence of <script type="math/tex">m</script> words in target language (German): <script type="math/tex">y = [y_1, \dots, y_m]</script>.</li>
<li>The <a href="/lil-log/2017/10/15/learning-word-embedding.html#glove-global-vectors">GloVe</a> vectors of source words: <script type="math/tex">\text{GloVe}(x)</script>.</li>
<li>Randomly initialized embedding vectors of target words: <script type="math/tex">z = [z_1, \dots, z_m]</script>.</li>
<li>The biLSTM encoder outputs a sequence of hidden states: <script type="math/tex">h = [h_1, \dots, h_n] = \text{biLSTM}(\text{GloVe}(x))</script> and <script type="math/tex">h_t = [\overrightarrow{h}_t; \overleftarrow{h}_t]</script> where the forward LSTM computes <script type="math/tex">\overrightarrow{h}_t = \text{LSTM}(x_t, \overrightarrow{h}_{t-1})</script> and the backward computation gives us <script type="math/tex">\overleftarrow{h}_t = \text{LSTM}(x_t, \overleftarrow{h}_{t-1})</script>.</li>
<li>The attentional decoder outputs a distribution over words: <script type="math/tex">p(y_t \mid H, y_1, \dots, y_{t-1})</script> where <script type="math/tex">H</script> is a stack of hidden states <script type="math/tex">\{h\}</script> along the time dimension:</li>
</ul>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\text{decoder hidden state: } s_t &= \text{LSTM}([z_{t-1}; \tilde{h}_{t-1}], s_{t-1}) \\
\text{attention weights: } \alpha_t &= \text{softmax}(H(W_1 s_t + b_1)) \\
\text{context-adjusted hidden state: } \tilde{h}_t &= \tanh(W_2[H^\top\alpha_t;s_t] + b_2) \\
\text{decoder output: } p(y_t\mid H, y_1, \dots, y_{t-1}) &= \text{softmax}(W_\text{out} \tilde{h}_t + b_\text{out})
\end{aligned} %]]></script>
<h3 id="use-cove-in-downstream-tasks">Use CoVe in Downstream Tasks</h3>
<p>The hidden states of NMT encoder are defined as <strong>context vectors</strong> for other language tasks:</p>
<script type="math/tex; mode=display">\text{CoVe}(x) = \text{biLSTM}(\text{GloVe}(x))</script>
<p>The paper proposed to use the concatenation of GloVe and CoVe for question-answering and classification tasks. GloVe learns from the ratios of global word co-occurrences, so it has no sentence context, while CoVe is generated by processing text sequences is able to capture the contextual information.</p>
<script type="math/tex; mode=display">v = [\text{GloVe}(x); \text{CoVe}(x)]</script>
<p>Given a downstream task, we first generate the concatenation of GloVe + CoVe vectors of input words and then feed them into the task-specific models as additional features.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CoVe.png" alt="CoVe model" /></p>
<p><em>Fig. 2. The CoVe embeddings are generated by an encoder trained for machine translation task. The encoder can be plugged into any downstream task-specific model. (Image source: <a href="https://arxiv.org/abs/1708.00107">original paper</a>)</em></p>
<p><strong>Summary</strong>: The limitation of CoVe is obvious: (1) pre-training is bounded by available datasets on the supervised translation task; (2) the contribution of CoVe to the final performance is constrained by the task-specific model architecture.</p>
<p>In the following sections, we will see that ELMo overcomes issue (1) by unsupervised pre-training and OpenAI GPT & BERT further overcome both problems by unsupervised pre-training + using generative model architecture for different downstream tasks.</p>
<h2 id="elmo">ELMo</h2>
<p><strong>ELMo</strong>, short for <strong>Embeddings from Language Model</strong> (<a href="https://arxiv.org/abs/1802.05365">Peters, et al, 2018</a>) learns contextualized word representation by pre-training a language model in an <em>unsupervised</em> way.</p>
<h3 id="bidirectional-language-model">Bidirectional Language Model</h3>
<p>The bidirectional Language Model (<strong>biLM</strong>) is the foundation for ELMo. While the input is a sequence of <script type="math/tex">n</script> tokens, <script type="math/tex">(x_1, \dots, x_n)</script>, the language model learns to predict the probability of next token given the history.</p>
<p>In the forward pass, the history contains words before the target token,</p>
<script type="math/tex; mode=display">p(x_1, \dots, x_n) = \prod_{i=1}^n p(x_i \mid x_1, \dots, x_{i-1})</script>
<p>In the backward pass, the history contains words after the target token,</p>
<script type="math/tex; mode=display">p(x_1, \dots, x_n) = \prod_{i=1}^n p(x_i \mid x_{i+1}, \dots, x_n)</script>
<p>The predictions in both directions are modeled by multi-layer LSTMs with hidden states <script type="math/tex">\overrightarrow{\mathbf{h}}_{i,\ell}</script> and <script type="math/tex">\overleftarrow{\mathbf{h}}_{i,\ell}</script> for input token <script type="math/tex">x_i</script> at the layer level <script type="math/tex">\ell=1,\dots,L</script>.
The final layer’s hidden state <script type="math/tex">\mathbf{h}_{i,L} = [\overrightarrow{\mathbf{h}}_{i,L}; \overleftarrow{\mathbf{h}}_{i,L}]</script> is used to output the probabilities over tokens after softmax normalization. They share the embedding layer and the softmax layer, parameterized by <script type="math/tex">\Theta_e</script> and <script type="math/tex">\Theta_s</script> respectively.</p>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/ELMo-biLSTM.png" alt="ELMo biLSTM" /></p>
<p><em>Fig. 3. The biLSTM base model of ELMo. (Image source: recreated based on the figure in <a href="http://colah.github.io/posts/2015-09-NN-Types-FP/">“Neural Networks, Types, and Functional Programming”</a> by Christopher Olah.)</em></p>
<p>The model is trained to minimize the negative log likelihood (= maximize the log likelihood for true words) in both directions:</p>
<script type="math/tex; mode=display">\begin{aligned}
\mathcal{L} = - \sum_{i=1}^n \Big(
\log p(x_i \mid x_1, \dots, x_{i-1}; \Theta_e, \overrightarrow{\Theta}_\text{LSTM}, \Theta_s) + \\
\log p(x_i \mid x_{i+1}, \dots, x_n; \Theta_e, \overleftarrow{\Theta}_\text{LSTM}, \Theta_s) \Big)
\end{aligned}</script>
<h3 id="elmo-representations">ELMo Representations</h3>
<p>On top of a <script type="math/tex">L</script>-layer biLM, ELMo stacks all the hidden states across layers together by learning a task-specific linear combination. The hidden state representation for the token <script type="math/tex">x_i</script> contains <script type="math/tex">2L+1</script> vectors:</p>
<p><script type="math/tex">R_i = \{ \mathbf{h}_{i,\ell} \mid \ell = 0, \dots, L \}</script>
where <script type="math/tex">\mathbf{h}_{0, \ell}</script> is the embedding layer output and <script type="math/tex">\mathbf{h}_{i, \ell} = [\overrightarrow{\mathbf{h}}_{i,\ell}; \overleftarrow{\mathbf{h}}_{i,\ell}]</script>.</p>
<p>The weights, <script type="math/tex">\mathbf{s}^\text{task}</script>, in the linear combination are learned for each end task and normalized by softmax. The scaling factor <script type="math/tex">\gamma^\text{task}</script> is used to correct the misalignment between the distribution of biLM hidden states and the distribution of task specific representations.</p>
<script type="math/tex; mode=display">v_i = f(R_i; \Theta^\text{task}) = \gamma^\text{task} \sum_{\ell=0}^L s^\text{task}_i \mathbf{h}_{i,\ell}</script>
<p>To evaluate what kind of information is captured by hidden states across different layers, ELMo is applied on semantic-intensive and syntax-intensive tasks respectively using representations in different layers of biLM:</p>
<ul>
<li><strong>Semantic task</strong>: The <em>word sense disambiguation (WSD)</em> task emphasizes the meaning of a word given a context. The biLM top layer is better at this task than the first layer.</li>
<li><strong>Syntax task</strong>: The <em><a href="https://en.wikipedia.org/wiki/Part-of-speech_tagging">part-of-speech</a> (POS) tagging</em> task aims to infer the grammatical role of a word in one sentence. A higher accuracy can be achieved by using the biLM first layer than the top layer.</li>
</ul>
<p>The comparison study indicates that syntactic information is better represented at lower layers while semantic information is captured by higher layers. Because different layers tend to carry different type of information, <em>stacking them together helps</em>.</p>
<h3 id="use-elmo-in-downstream-tasks">Use ELMo in Downstream Tasks</h3>
<p>Similar to how <a href="#use-cove-in-downstream-tasks">CoVe</a> can help different downstream tasks, ELMo embedding vectors are included in the input or lower levels of task-specific models. Moreover, for some tasks (i.e., <a href="#nli">SNLI</a> and <a href="#qa">SQuAD</a>, but not <a href="#srl">SRL</a>), adding them into the output level helps too.</p>
<p>The improvements brought up by ELMo are largest for tasks with a small supervised dataset. With ELMo, we can also achieve similar performance with much less labeled data.</p>
<p><strong>Summary</strong>: The language model pre-training is unsupervised and theoretically the pre-training can be scaled up as much as possible since the unlabeled text corpora are abundant. However, it still has the dependency on task-customized models and thus the improvement is only incremental, while searching for a good model architecture for every task remains non-trivial.</p>
<h2 id="cross-view-training">Cross-View Training</h2>
<p>In ELMo the unsupervised pre-training and task-specific learning happen for two independent models in two separate training stages. <strong>Cross-View Training</strong> (abbr. <strong>CVT</strong>; <a href="https://arxiv.org/abs/1809.08370">Clark et al., 2018</a>) combines them into one unified semi-supervised learning procedure where the representation of a biLSTM encoder is improved by both supervised learning with labeled data and unsupervised learning with unlabeled data on auxiliary tasks.</p>
<h3 id="model-architecture">Model Architecture</h3>
<p>The model consists of a two-layer bidirectional LSTM encoder and a primary prediction module. During training, the model is fed with labeled and unlabeled data batches alternatively.</p>
<ul>
<li>On <em>labeled examples</em>, all the model parameters are updated by standard supervised learning. The loss is the standard cross entropy.</li>
<li>On <em>unlabeled examples</em>, the primary prediction module still can produce a “soft” target, even though we cannot know exactly how accurate they are. In a couple of auxiliary tasks, the predictor only sees and processes a restricted view of the input, such as only using encoder hidden state representation in one direction. The auxiliary task outputs are expected to match the primary prediction target for a full view of input. <br />In this way, the encoder is forced to distill the knowledge of the full context into partial representation. At this stage, the biLSTM encoder is backpropagated but the primary prediction module is <em>fixed</em>. The loss is to minimize the distance between auxiliary and primary predictions.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/CVT.png" alt="CVT" /></p>
<p><em>Fig. 4. The overview of semi-supervised language model cross-view training. (Image source: <a href="https://arxiv.org/abs/1809.08370">original paper</a>)</em></p>
<h3 id="multi-task-learning">Multi-Task Learning</h3>
<p>When training for multiple tasks simultaneously, CVT adds several extra primary prediction models for additional tasks. They all share the same sentence representation encoder.
During supervised training, once one task is randomly selected, parameters in its corresponding predictor and the representation encoder are updated.
With unlabeled data samples, the encoder is optimized jointly across all the tasks by minimizing the differences between auxiliary outputs and primary prediction for every task.</p>
<p>The multi-task learning encourages better generality of representation and in the meantime produces a nice side-product: all-tasks-labeled examples from unlabeled data. They are precious data labels considering that cross-task labels are useful but fairly rare.</p>
<h3 id="use-cvt-in-downstream-tasks">Use CVT in Downstream Tasks</h3>
<p>Theoretically the primary prediction module can take any form, generic or task-specific design. The examples presented in the CVT paper include both cases.</p>
<p>In sequential tagging tasks (classification for every token) like <a href="#ner">NER</a> or <a href="#pos">POS</a> tagging, the predictor module contains two fully connected layers and a softmax layer on the output to produce a probability distribution over class labels.
For each token <script type="math/tex">\mathbf{x}_i</script>, we take the corresponding hidden states in two layers, <script type="math/tex">\mathbf{h}_1^{(i)}</script> and <script type="math/tex">\mathbf{h}_2^{(i)}</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
p_\theta(y_i \mid \mathbf{x}_i)
&= \text{NN}(\mathbf{h}^{(i)}) \\
&= \text{NN}([\mathbf{h}_1^{(i)}; \mathbf{h}_2^{(i)}]) \\
&= \text{softmax} \big( \mathbf{W}\cdot\text{ReLU}(\mathbf{W'}\cdot[\mathbf{h}_1^{(i)}; \mathbf{h}_2^{(i)}]) + \mathbf{b} \big)
\end{aligned} %]]></script>
<p>The auxiliary tasks are only fed with forward or backward LSTM state in the first layer. Because they only observe partial context, either on the left or right, they have to learn like a language model, trying to predict the next token given the context. The <code class="highlighter-rouge">fwd</code> and <code class="highlighter-rouge">bwd</code> auxiliary tasks only take one direction. The <code class="highlighter-rouge">future</code> and <code class="highlighter-rouge">past</code> tasks take one step further in forward and backward direction, respectively.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
p_\theta^\text{fwd}(y_i \mid \mathbf{x}_i) &= \text{NN}^\text{fwd}(\overrightarrow{\mathbf{h}}^{(i)}) \\
p_\theta^\text{bwd}(y_i \mid \mathbf{x}_i) &= \text{NN}^\text{bwd}(\overleftarrow{\mathbf{h}}^{(i)}) \\
p_\theta^\text{future}(y_i \mid \mathbf{x}_i) &= \text{NN}^\text{future}(\overrightarrow{\mathbf{h}}^{(i-1)}) \\
p_\theta^\text{past}(y_i \mid \mathbf{x}_i) &= \text{NN}^\text{past}(\overleftarrow{\mathbf{h}}^{(i+1)})
\end{aligned} %]]></script>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/CVT-example.png" alt="CVT sequential tagging" /></p>
<p><em>Fig. 5. The sequential tagging task depends on four auxiliary prediction models, their inputs only involving hidden states in one direction: forward, backward, future and past. (Image source: <a href="https://arxiv.org/abs/1809.08370">original paper</a>)</em></p>
<p>Note that if the primary prediction module has dropout, the dropout layer works as usual when training with labeled data, but it is not applied when generating “soft” target for auxiliary tasks during training with unlabeled data.</p>
<p>In the machine translation task, the primary prediction module is replaced with a standard unidirectional LSTM decoder with attention. There are two auxiliary tasks: (1) apply dropout on the attention weight vector by randomly zeroing out some values; (2) predict the future word in the target sequence. The primary prediction for auxiliary tasks to match is the best predicted target sequence produced by running the fixed primary decoder on the input sequence with <a href="https://en.wikipedia.org/wiki/Beam_search">beam search</a>.</p>
<h2 id="ulmfit">ULMFiT</h2>
<p>The idea of using generative pretrained LM + task-specific fine-tuning was first explored in ULMFiT (<a href="https://arxiv.org/abs/1801.06146">Howard & Ruder, 2018</a>), directly motivated by the success of using ImageNet pre-training for computer vision tasks. The base model is <a href="https://arxiv.org/abs/1708.02182">AWD-LSTM</a>.</p>
<p>ULMFiT follows three steps to achieve good transfer learning results on downstream language classification tasks:</p>
<p>1) <em>General LM pre-training</em>: on Wikipedia text.</p>
<p>2) <em>Target task LM fine-tuning</em>: ULMFiT proposed two training techniques for stabilizing the fine-tuning process. See below.</p>
<ul>
<li>
<p><strong>Discriminative fine-tuning</strong> is motivated by the fact that different layers of LM capture different types of information (see <a href="#elmo-representations">discussion</a> above). ULMFiT proposed to tune each layer with different learning rates, <script type="math/tex">\{\eta^1, \dots, \eta^\ell, \dots, \eta^L\}</script>, where <script type="math/tex">\eta</script> is the base learning rate for the first layer, <script type="math/tex">\eta^\ell</script> is for the <script type="math/tex">\ell</script>-th layer and there are <script type="math/tex">L</script> layers in total.</p>
</li>
<li>
<p><strong>Slanted triangular learning rates (STLR)</strong> refer to a special learning rate scheduling that first linearly increases the learning rate and then linearly decays it. The increase stage is short so that the model can converge to a parameter space suitable for the task fast, while the decay period is long allowing for better fine-tuning.</p>
</li>
</ul>
<p>3) <em>Target task classifier fine-tuning</em>: The pretrained LM is augmented with two standard feed-forward layers and a softmax normalization at the end to predict a target label distribution.</p>
<ul>
<li>
<p><strong>Concat pooling</strong> extracts max-polling and mean-pooling over the history of hidden states and concatenates them with the final hidden state.</p>
</li>
<li>
<p><strong>Gradual unfreezing</strong> helps to avoid catastrophic forgetting by gradually unfreezing the model layers starting from the last one. First the last layer is unfrozen and fine-tuned for one epoch. Then the next lower layer is unfrozen. This process is repeated until all the layers are tuned.</p>
</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/ULMFiT.png" alt="ULMFiT" /></p>
<p><em>Fig. 6. Three training stages of ULMFiT. (Image source: <a href="https://arxiv.org/abs/1801.06146">original paper</a>)</em></p>
<h2 id="openai-gpt">OpenAI GPT</h2>
<p>Following the similar idea of ELMo, OpenAI <strong>GPT</strong>, short for <strong>Generative Pre-training Transformer</strong> (<a href="https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf">Radford et al., 2018</a>), expands the unsupervised language model to a much larger scale by training on a giant collection of free text corpora. Despite of the similarity, GPT has two major differences from ELMo.</p>
<ol>
<li>The model architectures are different: ELMo uses a shallow concatenation of independently trained left-to-right and right-to-left multi-layer LSTMs, while GPT is a multi-layer transformer decoder.</li>
<li>The use of contextualized embeddings in downstream tasks are different: ELMo feeds embeddings into models customized for specific tasks as additional features, while GPT fine-tunes the same base model for all end tasks.</li>
</ol>
<h3 id="transformer-decoder-as-language-model">Transformer Decoder as Language Model</h3>
<p>Compared to the <a href="https://arxiv.org/abs/1706.03762">original transformer</a> architecture, the <a href="https://arxiv.org/abs/1801.10198">transformer decoder</a> model discards the encoder part, so there is only one single input sentence rather than two separate source and target sequences.</p>
<p>This model applies multiple transformer blocks over the embeddings of input sequences. Each block contains a masked <em>multi-headed self-attention</em> layer and a <em>pointwise feed-forward</em> layer. The final output produces a distribution over target tokens after softmax normalization.</p>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/OpenAI-GPT-transformer-decoder.png" alt="OpenAI GPT transformer decoder" /></p>
<p><em>Fig. 7. The transformer decoder model architecture in OpenAI GPT.</em></p>
<p>The loss is the negative log-likelihood, same as <a href="#elmo">ELMo</a>, but without backward computation. Let’s say, the context window of the size <script type="math/tex">k</script> is located before the target word and the loss would look like:</p>
<script type="math/tex; mode=display">\mathcal{L}_\text{LM} = -\sum_{i} \log p(x_i\mid x_{i-k}, \dots, x_{i-1})</script>
<h3 id="bpe">BPE</h3>
<p><strong>Byte Pair Encoding</strong> (<a href="https://arxiv.org/abs/1508.07909"><strong>BPE</strong></a>) is used to encode the input sequences. BPE was originally proposed as a data compression algorithm in 1990s and then was adopted to solve the open-vocabulary issue in machine translation, as we can easily run into rare and unknown words when translating into a new language. Motivated by the intuition that rare and unknown words can often be decomposed into multiple subwords, BPE finds the best word segmentation by iteratively and greedily merging frequent pairs of characters.</p>
<h3 id="supervised-fine-tuning">Supervised Fine-Tuning</h3>
<p>The most substantial upgrade that OpenAI GPT proposed is to get rid of the task-specific model and use the pre-trained language model directly!</p>
<p>Let’s take classification as an example. Say, in the labeled dataset, each input has <script type="math/tex">n</script> tokens, <script type="math/tex">\mathbf{x} = (x_1, \dots, x_n)</script>, and one label <script type="math/tex">y</script>. GPT first processes the input sequence <script type="math/tex">\mathbf{x}</script> through the pre-trained transformer decoder and the last layer output for the last token <script type="math/tex">x_n</script> is <script type="math/tex">\mathbf{h}_L^{(n)}</script>. Then with only one new trainable weight matrix <script type="math/tex">\mathbf{W}_y</script>, it can predict a distribution over class labels.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/GPT-classification.png" alt="GPT classification" /></p>
<script type="math/tex; mode=display">P(y\mid x_1, \dots, x_n) = \text{softmax}(\mathbf{h}_L^{(n)}\mathbf{W}_y)</script>
<p>The loss is to minimize the negative log-likelihood for true labels. In addition, adding the LM loss as an auxiliary loss is found to be beneficial, because:</p>
<ul>
<li>(1) it helps accelerate convergence during training and</li>
<li>(2) it is expected to improve the generalization of the supervised model.</li>
</ul>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}_\text{cls} &= \sum_{(\mathbf{x}, y) \in \mathcal{D}} \log P(y\mid x_1, \dots, x_n) = \sum_{(\mathbf{x}, y) \in \mathcal{D}} \log \text{softmax}(\mathbf{h}_L^{(n)}(\mathbf{x})\mathbf{W}_y) \\
\mathcal{L}_\text{LM} &= -\sum_{i} \log p(x_i\mid x_{i-k}, \dots, x_{i-1}) \\
\mathcal{L} &= \mathcal{L}_\text{cls} + \lambda \mathcal{L}_\text{LM}
\end{aligned} %]]></script>
<p>With similar designs, no customized model structure is needed for other end tasks (see Fig. 7). If the task input contains multiple sentences, a special delimiter token (<code class="highlighter-rouge">$</code>) is added between each pair of sentences. The embedding for this delimiter token is a new parameter we need to learn, but it should be pretty minimal.</p>
<p>For the sentence similarity task, because the ordering does not matter, both orderings are included. For the multiple choice task, the context is paired with every answer candidate.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/GPT-downstream-tasks.png" alt="GPT downstream tasks" /></p>
<p><em>Fig. 8. Training objects in slightly modified GPT transformer models for downstream tasks. (Image source: <a href="https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf">original paper</a>)</em></p>
<p><strong>Summary</strong>: It is super neat and encouraging to see that such a general framework is capable to beat SOTA on most language tasks at that time (June 2018). At the first stage, generative pre-training of a language model can absorb as much free text as possible. Then at the second stage, the model is fine-tuned on specific tasks with a small labeled dataset and a minimal set of new parameters to learn.</p>
<p>One limitation of GPT is its uni-directional nature — the model is only trained to predict the future left-to-right context.</p>
<h2 id="bert">BERT</h2>
<p><strong>BERT</strong>, short for <strong>Bidirectional Encoder Representations from Transformers</strong> (<a href="https://arxiv.org/abs/1810.04805">Devlin, et al., 2019</a>) is a direct descendant to <a href="#gpt">GPT</a>: train a large language model on free text and then fine-tune on specific tasks without customized network architectures.</p>
<p>Compared to GPT, the largest difference and improvement of BERT is to make training <strong>bi-directional</strong>. The model learns to predict both context on the left and right. The paper according to the ablation study claimed that:</p>
<blockquote>
<p>“bidirectional nature of our model is the single most important new contribution”</p>
</blockquote>
<h3 id="pre-training-tasks">Pre-training Tasks</h3>
<p>The model architecture of BERT is a multi-layer bidirectional Transformer encoder.</p>
<p style="width: 25%;" class="center"><img src="/lil-log/assets/images/transformer-encoder-2.png" alt="transformer encoder" /></p>
<p><em>Fig. 9. Recap of Transformer Encoder model architecture. (Image source: <a href="https://arxiv.org/abs/1706.03762">Transformer paper</a>)</em></p>
<p>To encourage the bi-directional prediction and sentence-level understanding, BERT is trained with two auxiliary tasks instead of the basic language task (that is, to predict the next token given context).</p>
<p><strong>Task 1: Mask language model (MLM)</strong></p>
<blockquote>
<p>From <a href="https://en.wikipedia.org/wiki/Cloze_test">Wikipedia</a>: “A cloze test (also cloze deletion test) is an exercise, test, or assessment consisting of a portion of language with certain items, words, or signs removed (cloze text), where the participant is asked to replace the missing language item. … The exercise was first described by W.L. Taylor in 1953.”</p>
</blockquote>
<p>It is unsurprising to believe that a representation that learns the context around a word rather than just after the word is able to better capture its meaning, both syntactically and semantically. BERT encourages the model to do so by training on the <em>“mask language model” task</em>:</p>
<ol>
<li>Randomly mask 15% of tokens in each sequence. Because if we only replace masked tokens with a special placeholder <code class="highlighter-rouge">[MASK]</code>, the special token would never be encountered during fine-tuning. Hence, BERT employed several heuristic tricks:
<ul>
<li>(a) with 80% probability, replace the chosen words with <code class="highlighter-rouge">[MASK]</code>;</li>
<li>(b) with 10% probability, replace with a random word;</li>
<li>(c) with 10% probability, keep it the same.</li>
</ul>
</li>
<li>The model only predicts the missing words, but it has no information on which words have been replaced or which words should be predicted. The output size is only 15% of the input size.</li>
</ol>
<p><strong>Task 2: Next sentence prediction</strong></p>
<p>Motivated by the fact that many downstream tasks involve the understanding of relationships between sentences (i.e., <a href="#qa">QA</a>, <a href="#nli">NLI</a>), BERT added another auxiliary task on training a <em>binary classifier</em> for telling whether one sentence is the next sentence of the other:</p>
<ol>
<li>Sample sentence pairs (A, B) so that:
<ul>
<li>(a) 50% of the time, B follows A;</li>
<li>(b) 50% of the time, B does not follow A.</li>
</ul>
</li>
<li>The model processes both sentences and output a binary label indicating whether B is the next sentence of A.</li>
</ol>
<p>The training data for both auxiliary tasks above can be trivially generated from any monolingual corpus. Hence the scale of training is unbounded. The training loss is the sum of the mean masked LM likelihood and mean next sentence prediction likelihood.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/language-model-comparison.png" alt="Language model comparison" /></p>
<p><em>Fig. 10. Comparison of BERT, OpenAI GPT and ELMo model architectures. (Image source: <a href="https://arxiv.org/abs/1810.04805">original paper</a>)</em></p>
<h3 id="input-embedding">Input Embedding</h3>
<p>The input embedding is the sum of three parts:</p>
<ol>
<li><em>WordPiece tokenization embeddings</em>: The <a href="https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf">WordPiece</a> <a href="https://arxiv.org/pdf/1609.08144.pdf">model</a> was originally proposed for Japanese or Korean segmentation problem. Instead of using naturally split English word, they can be further divided into smaller sub-word units so that it is more effective to handle rare or unknown words. Please read <a href="https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf">linked</a> <a href="https://arxiv.org/pdf/1609.08144.pdf">papers</a> for the optimal way to split words if interested.</li>
<li><em>Segment embeddings</em>: If the input contains two sentences, they have sentence A embeddings and sentence B embeddings respectively and they are separated by a special character <code class="highlighter-rouge">[SEP]</code>; Only sentence A embeddings are used if the input only contains one sentence.</li>
<li><em>Position embeddings</em>: Positional embeddings are learned rather than hard-coded.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/BERT-input-embedding.png" alt="BERT input embedding" /></p>
<p><em>Fig. 11. BERT input representation. (Image source: <a href="https://arxiv.org/abs/1810.04805">original paper</a>)</em></p>
<p>Note that the first token is always forced to be <code class="highlighter-rouge">[CLS]</code> — a placeholder that will be used later for prediction in downstream tasks.</p>
<h3 id="use-bert-in-downstream-tasks">Use BERT in Downstream Tasks</h3>
<p>BERT fine-tuning requires only a few new parameters added, just like OpenAI GPT.</p>
<p>For classification tasks, we get the prediction by taking the final hidden state of the special first token <code class="highlighter-rouge">[CLS]</code>, <script type="math/tex">\mathbf{h}^\text{[CLS]}_L</script>, and multiplying it with a small weight matrix, <script type="math/tex">\text{softmax}(\mathbf{h}^\text{[CLS]}_L \mathbf{W}_\text{cls})</script>.</p>
<p>For <a href="#qa">QA</a> tasks like SQuAD, we need to predict the text span in the given paragraph for an given question. BERT predicts two probability distributions of every token, being the start and the end of the text span. Only two new small matrices, <script type="math/tex">\mathbf{W}_\text{s}</script> and <script type="math/tex">\mathbf{W}_\text{e}</script>, are newly learned during fine-tuning and <script type="math/tex">\text{softmax}(\mathbf{h}^\text{(i)}_L \mathbf{W}_\text{s})</script> and <script type="math/tex">\text{softmax}(\mathbf{h}^\text{(i)}_L \mathbf{W}_\text{e})</script> define two probability distributions.</p>
<p>Overall the add-on part for end task fine-tuning is very minimal — one or two weight matrices to convert the Transform hidden states to an interpretable format. Check the paper for implementation details for other cases.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/BERT-downstream-tasks.png" alt="BERT downstream tasks" /></p>
<p><em>Fig. 12. Training objects in slightly modified BERT models for downstream tasks. (Image source: <a href="https://arxiv.org/abs/1810.04805">original paper</a>)</em></p>
<p>A summary table compares differences between fine-tuning of OpenAI GPT and BERT.</p>
<table class="info">
<tbody>
<tr>
<td> </td>
<td><strong>OpenAI GPT</strong></td>
<td><strong>BERT</strong></td>
</tr>
<tr>
<td>Special char</td>
<td><code class="highlighter-rouge">[SEP]</code> and <code class="highlighter-rouge">[CLS]</code> are only introduced at fine-tuning stage.</td>
<td><code class="highlighter-rouge">[SEP]</code> and <code class="highlighter-rouge">[CLS]</code> and sentence A/B embeddings are learned at the pre-training stage.</td>
</tr>
<tr>
<td>Training process</td>
<td>1M steps, batch size 32k words.</td>
<td>1M steps, batch size 128k words.</td>
</tr>
<tr>
<td>Fine-tuning</td>
<td>lr = 5e-5 for all fine-tuning tasks.</td>
<td>Use task-specific lr for fine-tuning.</td>
</tr>
</tbody>
</table>
<h2 id="openai-gpt-2">OpenAI GPT-2</h2>
<p>The <a href="https://blog.openai.com/better-language-models/">OpenAI</a> <a href="https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf">GPT-2</a> language model is a direct successor to <a href="#openai-gpt">GPT</a>. GPT-2 has 1.5B parameters, 10x more than the original GPT, and it achieves SOTA results on 7 out of 8 tested language modeling datasets in a <em>zero-shot transfer setting</em> without any task-specific fine-tuning. The pre-training dataset contains 8 million Web pages collected by crawling qualified outbound links from <a href="https://www.reddit.com/">Reddit</a>. Large improvements by OpenAI GPT-2 are specially noticeable on small datasets and datasets used for measuring <em>long-term dependency</em>.</p>
<h3 id="zero-shot-transfer">Zero-Shot Transfer</h3>
<p>The pre-training task for GPT-2 is solely language modeling. All the downstream language tasks are framed as predicting conditional probabilities and there is no task-specific fine-tuning.</p>
<ul>
<li>Text generation is straightforward using LM.</li>
<li>Machine translation task, for example, English to Chinese, is induced by conditioning LM on pairs of “English sentence = Chinese sentence” and “the target English sentence =” at the end.
<ul>
<li>For example, the conditional probability to predict might look like: <code class="highlighter-rouge">P(? | I like green apples. = 我喜欢绿苹果。 A cat meows at him. = 一只猫对他喵。It is raining cats and dogs. =")</code></li>
</ul>
</li>
<li>QA task is formatted similar to translation with pairs of questions and answers in the context.</li>
<li>Summarization task is induced by adding <code class="highlighter-rouge">TL;DR:</code> after the articles in the context.</li>
</ul>
<h3 id="bpe-on-byte-sequences">BPE on Byte Sequences</h3>
<p>Same as the original GPT, GPT-2 uses <a href="#bpe">BPE</a> but on <a href="https://en.wikipedia.org/wiki/UTF-8">UTF-8</a> byte sequences. Each byte can represent 256 different values in 8 bits, while UTF-8 can use up to 4 bytes for one character, supporting up to <script type="math/tex">2^{31}</script> characters in total. Therefore, with byte sequence representation we only need a vocabulary of size 256 and do not need to worry about pre-processing, tokenization, etc. Despite of the benefit, current byte-level LMs still have non-negligible performance gap with the SOTA word-level LMs.</p>
<p>BPE merges frequently co-occurred byte pairs in a greedy manner. To prevent it from generating multiple versions of common words (i.e. <code class="highlighter-rouge">dog.</code>, <code class="highlighter-rouge">dog!</code> and <code class="highlighter-rouge">dog?</code> for the word <code class="highlighter-rouge">dog</code>), GPT-2 prevents BPE from merging characters across categories (thus <code class="highlighter-rouge">dog</code> would not be merged with punctuations like <code class="highlighter-rouge">.</code>, <code class="highlighter-rouge">!</code> and <code class="highlighter-rouge">?</code>). This tricks help increase the quality of the final byte segmentation.</p>
<p>Using the byte sequence representation, GPT-2 is able to assign a probability to any Unicode string, regardless of any pre-processing steps.</p>
<h3 id="model-modifications">Model Modifications</h3>
<p>Compared to GPT, other than having many more transformer layers and parameters, GPT-2 incorporates only a few architecture modifications:</p>
<ul>
<li><a href="https://arxiv.org/abs/1607.06450">Layer normalization</a> was moved to the input of each sub-block, similar to a residual unit of type <a href="https://arxiv.org/abs/1603.05027">“building block”</a> (differently from the original type <a href="https://arxiv.org/abs/1512.03385">“bottleneck”</a>, it has batch normalization applied before weight layers).</li>
<li>An additional layer normalization was added after the final self-attention block.</li>
<li>A modified initialization was constructed as a function of the model depth.</li>
<li>The weights of residual layers were initially scaled by a factor of <script type="math/tex">1/ \sqrt{N}</script> where N is the number of residual layers.</li>
<li>Use larger vocabulary size and context size.</li>
</ul>
<h2 id="summary">Summary</h2>
<table class="info">
<thead>
<tr>
<th> </th>
<th>Base model</th>
<th>pre-training</th>
<th>Downstream tasks</th>
<th>Downstream model</th>
<th>Fine-tuning</th>
</tr>
</thead>
<tbody>
<tr>
<td>CoVe</td>
<td>seq2seq NMT model</td>
<td>supervised</td>
<td>feature-based</td>
<td>task-specific</td>
<td>/</td>
</tr>
<tr>
<td>ELMo</td>
<td>two-layer biLSTM</td>
<td>unsupervised</td>
<td>feature-based</td>
<td>task-specific</td>
<td>/</td>
</tr>
<tr>
<td>CVT</td>
<td>two-layer biLSTM</td>
<td>semi-supervised</td>
<td>model-based</td>
<td>task-specific / task-agnostic</td>
<td>/</td>
</tr>
<tr>
<td>ULMFiT</td>
<td>AWD-LSTM</td>
<td>unsupervised</td>
<td>model-based</td>
<td>task-agnostic</td>
<td>all layers; with various training tricks</td>
</tr>
<tr>
<td>GPT</td>
<td>Transformer decoder</td>
<td>unsupervised</td>
<td>model-based</td>
<td>task-agnostic</td>
<td>pre-trained layers + top task layer(s)</td>
</tr>
<tr>
<td>BERT</td>
<td>Transformer encoder</td>
<td>unsupervised</td>
<td>model-based</td>
<td>task-agnostic</td>
<td>pre-trained layers + top task layer(s)</td>
</tr>
<tr>
<td>GPT-2</td>
<td>Transformer decoder</td>
<td>unsupervised</td>
<td>model-based</td>
<td>task-agnostic</td>
<td>pre-trained layers + top task layer(s)</td>
</tr>
</tbody>
</table>
<h2 id="metric-perplexity">Metric: Perplexity</h2>
<p>Perplexity is often used as an intrinsic evaluation metric for gauging how well a language model can capture the real word distribution conditioned on the context.</p>
<p>A <a href="https://en.wikipedia.org/wiki/Perplexity">perplexity</a> of a discrete proability distribution <script type="math/tex">p</script> is defined as the exponentiation of the entropy:</p>
<script type="math/tex; mode=display">2^{H(p)} = 2^{-\sum_x p(x) \log_2 p(x)}</script>
<p>Given a sentence with <script type="math/tex">N</script> words, <script type="math/tex">s = (w_1, \dots, w_N)</script>, the entropy looks as follows, simply assuming that each word has the same frequency, <script type="math/tex">\frac{1}{N}</script>:</p>
<script type="math/tex; mode=display">H(s) = -\sum_{i=1}^N P(w_i) \log_2 p(w_i) = -\sum_{i=1}^N \frac{1}{N} \log_2 p(w_i)</script>
<p>The perplexity for the sentence becomes:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
2^{H(s)} &= 2^{-\frac{1}{N} \sum_{i=1}^N \log_2 p(w_i)}
= (2^{\sum_{i=1}^N \log_2 p(w_i)})^{-\frac{1}{N}}
= (p(w_1) \dots p(w_N))^{-\frac{1}{N}}
\end{aligned} %]]></script>
<p>A good language model should predict high word probabilities. Therefore, the smaller perplexity the better.</p>
<h2 id="common-tasks-and-datasets">Common Tasks and Datasets</h2>
<p><a name="qa"></a>
<strong>Question-Answering</strong></p>
<ul>
<li><a href="https://rajpurkar.github.io/SQuAD-explorer/">SQuAD</a> (Stanford Question Answering Dataset): A reading comprehension dataset, consisting of questions posed on a set of Wikipedia articles, where the answer to every question is a span of text.</li>
<li><a href="http://www.qizhexie.com/data/RACE_leaderboard">RACE</a> (ReAding Comprehension from Examinations): A large-scale reading comprehension dataset with more than 28,000 passages and nearly 100,000 questions. The dataset is collected from English examinations in China, which are designed for middle school and high school students.</li>
</ul>
<p><strong>Commonsense Reasoning</strong></p>
<ul>
<li><a href="http://cs.rochester.edu/nlp/rocstories/">Story Cloze Test</a>: A commonsense reasoning framework for evaluating story understanding and generation. The test requires a system to choose the correct ending to multi-sentence stories from two options.</li>
<li><a href="https://rowanzellers.com/swag/">SWAG</a> (Situations With Adversarial Generations): multiple choices; contains 113k sentence-pair completion examples that evaluate grounded common-sense inference</li>
</ul>
<p><a name="nli"></a>
<strong>Natural Language Inference (NLI)</strong>: also known as <strong>Text Entailment</strong>, an exercise to discern in logic whether one sentence can be inferred from another.</p>
<ul>
<li><a href="https://aclweb.org/aclwiki/Textual_Entailment_Resource_Pool">RTE</a> (Recognizing Textual Entailment): A set of datasets initiated by text entailment challenges.</li>
<li><a href="https://nlp.stanford.edu/projects/snli/">SNLI</a> (Stanford Natural Language Inference): A collection of 570k human-written English sentence pairs manually labeled for balanced classification with the labels <code class="highlighter-rouge">entailment</code>, <code class="highlighter-rouge">contradiction</code>, and <code class="highlighter-rouge">neutral</code>.</li>
<li><a href="https://www.nyu.edu/projects/bowman/multinli/">MNLI</a> (Multi-Genre NLI): Similar to SNLI, but with a more diverse variety of text styles and topics, collected from transcribed speech, popular fiction, and government reports.</li>
<li><a href="https://gluebenchmark.com/tasks">QNLI</a> (Question NLI): Converted from SQuAD dataset to be a binary classification task over pairs of (question, sentence).</li>
<li><a href="http://data.allenai.org/scitail/">SciTail</a>: An entailment dataset created from multiple-choice science exams and web sentences.</li>
</ul>
<p><a name="ner"></a>
<strong>Named Entity Recognition (NER)</strong>: labels sequences of words in a text which are the names of things, such as person and company names, or gene and protein names</p>
<ul>
<li><a href="https://www.clips.uantwerpen.be/conll2003/">CoNLL 2003 NER task</a>: consists of newswire from the Reuters, concentrating on four types of named entities: persons, locations, organizations and names of miscellaneous entities.</li>
<li><a href="https://catalog.ldc.upenn.edu/LDC2013T19">OntoNotes 0.5</a>: This corpus contains text in English, Arabic and Chinese, tagged with four different entity types (PER, LOC, ORG, MISC).</li>
<li><a href="https://trec.nist.gov/data/reuters/reuters.html">Reuters Corpus</a>: A large collection of Reuters News stories.</li>
<li>Fine-Grained NER (FGN)</li>
</ul>
<p><strong>Sentiment Analysis</strong></p>
<ul>
<li><a href="https://nlp.stanford.edu/sentiment/index.html">SST</a> (Stanford Sentiment Treebank)</li>
<li><a href="http://ai.stanford.edu/~amaas/data/sentiment/">IMDb</a>: A large dataset of movie reviews with binary sentiment classification labels.</li>
</ul>
<p><a name="srl"></a>
<strong>Semantic Role Labeling (SRL)</strong>: models the predicate-argument structure of a sentence, and is often described as answering “Who did what to whom”.</p>
<ul>
<li><a href="http://www.lsi.upc.edu/~srlconll/">CoNLL-2004 & CoNLL-2005</a></li>
</ul>
<p><strong>Sentence similarity</strong>: also known as <em>paraphrase detection</em></p>
<ul>
<li><a href="https://www.microsoft.com/en-us/download/details.aspx?id=52398">MRPC</a> (MicRosoft Paraphrase Corpus): It contains pairs of sentences extracted from news sources on the web, with annotations indicating whether each pair is semantically equivalent.</li>
<li><a href="https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs">QQP</a> (Quora Question Pairs)
STS Benchmark: Semantic Textual Similarity</li>
</ul>
<p><strong>Sentence Acceptability</strong>: a task to annotate sentences for grammatical acceptability.</p>
<ul>
<li><a href="https://nyu-mll.github.io/CoLA/">CoLA</a> (Corpus of Linguistic Acceptability): a binary single-sentence classification task.</li>
</ul>
<p><strong>Text Chunking</strong>: To divide a text in syntactically correlated parts of words.</p>
<ul>
<li><a href="https://www.clips.uantwerpen.be/conll2000/chunking/">CoNLL-2000</a></li>
</ul>
<p><a name="pos"></a>
<strong>Part-of-Speech (POS) Tagging</strong>: tag parts of speech to each token, such as noun, verb, adjective, etc.
the Wall Street Journal portion of the Penn Treebank (Marcus et al., 1993).</p>
<p><strong>Machine Translation</strong>: See <a href="https://nlp.stanford.edu/projects/nmt/">Standard NLP</a> page.</p>
<ul>
<li>WMT 2015 English-Czech data (Large)</li>
<li>WMT 2014 English-German data (Medium)</li>
<li>IWSLT 2015 English-Vietnamese data (Small)</li>
</ul>
<p><strong>Coreference Resolution</strong>: cluster mentions in text that refer to the same underlying real world entities.</p>
<ul>
<li><a href="http://conll.cemantix.org/2012/data.html">CoNLL-2012</a></li>
</ul>
<p><strong>Long-range Dependency</strong></p>
<ul>
<li><a href="http://clic.cimec.unitn.it/lambada/">LAMBADA</a> (LAnguage Modeling Broadened to Account for Discourse Aspects): A collection of narrative passages extracted from the BookCorpus and the task is to predict the last word, which require at least 50 tokens of context for a human to successfully predict.</li>
<li><a href="https://research.fb.com/downloads/babi/">Children’s Book Test</a>: is built from books that are freely available in <a href="https://www.gutenberg.org/">Project Gutenberg</a>. The task is to predict the missing word among 10 candidates.</li>
</ul>
<p><strong>Multi-task benchmark</strong></p>
<ul>
<li>GLUE multi-task benchmark: <a href="https://gluebenchmark.com/">https://gluebenchmark.com</a></li>
<li>decaNLP benmark: <a href="https://decanlp.com/">https://decanlp.com</a></li>
</ul>
<p><strong>Unsupervised pretraining dataset</strong></p>
<ul>
<li><a href="https://googlebooks.byu.edu/">Books corpus</a>: The corpus contains “over 7,000 unique unpublished books from a variety of genres including Adventure, Fantasy, and Romance.”</li>
<li><a href="http://www.statmt.org/lm-benchmark/">1B Word Language Model Benchmark</a></li>
<li><a href="https://en.wikipedia.org/wiki/Wikipedia:Database_download#English-language_Wikipedia">English Wikipedia</a>: ~2500M words</li>
</ul>
<hr />
<p>Cited as:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2019LM,
title = "Generalized Language Models",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2019",
url = "http://lilianweng.github.io/lil-log/2019/01/31/generalized-language-models.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Bryan McCann, et al. <a href="https://arxiv.org/abs/1708.00107">“Learned in translation: Contextualized word vectors.”</a> NIPS. 2017.</p>
<p>[2] Kevin Clark et al. <a href="https://arxiv.org/abs/1809.08370">“Semi-Supervised Sequence Modeling with Cross-View Training.”</a> EMNLP 2018.</p>
<p>[3] Matthew E. Peters, et al. <a href="https://arxiv.org/abs/1802.05365">“Deep contextualized word representations.”</a> NAACL-HLT 2017.</p>
<p>[4] OpenAI Blog <a href="https://blog.openai.com/language-unsupervised/">“Improving Language Understanding with Unsupervised Learning”</a>, June 11, 2018.</p>
<p>[5] OpenAI Blog <a href="https://blog.openai.com/better-language-models/">“Better Language Models and Their Implications.”</a> Feb 14, 2019.</p>
<p>[6] Jeremy Howard and Sebastian Ruder. <a href="https://arxiv.org/abs/1801.06146">“Universal language model fine-tuning for text classification.”</a> ACL 2018.</p>
<p>[7] Alec Radford et al. <a href="https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf">“Improving Language Understanding by Generative Pre-Training”</a>. OpenAI Blog, June 11, 2018.</p>
<p>[8] Jacob Devlin, et al. <a href="https://arxiv.org/abs/1810.04805">“BERT: Pre-training of deep bidirectional transformers for language understanding.”</a> arXiv:1810.04805 (2018).</p>
<p>[9] Mike Schuster, and Kaisuke Nakajima. <a href="https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf">“Japanese and Korean voice search.”</a> ICASSP. 2012.</p>
<p>[10] Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation</p>
<p>[11] Ashish Vaswani, et al. <a href="https://arxiv.org/abs/1706.03762">“Attention is all you need.”</a> NIPS 2017.</p>
<p>[12] Peter J. Liu, et al. <a href="https://arxiv.org/abs/1801.10198">“Generating wikipedia by summarizing long sequences.”</a> ICLR 2018.</p>
<p>[13] Sebastian Ruder. <a href="http://ruder.io/10-exciting-ideas-of-2018-in-nlp/">“10 Exciting Ideas of 2018 in NLP”</a> Dec 2018.</p>
<p>[14] Alec Radford, et al. <a href="https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf">“Language Models are Unsupervised Multitask Learners.”</a>. 2019.</p>
<p>[15] Rico Sennrich, et al. <a href="https://arxiv.org/abs/1508.07909">“Neural machine translation of rare words with subword units.”</a> arXiv preprint arXiv:1508.07909. 2015.</p>Lilian WengAs a follow up of word embedding post, we will discuss the models on learning contextualized word vectors, as well as the new trend in large unsupervised pre-trained language models which have achieved amazing SOTA results on a variety of language tasks.Object Detection Part 4: Fast Detection Models2018-12-27T12:00:00+00:002018-12-27T12:00:00+00:00https://lilianweng.github.io/lil-log/2018/12/27/object-detection-part-4<blockquote>
<p>Part 4 of the “Object Detection for Dummies” series focuses on one-stage models for fast detection, including SSD, RetinaNet, and models in the YOLO family. These models skip the explicit region proposal stage but apply the detection directly on dense sampled areas.</p>
</blockquote>
<!--more-->
<p>In <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html">Part 3</a>, we have reviewed models in the R-CNN family. All of them are region-based object detection algorithms. They can achieve high accuracy but could be too slow for certain applications such as autonomous driving. In Part 4, we only focus on fast object detection models, including SSD, RetinaNet, and models in the YOLO family.</p>
<p>Links to all the posts in the series:
[<a href="/lil-log/2017/10/29/object-recognition-for-dummies-part-1.html">Part 1</a>]
[<a href="/lil-log/2017/12/15/object-recognition-for-dummies-part-2.html">Part 2</a>]
[<a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html">Part 3</a>]
[<a href="/lil-log/2018/12/27/object-detection-part-4.html">Part 4</a>].</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#two-stage-vs-one-stage-detectors" id="markdown-toc-two-stage-vs-one-stage-detectors">Two-stage vs One-stage Detectors</a></li>
<li><a href="#yolo-you-only-look-once" id="markdown-toc-yolo-you-only-look-once">YOLO: You Only Look Once</a> <ul>
<li><a href="#workflow" id="markdown-toc-workflow">Workflow</a></li>
<li><a href="#network-architecture" id="markdown-toc-network-architecture">Network Architecture</a></li>
<li><a href="#loss-function" id="markdown-toc-loss-function">Loss Function</a></li>
</ul>
</li>
<li><a href="#ssd-single-shot-multibox-detector" id="markdown-toc-ssd-single-shot-multibox-detector">SSD: Single Shot MultiBox Detector</a> <ul>
<li><a href="#image-pyramid" id="markdown-toc-image-pyramid">Image Pyramid</a></li>
<li><a href="#workflow-1" id="markdown-toc-workflow-1">Workflow</a></li>
<li><a href="#loss-function-1" id="markdown-toc-loss-function-1">Loss Function</a></li>
</ul>
</li>
<li><a href="#yolov2--yolo9000" id="markdown-toc-yolov2--yolo9000">YOLOv2 / YOLO9000</a> <ul>
<li><a href="#yolov2-improvement" id="markdown-toc-yolov2-improvement">YOLOv2 Improvement</a></li>
<li><a href="#yolo9000-rich-dataset-training" id="markdown-toc-yolo9000-rich-dataset-training">YOLO9000: Rich Dataset Training</a></li>
</ul>
</li>
<li><a href="#retinanet" id="markdown-toc-retinanet">RetinaNet</a> <ul>
<li><a href="#focal-loss" id="markdown-toc-focal-loss">Focal Loss</a></li>
<li><a href="#featurized-image-pyramid" id="markdown-toc-featurized-image-pyramid">Featurized Image Pyramid</a></li>
<li><a href="#model-architecture" id="markdown-toc-model-architecture">Model Architecture</a></li>
</ul>
</li>
<li><a href="#yolov3" id="markdown-toc-yolov3">YOLOv3</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="two-stage-vs-one-stage-detectors">Two-stage vs One-stage Detectors</h2>
<p>Models in the R-CNN family are all region-based. The detection happens in two stages: (1) First, the model proposes a set of regions of interests by select search or regional proposal network. The proposed regions are sparse as the potential bounding box candidates can be infinite. (2) Then a classifier only processes the region candidates.</p>
<p>The other different approach skips the region proposal stage and runs detection directly over a dense sampling of possible locations. This is how a one-stage object detection algorithm works. This is faster and simpler, but might potentially drag down the performance a bit.</p>
<p>All the models introduced in this post are one-stage detectors.</p>
<h2 id="yolo-you-only-look-once">YOLO: You Only Look Once</h2>
<p>The <strong>YOLO</strong> model (<strong>“You Only Look Once”</strong>; <a href="https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Redmon_You_Only_Look_CVPR_2016_paper.pdf">Redmon et al., 2016</a>) is the very first attempt at building a fast real-time object detector. Because YOLO does not undergo the region proposal step and only predicts over a limited number of bounding boxes, it is able to do inference super fast.</p>
<h3 id="workflow">Workflow</h3>
<ol>
<li>
<p><strong>Pre-train</strong> a CNN network on image classification task.</p>
</li>
<li>Split an image into <script type="math/tex">S \times S</script> cells. If an object’s center falls into a cell, that cell is “responsible” for detecting the existence of that object. Each cell predicts (a) the location of <script type="math/tex">B</script> bounding boxes, (b) a confidence score, and (c) a probability of object class conditioned on the existence of an object in the bounding box.
<br />
<br />
<ul>
<li>The <strong>coordinates</strong> of bounding box are defined by a tuple of 4 values, (center x-coord, center y-coord, width, height) — <script type="math/tex">(x, y, w, h)</script>, where <script type="math/tex">x</script> and <script type="math/tex">y</script> are set to be offset of a cell location. Moreover, <script type="math/tex">x</script>, <script type="math/tex">y</script>, <script type="math/tex">w</script> and <script type="math/tex">h</script> are normalized by the image width and height, and thus all between (0, 1].</li>
<li>A <strong>confidence score</strong> indicates the likelihood that the cell contains an object: <code class="highlighter-rouge">Pr(containing an object) x IoU(pred, truth)</code>; where <code class="highlighter-rouge">Pr</code> = probability and <code class="highlighter-rouge">IoU</code> = interaction under union.</li>
<li>If the cell contains an object, it predicts a <strong>probability</strong> of this object belonging to every class <script type="math/tex">C_i, i=1, \dots, K</script>: <code class="highlighter-rouge">Pr(the object belongs to the class C_i | containing an object)</code>. At this stage, the model only predicts one set of class probabilities per cell, regardless of the number of bounding boxes, <script type="math/tex">B</script>.</li>
<li>In total, one image contains <script type="math/tex">S \times S \times B</script> bounding boxes, each box corresponding to 4 location predictions, 1 confidence score, and K conditional probabilities for object classification. The total prediction values for one image is <script type="math/tex">S \times S \times (5B + K)</script>, which is the tensor shape of the final conv layer of the model.
<br />
<br /></li>
</ul>
</li>
<li>The final layer of the pre-trained CNN is modified to output a prediction tensor of size <script type="math/tex">S \times S \times (5B + K)</script>.</li>
</ol>
<p class="center"><img src="/lil-log/assets/images/yolo.png" alt="YOLO workflow" /></p>
<p><em>Fig. 1. The workflow of YOLO model. (Image source: <a href="https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Redmon_You_Only_Look_CVPR_2016_paper.pdf">original paper</a>)</em></p>
<h3 id="network-architecture">Network Architecture</h3>
<p>The base model is similar to <a href="https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf">GoogLeNet</a> with inception module replaced by 1x1 and 3x3 conv layers. The final prediction of shape <script type="math/tex">S \times S \times (5B + K)</script> is produced by two fully connected layers over the whole conv feature map.</p>
<p class="center"><img src="/lil-log/assets/images/yolo-network-architecture.png" alt="YOLO architecture" /></p>
<p><em>Fig. 2. The network architecture of YOLO.</em></p>
<h3 id="loss-function">Loss Function</h3>
<p>The loss consists of two parts, the <em>localization loss</em> for bounding box offset prediction and the <em>classification loss</em> for conditional class probabilities. Both parts are computed as the sum of squared errors. Two scale parameters are used to control how much we want to increase the loss from bounding box coordinate predictions (<script type="math/tex">\lambda_\text{coord}</script>) and how much we want to decrease the loss of confidence score predictions for boxes without objects (<script type="math/tex">\lambda_\text{noobj}</script>). Down-weighting the loss contributed by background boxes is important as most of the bounding boxes involve no instance. In the paper, the model sets <script type="math/tex">\lambda_\text{coord} = 5</script> and <script type="math/tex">\lambda_\text{noobj} = 0.5</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}_\text{loc} &= \lambda_\text{coord} \sum_{i=0}^{S^2} \sum_{j=0}^B \mathbb{1}_{ij}^\text{obj} [(x_i - \hat{x}_i)^2 + (y_i - \hat{y}_i)^2 + (\sqrt{w_i} - \sqrt{\hat{w}_i})^2 + (\sqrt{h_i} - \sqrt{\hat{h}_i})^2 ] \\
\mathcal{L}_\text{cls} &= \sum_{i=0}^{S^2} \sum_{j=0}^B \big( \mathbb{1}_{ij}^\text{obj} + \lambda_\text{noobj} (1 - \mathbb{1}_{ij}^\text{obj})\big) (C_{ij} - \hat{C}_{ij})^2 + \sum_{i=0}^{S^2} \sum_{c \in \mathcal{C}} \mathbb{1}_i^\text{obj} (p_i(c) - \hat{p}_i(c))^2\\
\mathcal{L} &= \mathcal{L}_\text{loc} + \mathcal{L}_\text{cls}
\end{aligned} %]]></script>
<blockquote>
<p>NOTE: In the original YOLO paper, the loss function uses <script type="math/tex">C_i</script> instead of <script type="math/tex">C_{ij}</script> as confidence score. I made the correction based on my own understanding, since every bounding box should have its own confidence score. Please kindly let me if you do not agree. Many thanks.</p>
</blockquote>
<p>where,</p>
<ul>
<li><script type="math/tex">\mathbb{1}_i^\text{obj}</script>: An indicator function of whether the cell i contains an object.</li>
<li><script type="math/tex">\mathbb{1}_{ij}^\text{obj}</script>: It indicates whether the j-th bounding box of the cell i is “responsible” for the object prediction (see Fig. 3).</li>
<li><script type="math/tex">C_{ij}</script>: The confidence score of cell i, <code class="highlighter-rouge">Pr(containing an object) * IoU(pred, truth)</code>.</li>
<li><script type="math/tex">\hat{C}_{ij}</script>: The predicted confidence score.</li>
<li><script type="math/tex">\mathcal{C}</script>: The set of all classes.</li>
<li><script type="math/tex">p_i(c)</script>: The conditional probability of whether cell i contains an object of class <script type="math/tex">c \in \mathcal{C}</script>.</li>
<li><script type="math/tex">\hat{p}_i(c)</script>: The predicted conditional class probability.</li>
</ul>
<p style="width: 85%;" class="center"><img src="/lil-log/assets/images/yolo-responsible-predictor.png" alt="YOLO responsible predictor" /></p>
<p><em>Fig. 3. At one location, in cell i, the model proposes B bounding box candidates and the one that has highest overlap with the ground truth is the “responsible” predictor.</em></p>
<p>The loss function only penalizes classification error if an object is present in that grid cell, <script type="math/tex">\mathbb{1}_i^\text{obj} = 1</script>. It also only penalizes bounding box coordinate error if that predictor is “responsible” for the ground truth box, <script type="math/tex">\mathbb{1}_{ij}^\text{obj} = 1</script>.</p>
<p>As a one-stage object detector, YOLO is super fast, but it is not good at recognizing irregularly shaped objects or a group of small objects due to a limited number of bounding box candidates.</p>
<h2 id="ssd-single-shot-multibox-detector">SSD: Single Shot MultiBox Detector</h2>
<p>The <strong>Single Shot Detector</strong> (<strong>SSD</strong>; <a href="https://arxiv.org/abs/1512.02325">Liu et al, 2016</a>) is one of the first attempts at using convolutional neural network’s pyramidal feature hierarchy for efficient detection of objects of various sizes.</p>
<h3 id="image-pyramid">Image Pyramid</h3>
<p>SSD uses the <a href="https://arxiv.org/abs/1409.1556">VGG-16</a> model pre-trained on ImageNet as its base model for extracting useful image features.
On top of VGG16, SSD adds several conv feature layers of decreasing sizes. They can be seen as a <em>pyramid representation</em> of images at different scales. Intuitively large fine-grained feature maps at earlier levels are good at capturing small objects and small coarse-grained feature maps can detect large objects well. In SSD, the detection happens in every pyramidal layer, targeting at objects of various sizes.</p>
<p class="center"><img src="/lil-log/assets/images/SSD-architecture.png" alt="SSD architecture" /></p>
<p><em>Fig. 4. The model architecture of SSD.</em></p>
<h3 id="workflow-1">Workflow</h3>
<p>Unlike YOLO, SSD does not split the image into grids of arbitrary size but predicts offset of predefined <em>anchor boxes</em> (this is called “default boxes” in the paper) for every location of the feature map. Each box has a fixed size and position relative to its corresponding cell. All the anchor boxes tile the whole feature map in a convolutional manner.</p>
<p>Feature maps at different levels have different receptive field sizes. The anchor boxes on different levels are rescaled so that one feature map is only responsible for objects at one particular scale. For example, in Fig. 5 the dog can only be detected in the 4x4 feature map (higher level) while the cat is just captured by the 8x8 feature map (lower level).</p>
<p class="center"><img src="/lil-log/assets/images/SSD-framework.png" alt="SSD framework" /></p>
<p><em>Fig. 5. The SSD framework. (a) The training data contains images and ground truth boxes for every object. (b) In a fine-grained feature maps (8 x 8), the anchor boxes of different aspect ratios correspond to smaller area of the raw input. (c) In a coarse-grained feature map (4 x 4), the anchor boxes cover larger area of the raw input. (Image source: <a href="https://arxiv.org/abs/1512.02325">original paper</a>)</em></p>
<p>The width, height and the center location of an anchor box are all normalized to be (0, 1). At a location <script type="math/tex">(i, j)</script> of the <script type="math/tex">\ell</script>-th feature layer of size <script type="math/tex">m \times n</script>, <script type="math/tex">i=1,\dots,n, j=1,\dots,m</script>, we have a unique linear scale proportional to the layer level and 5 different box aspect ratios (width-to-height ratios), in addition to a special scale (why we need this? the paper didn’t explain. maybe just a heuristic trick) when the aspect ratio is 1. This gives us 6 anchor boxes in total per feature cell.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\text{level index: } &\ell = 1, \dots, L \\
\text{scale of boxes: } &s_\ell = s_\text{min} + \frac{s_\text{max} - s_\text{min}}{L - 1} (\ell - 1) \\
\text{aspect ratio: } &r \in \{1, 2, 3, 1/2, 1/3\}\\
\text{additional scale: } & s'_\ell = \sqrt{s_\ell s_{\ell + 1}} \text{ when } r = 1 \text{thus, 6 boxes in total.}\\
\text{width: } &w_\ell^r = s_\ell \sqrt{r} \\
\text{height: } &h_\ell^r = s_\ell / \sqrt{r} \\
\text{center location: } & (x^i_\ell, y^j_\ell) = (\frac{i+0.5}{m}, \frac{j+0.5}{n})
\end{aligned} %]]></script>
<p class="center"><img src="/lil-log/assets/images/SSD-box-scales.png" alt="Box scales" /></p>
<p><em>Fig. 6. An example of how the anchor box size is scaled up with the layer index <script type="math/tex">\ell</script> for <script type="math/tex">L=6, s_\text{min} = 0.2, s_\text{max} = 0.9</script>. Only the boxes of aspect ratio <script type="math/tex">r=1</script> are illustrated.</em></p>
<p>At every location, the model outputs 4 offsets and <script type="math/tex">c</script> class probabilities by applying a <script type="math/tex">3 \times 3 \times p</script> conv filter (where <script type="math/tex">p</script> is the number of channels in the feature map) for every one of <script type="math/tex">k</script> anchor boxes. Therefore, given a feature map of size <script type="math/tex">m \times n</script>, we need <script type="math/tex">kmn(c+4)</script> prediction filters.</p>
<h3 id="loss-function-1">Loss Function</h3>
<p>Same as YOLO, the loss function is the sum of a localization loss and a classification loss.</p>
<script type="math/tex; mode=display">\mathcal{L} = \frac{1}{N}(\mathcal{L}_\text{cls} + \alpha \mathcal{L}_\text{loc})</script>
<p>where <script type="math/tex">N</script> is the number of matched bounding boxes and <script type="math/tex">\alpha</script> balances the weights between two losses, picked by cross validation.</p>
<p>The <em>localization loss</em> is a <a href="https://github.com/rbgirshick/py-faster-rcnn/files/764206/SmoothL1Loss.1.pdf">smooth L1 loss</a> between the predicted bounding box correction and the true values. The coordinate correction transformation is same as what <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html#r-cnn">R-CNN</a> does in <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html#bounding-box-regression">bounding box regression</a>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{L}_\text{loc} &= \sum_{i,j} \sum_{m\in\{x, y, w, h\}} \mathbb{1}_{ij}^\text{match}
L_1^\text{smooth}(d_m^i - t_m^j)^2\\
L_1^\text{smooth}(x) &= \begin{cases}
0.5 x^2 & \text{if } \vert x \vert < 1\\
\vert x \vert - 0.5 & \text{otherwise}
\end{cases} \\
t^j_x &= (g^j_x - p^i_x) / p^i_w \\
t^j_y &= (g^j_y - p^i_y) / p^i_h \\
t^j_w &= \log(g^j_w / p^i_w) \\
t^j_h &= \log(g^j_h / p^i_h)
\end{aligned} %]]></script>
<p>where <script type="math/tex">\mathbb{1}_{ij}^\text{match}</script> indicates whether the <script type="math/tex">i</script>-th bounding box with coordinates <script type="math/tex">(p^i_x, p^i_y, p^i_w, p^i_h)</script> is matched to the <script type="math/tex">j</script>-th ground truth box with coordinates <script type="math/tex">(g^j_x, g^j_y, g^j_w, g^j_h)</script> for any object. <script type="math/tex">d^i_m, m\in\{x, y, w, h\}</script> are the predicted correction terms. See <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html#bounding-box-regression">this</a> for how the transformation works.</p>
<p>The <em>classification loss</em> is a softmax loss over multiple classes (<a href="https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits">softmax_cross_entropy_with_logits</a> in tensorflow):</p>
<script type="math/tex; mode=display">\mathcal{L}_\text{cls} = -\sum_{i \in \text{pos}} \mathbb{1}_{ij}^k \log(\hat{c}_i^k) - \sum_{i \in \text{neg}} \log(\hat{c}_i^0)\text{, where }\hat{c}_i^k = \text{softmax}(c_i^k)</script>
<p>where <script type="math/tex">\mathbb{1}_{ij}^k</script> indicates whether the <script type="math/tex">i</script>-th bounding box and the <script type="math/tex">j</script>-th ground truth box are matched for an object in class <script type="math/tex">k</script>. <script type="math/tex">\text{pos}</script> is the set of matched bounding boxes (<script type="math/tex">N</script> items in total) and <script type="math/tex">\text{neg}</script> is the set of negative examples. SSD uses <a href="/lil-log/2017/12/31/object-recognition-for-dummies-part-3.html#common-tricks">hard negative mining</a> to select easily misclassified negative examples to construct this <script type="math/tex">\text{neg}</script> set: Once all the anchor boxes are sorted by objectiveness confidence score, the model picks the top candidates for training so that neg:pos is at most 3:1.</p>
<h2 id="yolov2--yolo9000">YOLOv2 / YOLO9000</h2>
<p><strong>YOLOv2</strong> (<a href="https://arxiv.org/abs/1612.08242">Redmon & Farhadi, 2017</a>) is an enhanced version of YOLO. <strong>YOLO9000</strong> is built on top of YOLOv2 but trained with joint dataset combining the COCO detection dataset and the top 9000 classes from ImageNet.</p>
<h3 id="yolov2-improvement">YOLOv2 Improvement</h3>
<p>A variety of modifications are applied to make YOLO prediction more accurate and faster, including:</p>
<p><strong>1. BatchNorm helps</strong>: Add <em>batch norm</em> on all the convolutional layers, leading to significant improvement over convergence.</p>
<p><strong>2. Image resolution matters</strong>: Fine-tuning the base model with <em>high resolution</em> images improves the detection performance.</p>
<p><strong>3. Convolutional anchor box detection</strong>: Rather than predicts the bounding box position with fully-connected layers over the whole feature map, YOLOv2 uses <em>convolutional layers</em> to predict locations of <em>anchor boxes</em>, like in faster R-CNN. The prediction of spatial locations and class probabilities are decoupled. Overall, the change leads to a slight decrease in mAP, but an increase in recall.</p>
<p><strong>4. K-mean clustering of box dimensions</strong>: Different from faster R-CNN that uses hand-picked sizes of anchor boxes, YOLOv2 runs k-mean clustering on the training data to find good priors on anchor box dimensions. The distance metric is designed to <em>rely on IoU scores</em>:</p>
<script type="math/tex; mode=display">\text{dist}(x, c_i) = 1 - \text{IoU}(x, c_i), i=1,\dots,k</script>
<p>where <script type="math/tex">x</script> is a ground truth box candidate and <script type="math/tex">c_i</script> is one of the centroids. The best number of centroids (anchor boxes) <script type="math/tex">k</script> can be chosen by the <a href="https://en.wikipedia.org/wiki/Elbow_method_(clustering)">elbow method</a>.</p>
<p>The anchor boxes generated by clustering provide better average IoU conditioned on a fixed number of boxes.</p>
<p><strong>5. Direct location prediction</strong>: YOLOv2 formulates the bounding box prediction in a way that it would <em>not diverge</em> from the center location too much. If the box location prediction can place the box in any part of the image, like in regional proposal network, the model training could become unstable.</p>
<p>Given the anchor box of size <script type="math/tex">(p_w, p_h)</script> at the grid cell with its top left corner at <script type="math/tex">(c_x, c_y)</script>, the model predicts the offset and the scale, <script type="math/tex">(t_x, t_y, t_w, t_h)</script> and the corresponding predicted bounding box <script type="math/tex">b</script> has center <script type="math/tex">(b_x, b_y)</script> and size <script type="math/tex">(b_w, b_h)</script>. The confidence score is the sigmoid (<script type="math/tex">\sigma</script>) of another output <script type="math/tex">t_o</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
b_x &= \sigma(t_x) + c_x\\
b_y &= \sigma(t_y) + c_y\\
b_w &= p_w e^{t_w}\\
b_h &= p_h e^{t_h}\\
\text{Pr}(\text{object}) &\cdot \text{IoU}(b, \text{object}) = \sigma(t_o)
\end{aligned} %]]></script>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/yolov2-loc-prediction.png" alt="YOLOv2 bbox location prediction" /></p>
<p><em>Fig. 7. YOLOv2 bounding box location prediction. (Image source: <a href="https://arxiv.org/abs/1612.08242">original paper</a>)</em></p>
<p><strong>6. Add fine-grained features</strong>: YOLOv2 adds a passthrough layer to bring <em>fine-grained features</em> from an earlier layer to the last output layer. The mechanism of this passthrough layer is similar to <em>identity mappings in ResNet</em> to extract higher-dimensional features from previous layers. This leads to 1% performance increase.</p>
<p><strong>7. Multi-scale training</strong>: In order to train the model to be robust to input images of different sizes, a <em>new size</em> of input dimension is <em>randomly sampled</em> every 10 batches. Since conv layers of YOLOv2 downsample the input dimension by a factor of 32, the newly sampled size is a multiple of 32.</p>
<p><strong>8. Light-weighted base model</strong>: To make prediction even faster, YOLOv2 adopts a light-weighted base model, DarkNet-19, which has 19 conv layers and 5 max-pooling layers. The key point is to insert avg poolings and 1x1 conv filters between 3x3 conv layers.</p>
<h3 id="yolo9000-rich-dataset-training">YOLO9000: Rich Dataset Training</h3>
<p>Because drawing bounding boxes on images for object detection is much more expensive than tagging images for classification, the paper proposed a way to combine small object detection dataset with large ImageNet so that the model can be exposed to a much larger number of object categories. The name of YOLO9000 comes from the top 9000 classes in ImageNet. During joint training, if an input image comes from the classification dataset, it only backpropagates the classification loss.</p>
<p>The detection dataset has much fewer and more general labels and, moreover, labels cross multiple datasets are often not mutually exclusive. For example, ImageNet has a label “Persian cat” while in COCO the same image would be labeled as “cat”. Without mutual exclusiveness, it does not make sense to apply softmax over all the classes.</p>
<p>In order to efficiently merge ImageNet labels (1000 classes, fine-grained) with COCO/PASCAL (< 100 classes, coarse-grained), YOLO9000 built a hierarchical tree structure with reference to <a href="https://wordnet.princeton.edu/">WordNet</a> so that general labels are closer to the root and the fine-grained class labels are leaves. In this way, “cat” is the parent node of “Persian cat”.</p>
<p style="width:100%;" class="center"><img src="/lil-log/assets/images/word-tree.png" alt="WordTree" /></p>
<p><em>Fig. 8. The WordTree hierarchy merges labels from COCO and ImageNet. Blue nodes are COCO labels and red nodes are ImageNet labels. (Image source: <a href="https://arxiv.org/abs/1612.08242">original paper</a>)</em></p>
<p>To predict the probability of a class node, we can follow the path from the node to the root:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Pr("persian cat" | contain a "physical object")
= Pr("persian cat" | "cat")
Pr("cat" | "animal")
Pr("animal" | "physical object")
Pr(contain a "physical object") # confidence score.
</code></pre></div></div>
<p>Note that <code class="highlighter-rouge">Pr(contain a "physical object")</code> is the confidence score, predicted separately in the bounding box detection pipeline. The path of conditional probability prediction can stop at any step, depending on which labels are available.</p>
<h2 id="retinanet">RetinaNet</h2>
<p>The <strong>RetinaNet</strong> (<a href="https://arxiv.org/abs/1708.02002">Lin et al., 2018</a>) is a one-stage dense object detector. Two crucial building blocks are <em>featurized image pyramid</em> and the use of <em>focal loss</em>.</p>
<h3 id="focal-loss">Focal Loss</h3>
<p>One issue for object detection model training is an extreme imbalance between background that contains no object and foreground that holds objects of interests. <strong>Focal loss</strong> is designed to assign more weights on hard, easily misclassified examples (i.e. background with noisy texture or partial object) and to down-weight easy examples (i.e. obviously empty background).</p>
<p>Starting with a normal cross entropy loss for binary classification,</p>
<script type="math/tex; mode=display">\text{CE}(p, y) = -y\log p - (1-y)\log(1-p)</script>
<p>where <script type="math/tex">y \in \{0, 1\}</script> is a ground truth binary label, indicating whether a bounding box contains a object, and <script type="math/tex">p \in [0, 1]</script> is the predicted probability of objectiveness (aka confidence score).</p>
<p>For notational convenience,</p>
<script type="math/tex; mode=display">% <![CDATA[
\text{let } p_t = \begin{cases}
p & \text{if } y = 1\\
1-p & \text{otherwise}
\end{cases},
\text{then } \text{CE}(p, y)=\text{CE}(p_t) = -\log p_t %]]></script>
<p>Easily classified examples with large <script type="math/tex">p_t \gg 0.5</script>, that is, when <script type="math/tex">p</script> is very close to 0 (when y=0) or 1 (when y=1), can incur a loss with non-trivial magnitude. Focal loss explicitly adds a weighting factor <script type="math/tex">(1-p_t)^\gamma, \gamma \geq 0</script> to each term in cross entropy so that the weight is small when <script type="math/tex">p_t</script> is large and therefore easy examples are down-weighted.</p>
<script type="math/tex; mode=display">\text{FL}(p_t) = -(1-p_t)^\gamma \log p_t</script>
<p style="width:65%;" class="center"><img src="/lil-log/assets/images/focal-loss.png" alt="Focal Loss" /></p>
<p><em>Fig. 9. The focal loss focuses less on easy examples with a factor of <script type="math/tex">(1-p_t)^\gamma</script>. (Image source: <a href="https://arxiv.org/abs/1708.02002">original paper</a>)</em></p>
<p>For a better control of the shape of the weighting function (see Fig. 10.), RetinaNet uses an <script type="math/tex">\alpha</script>-balanced variant of the focal loss, where <script type="math/tex">\alpha=0.25, \gamma=2</script> works the best.</p>
<script type="math/tex; mode=display">\text{FL}(p_t) = -\alpha (1-p_t)^\gamma \log p_t</script>
<p style="width:90%;" class="center"><img src="/lil-log/assets/images/focal-loss-weights.png" alt="WordTree" /></p>
<p><em>Fig. 10. The plot of focal loss weights <script type="math/tex">\alpha (1-p_t)^\gamma</script> as a function of <script type="math/tex">p_t</script>, given different values of <script type="math/tex">\alpha</script> and <script type="math/tex">\gamma</script>.</em></p>
<h3 id="featurized-image-pyramid">Featurized Image Pyramid</h3>
<p>The <strong>featurized image pyramid</strong> (<a href="https://arxiv.org/abs/1612.03144">Lin et al., 2017</a>) is the backbone network for RetinaNet. Following the same approach by <a href="#image-pyramid">image pyramid</a> in SSD, featurized image pyramids provide a basic vision component for object detection at different scales.</p>
<p>The key idea of feature pyramid network is demonstrated in Fig. 11. The base structure contains a sequence of <em>pyramid levels</em>, each corresponding to one network <em>stage</em>. One stage contains multiple convolutional layers of the same size and the stage sizes are scaled down by a factor of 2. Let’s denote the last layer of the <script type="math/tex">i</script>-th stage as <script type="math/tex">C_i</script>.</p>
<p style="width:100%;" class="center"><img src="/lil-log/assets/images/featurized-image-pyramid.png" alt="Featurized image pyramid" /></p>
<p><em>Fig. 11. The illustration of the featurized image pyramid module. (Replot based on figure 3 in <a href="https://arxiv.org/abs/1612.03144">FPN paper</a>)</em></p>
<p>Two pathways connect conv layers:</p>
<ul>
<li><strong>Bottom-up pathway</strong> is the normal feedforward computation.</li>
<li><strong>Top-down pathway</strong> goes in the inverse direction, adding coarse but semantically stronger feature maps back into the previous pyramid levels of a larger size via lateral connections.
<ul>
<li>First, the higher-level features are upsampled spatially coarser to be 2x larger. For image upscaling, the paper used nearest neighbor upsampling. While there are many <a href="https://en.wikipedia.org/wiki/Image_scaling#Algorithms">image upscaling algorithms</a> such as using <a href="https://www.tensorflow.org/api_docs/python/tf/layers/conv2d_transpose">deconv</a>, adopting another image scaling method might or might not improve the performance of RetinaNet.</li>
<li>The larger feature map undergoes a 1x1 conv layer to reduce the channel dimension.</li>
<li>Finally, these two feature maps are merged by element-wise addition.
<br />
<br />
The lateral connections only happen at the last layer in stages, denoted as <script type="math/tex">\{C_i\}</script>, and the process continues until the finest (largest) merged feature map is generated. The prediction is made out of every merged map after a 3x3 conv layer, <script type="math/tex">\{P_i\}</script>.</li>
</ul>
</li>
</ul>
<p>According to ablation studies, the importance rank of components of the featurized image pyramid design is as follows: <strong>1x1 lateral connection</strong> > detect object across multiple layers > top-down enrichment > pyramid representation (compared to only check the finest layer).</p>
<h3 id="model-architecture">Model Architecture</h3>
<p>The featurized pyramid is constructed on top of the ResNet architecture. Recall that <a href="TBA">ResNet</a> has 5 conv blocks (= network stages / pyramid levels). The last layer of the <script type="math/tex">i</script>-th pyramid level, <script type="math/tex">C_i</script>, has resolution <script type="math/tex">2^i</script> lower than the raw input dimension.</p>
<p>RetinaNet utilizes feature pyramid levels <script type="math/tex">P_3</script> to <script type="math/tex">P_7</script>:</p>
<ul>
<li><script type="math/tex">P_3</script> to <script type="math/tex">P_5</script> are computed from the corresponding ResNet residual stage from <script type="math/tex">C_3</script> to <script type="math/tex">C_5</script>. They are connected by both top-down and bottom-up pathways.</li>
<li><script type="math/tex">P_6</script> is obtained via a 3×3 stride-2 conv on top of <script type="math/tex">C_5</script></li>
<li><script type="math/tex">P_7</script> applies ReLU and a 3×3 stride-2 conv on <script type="math/tex">P_6</script>.</li>
</ul>
<p>Adding higher pyramid levels on ResNet improves the performance for detecting large objects.</p>
<p>Same as in SSD, detection happens in all pyramid levels by making a prediction out of every merged feature map. Because predictions share the same classifier and the box regressor, they are all formed to have the same channel dimension d=256.</p>
<p>There are A=9 anchor boxes per level:</p>
<ul>
<li>The base size corresponds to areas of <script type="math/tex">32^2</script> to <script type="math/tex">512^2</script> pixels on <script type="math/tex">P_3</script> to <script type="math/tex">P_7</script> respectively. There are three size ratios, <script type="math/tex">\{2^0, 2^{1/3}, 2^{2/3}\}</script>.</li>
<li>For each size, there are three aspect ratios {1/2, 1, 2}.</li>
</ul>
<p>As usual, for each anchor box, the model outputs a class probability for each of <script type="math/tex">K</script> classes in the classification subnet and regresses the offset from this anchor box to the nearest ground truth object in the box regression subnet. The classification subnet adopts the focal loss introduced above.</p>
<p style="width:100%;" class="center"><img src="/lil-log/assets/images/retina-net.png" alt="RetinaNet" /></p>
<p><em>Fig. 12. The RetinaNet model architecture uses a <a href="https://arxiv.org/abs/1612.03144">FPN</a> backbone on top of ResNet. (Image source: the <a href="https://arxiv.org/abs/1612.03144">FPN</a> paper)</em></p>
<h2 id="yolov3">YOLOv3</h2>
<p><a href="https://pjreddie.com/media/files/papers/YOLOv3.pdf">YOLOv3</a> is created by applying a bunch of design tricks on YOLOv2. The changes are inspired by recent advances in the object detection world.</p>
<p>Here are a list of changes:</p>
<p><strong>1. Logistic regression for confidence scores</strong>: YOLOv3 predicts an confidence score for each bounding box using <em>logistic regression</em>, while YOLO and YOLOv2 uses sum of squared errors for classification terms (see the <a href="#loss-function">loss function</a> above). Linear regression of offset prediction leads to a decrease in mAP.</p>
<p><strong>2. No more softmax for class prediction</strong>: When predicting class confidence, YOLOv3 uses <em>multiple independent logistic classifier</em> for each class rather than one softmax layer. This is very helpful especially considering that one image might have multiple labels and not all the labels are guaranteed to be mutually exclusive.</p>
<p><strong>3. Darknet + ResNet as the base model</strong>: The new Darknet-53 still relies on successive 3x3 and 1x1 conv layers, just like the original dark net architecture, but has residual blocks added.</p>
<p><strong>4. Multi-scale prediction</strong>: Inspired by image pyramid, YOLOv3 adds several conv layers after the base feature extractor model and makes prediction at three different scales among these conv layers. In this way, it has to deal with many more bounding box candidates of various sizes overall.</p>
<p><strong>5. Skip-layer concatenation</strong>: YOLOv3 also adds cross-layer connections between two prediction layers (except for the output layer) and earlier finer-grained feature maps. The model first up-samples the coarse feature maps and then merges it with the previous features by concatenation. The combination with finer-grained information makes it better at detecting small objects.</p>
<p>Interestingly, focal loss does not help YOLOv3, potentially it might be due to the usage of <script type="math/tex">\lambda_\text{noobj}</script> and <script type="math/tex">\lambda_\text{coord}</script> — they increase the loss from bounding box location predictions and decrease the loss from confidence predictions for background boxes.</p>
<p>Overall YOLOv3 performs better and faster than SSD, and worse than RetinaNet but 3.8x faster.</p>
<p style="width:80%;" class="center"><img src="/lil-log/assets/images/yolov3-perf.png" alt="YOLOv3 performance" /></p>
<p><em>Fig. 13. The comparison of various fast object detection models on speed and mAP performance. (Image source: <a href="https://arxiv.org/abs/1708.02002">focal loss</a> paper with additional labels from the <a href="https://pjreddie.com/media/files/papers/YOLOv3.pdf">YOLOv3</a> paper.)</em></p>
<hr />
<p>Cited as:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2018detection4,
title = "Object Detection Part 4: Fast Detection Models",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2018",
url = "http://lilianweng.github.io/lil-log/2018/12/27/object-detection-part-4.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Joseph Redmon, et al. <a href="https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Redmon_You_Only_Look_CVPR_2016_paper.pdf">“You only look once: Unified, real-time object detection.”</a> CVPR 2016.</p>
<p>[2] Joseph Redmon and Ali Farhadi. <a href="http://openaccess.thecvf.com/content_cvpr_2017/papers/Redmon_YOLO9000_Better_Faster_CVPR_2017_paper.pdf">“YOLO9000: Better, Faster, Stronger.”</a> CVPR 2017.</p>
<p>[3] Joseph Redmon, Ali Farhadi. <a href="https://pjreddie.com/media/files/papers/YOLOv3.pdf">“YOLOv3: An incremental improvement.”</a>.</p>
<p>[4] Wei Liu et al. <a href="https://arxiv.org/abs/1512.02325">“SSD: Single Shot MultiBox Detector.”</a> ECCV 2016.</p>
<p>[5] Tsung-Yi Lin, et al. <a href="https://arxiv.org/abs/1612.03144">“Feature Pyramid Networks for Object Detection.”</a> CVPR 2017.</p>
<p>[6] Tsung-Yi Lin, et al. <a href="https://arxiv.org/abs/1708.02002">“Focal Loss for Dense Object Detection.”</a> IEEE transactions on pattern analysis and machine intelligence, 2018.</p>
<p>[7] <a href="https://towardsdatascience.com/yolo-v3-object-detection-53fb7d3bfe6b">“What’s new in YOLO v3?”</a> by Ayoosh Kathuria on “Towards Data Science”, Apr 23, 2018.</p>Lilian WengPart 4 of the “Object Detection for Dummies” series focuses on one-stage models for fast detection, including SSD, RetinaNet, and models in the YOLO family. These models skip the explicit region proposal stage but apply the detection directly on dense sampled areas.Meta-Learning: Learning to Learn Fast2018-11-30T00:00:00+00:002018-11-30T00:00:00+00:00https://lilianweng.github.io/lil-log/2018/11/30/meta-learning<blockquote>
<p>Meta-learning, also known as “learning to learn”, intends to design models that can learn new skills or adapt to new environments rapidly with a few training examples. There are three common approaches: 1) learn an efficient distance metric (metric-based); 2) use (recurrent) network with external or internal memory (model-based); 3) optimize the model parameters explicitly for fast learning (optimization-based).</p>
</blockquote>
<!--more-->
<p>A good machine learning model often requires training with a large number of samples. Humans, in contrast, learn new concepts and skills much faster and more efficiently. Kids who have seen cats and birds only a few times can quickly tell them apart. People who know how to ride a bike are likely to discover the way to ride a motorcycle fast with little or even no demonstration. Is it possible to design a machine learning model with similar properties — learning new concepts and skills fast with a few training examples? That’s essentially what <strong>meta-learning</strong> aims to solve.</p>
<p>We expect a good meta-learning model capable of well adapting or generalizing to new tasks and new environments that have never been encountered during training time. The adaptation process, essentially a mini learning session, happens during test but with a limited exposure to the new task configurations. Eventually, the adapted model can complete new tasks. This is why meta-learning is also known as <a href="https://www.cs.cmu.edu/~rsalakhu/papers/LakeEtAl2015Science.pdf">learning to learn</a>.</p>
<p>The tasks can be any well-defined family of machine learning problems: supervised learning, reinforcement learning, etc. For example, here are a couple concrete meta-learning tasks:</p>
<ul>
<li>A classifier trained on non-cat images can tell whether a given image contains a cat after seeing a handful of cat pictures.</li>
<li>A game bot is able to quickly master a new game.</li>
<li>A mini robot completes the desired task on an uphill surface during test even through it was only trained in a flat surface environment.</li>
</ul>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#define-the-meta-learning-problem" id="markdown-toc-define-the-meta-learning-problem">Define the Meta-Learning Problem</a> <ul>
<li><a href="#a-simple-view" id="markdown-toc-a-simple-view">A Simple View</a></li>
<li><a href="#training-in-the-same-way-as-testing" id="markdown-toc-training-in-the-same-way-as-testing">Training in the Same Way as Testing</a></li>
<li><a href="#learner-and-meta-learner" id="markdown-toc-learner-and-meta-learner">Learner and Meta-Learner</a></li>
<li><a href="#common-approaches" id="markdown-toc-common-approaches">Common Approaches</a></li>
</ul>
</li>
<li><a href="#metric-based" id="markdown-toc-metric-based">Metric-Based</a> <ul>
<li><a href="#convolutional-siamese-neural-network" id="markdown-toc-convolutional-siamese-neural-network">Convolutional Siamese Neural Network</a></li>
<li><a href="#matching-networks" id="markdown-toc-matching-networks">Matching Networks</a> <ul>
<li><a href="#simple-embedding" id="markdown-toc-simple-embedding">Simple Embedding</a></li>
<li><a href="#full-context-embeddings" id="markdown-toc-full-context-embeddings">Full Context Embeddings</a></li>
</ul>
</li>
<li><a href="#relation-network" id="markdown-toc-relation-network">Relation Network</a></li>
<li><a href="#prototypical-networks" id="markdown-toc-prototypical-networks">Prototypical Networks</a></li>
</ul>
</li>
<li><a href="#model-based" id="markdown-toc-model-based">Model-Based</a> <ul>
<li><a href="#memory-augmented-neural-networks" id="markdown-toc-memory-augmented-neural-networks">Memory-Augmented Neural Networks</a> <ul>
<li><a href="#mann-for-meta-learning" id="markdown-toc-mann-for-meta-learning">MANN for Meta-Learning</a></li>
<li><a href="#addressing-mechanism-for-meta-learning" id="markdown-toc-addressing-mechanism-for-meta-learning">Addressing Mechanism for Meta-Learning</a></li>
</ul>
</li>
<li><a href="#meta-networks" id="markdown-toc-meta-networks">Meta Networks</a> <ul>
<li><a href="#fast-weights" id="markdown-toc-fast-weights">Fast Weights</a></li>
<li><a href="#model-components" id="markdown-toc-model-components">Model Components</a></li>
<li><a href="#training-process" id="markdown-toc-training-process">Training Process</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#optimization-based" id="markdown-toc-optimization-based">Optimization-Based</a> <ul>
<li><a href="#lstm-meta-learner" id="markdown-toc-lstm-meta-learner">LSTM Meta-Learner</a> <ul>
<li><a href="#why-lstm" id="markdown-toc-why-lstm">Why LSTM?</a></li>
<li><a href="#model-setup" id="markdown-toc-model-setup">Model Setup</a></li>
</ul>
</li>
<li><a href="#maml" id="markdown-toc-maml">MAML</a> <ul>
<li><a href="#first-order-maml" id="markdown-toc-first-order-maml">First-Order MAML</a></li>
</ul>
</li>
<li><a href="#reptile" id="markdown-toc-reptile">Reptile</a> <ul>
<li><a href="#the-optimization-assumption" id="markdown-toc-the-optimization-assumption">The Optimization Assumption</a></li>
<li><a href="#reptile-vs-fomaml" id="markdown-toc-reptile-vs-fomaml">Reptile vs FOMAML</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="define-the-meta-learning-problem">Define the Meta-Learning Problem</h2>
<p>In this post, we focus on the case when each desired task is a supervised learning problem like image classification. There is a lot of interesting literature on meta-learning with reinforcement learning problems (aka “Meta Reinforcement Learning”), but we would not cover them here.</p>
<h3 id="a-simple-view">A Simple View</h3>
<p>A good meta-learning model should be trained over a variety of learning tasks and optimized for the best performance on a distribution of tasks, including potentially unseen tasks. Each task is associated with a dataset <script type="math/tex">\mathcal{D}</script>, containing both feature vectors and true labels. The optimal model parameters are:</p>
<script type="math/tex; mode=display">\theta^* = \arg\min_\theta \mathbb{E}_{\mathcal{D}\sim p(\mathcal{D})} [\mathcal{L}_\theta(\mathcal{D})]</script>
<p>It looks very similar to a normal learning task, but <em>one dataset</em> is considered as <em>one data sample</em>.</p>
<p><em>Few-shot classification</em> is an instantiation of meta-learning in the field of supervised learning. The dataset <script type="math/tex">\mathcal{D}</script> is often split into two parts, a support set <script type="math/tex">S</script> for learning and a prediction set <script type="math/tex">B</script> for training or testing, <script type="math/tex">\mathcal{D}=\langle S, B\rangle</script>. Often we consider a <em>K-shot N-class classification</em> task: the support set contains K labelled examples for each of N classes.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/few-shot-classification.png" alt="few-shot-classification" /></p>
<p><em>Fig. 1. An example of 4-shot 2-class image classification. (Image thumbnails are from <a href="https://www.pinterest.com/">Pinterest</a>)</em></p>
<h3 id="training-in-the-same-way-as-testing">Training in the Same Way as Testing</h3>
<p>A dataset <script type="math/tex">\mathcal{D}</script> contains pairs of feature vectors and labels, <script type="math/tex">\mathcal{D} = \{(\mathbf{x}_i, y_i)\}</script> and each label belongs to a known label set <script type="math/tex">\mathcal{L}</script>. Let’s say, our classifier <script type="math/tex">f_\theta</script> with parameter <script type="math/tex">\theta</script> outputs a probability of a data point belonging to the class <script type="math/tex">y</script> given the feature vector <script type="math/tex">\mathbf{x}</script>, <script type="math/tex">P_\theta(y\vert\mathbf{x})</script>.</p>
<p>The optimal parameters should maximize the probability of true labels across multiple training batches <script type="math/tex">B \subset \mathcal{D}</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\theta^* &= {\arg\max}_{\theta} \mathbb{E}_{(\mathbf{x}, y)\in \mathcal{D}}[P_\theta(y \vert \mathbf{x})] &\\
\theta^* &= {\arg\max}_{\theta} \mathbb{E}_{B\subset \mathcal{D}}[\sum_{(\mathbf{x}, y)\in B}P_\theta(y \vert \mathbf{x})] & \scriptstyle{\text{; trained with mini-batches.}}
\end{aligned} %]]></script>
<p>In few-shot classification, the goal is to reduce the prediction error on data samples with unknown labels given a small support set for “fast learning” (think of how “fine-tuning” works). To make the training process mimics what happens during inference, we would like to “fake” datasets with a subset of labels to avoid exposing all the labels to the model and modify the optimization procedure accordingly to encourage fast learning:</p>
<ol>
<li>Sample a subset of labels, <script type="math/tex">L\subset\mathcal{L}</script>.</li>
<li>Sample a support set <script type="math/tex">S^L \subset \mathcal{D}</script> and a training batch <script type="math/tex">B^L \subset \mathcal{D}</script>. Both of them only contain data points with labels belonging to the sampled label set <script type="math/tex">L</script>, <script type="math/tex">y \in L, \forall (x, y) \in S^L, B^L</script>.</li>
<li>The support set is part of the model input. <!-- , $$\hat{y}=f_\theta(\mathbf{x}, S^L)$$ --></li>
<li>The final optimization uses the mini-batch <script type="math/tex">B^L</script> to compute the loss and update the model parameters through backpropagation, in the same way as how we use it in the supervised learning.</li>
</ol>
<p>You may consider each pair of sampled dataset <script type="math/tex">(S^L, B^L)</script> as one data point. The model is trained such that it can generalize to other datasets. Symbols in red are added for meta-learning in addition to the supervised learning objective.</p>
<script type="math/tex; mode=display">\theta = \arg\max_\theta \color{red}{E_{L\subset\mathcal{L}}[} E_{\color{red}{S^L \subset\mathcal{D}, }B^L \subset\mathcal{D}} [\sum_{(x, y)\in B^L} P_\theta(x, y\color{red}{, S^L})] \color{red}{]}</script>
<p>The idea is to some extent similar to using a pre-trained model in image classification (ImageNet) or language modeling (big text corpora) when only a limited set of task-specific data samples are available. Meta-learning takes this idea one step further, rather than fine-tuning according to one down-steam task, it optimizes the model to be good at many, if not all.</p>
<h3 id="learner-and-meta-learner">Learner and Meta-Learner</h3>
<p>Another popular view of meta-learning decomposes the model update into two stages:</p>
<ul>
<li>A classifier <script type="math/tex">f_\theta</script> is the “learner” model, trained for operating a given task;</li>
<li>In the meantime, a optimizer <script type="math/tex">g_\phi</script> learns how to update the learner model’s parameters via the support set <script type="math/tex">S</script>, <script type="math/tex">\theta' = g_\phi(\theta, S)</script>.</li>
</ul>
<p>Then in final optimization step, we need to update both <script type="math/tex">\theta</script> and <script type="math/tex">\phi</script> to maximize:</p>
<script type="math/tex; mode=display">\mathbb{E}_{L\subset\mathcal{L}}[ \mathbb{E}_{S^L \subset\mathcal{D}, B^L \subset\mathcal{D}} [\sum_{(\mathbf{x}, y)\in B^L} P_{g_\phi(\theta, S^L)}(y \vert \mathbf{x})]]</script>
<h3 id="common-approaches">Common Approaches</h3>
<p>There are three common approaches to meta-learning: metric-based, model-based, and optimization-based. Oriol Vinyals has a nice summary in his <a href="http://metalearning-symposium.ml/files/vinyals.pdf">talk</a> at meta-learning symposium @ NIPS 2018:</p>
<table class="info">
<thead>
<tr>
<th> </th>
<th>Model-based</th>
<th>Metric-based</th>
<th>Optimization-based</th>
</tr>
</thead>
<tbody>
<tr>
<td><strong>Key idea</strong></td>
<td>RNN; memory</td>
<td>Metric learning</td>
<td>Gradient descent</td>
</tr>
<tr>
<td><strong>How <script type="math/tex">P_\theta(y \vert \mathbf{x})</script> is modeled?</strong></td>
<td><script type="math/tex">f_\theta(\mathbf{x}, S)</script></td>
<td><script type="math/tex">\sum_{(\mathbf{x}_i, y_i) \in S} k_\theta(\mathbf{x}, \mathbf{x}_i)y_i</script> (*)</td>
<td><script type="math/tex">P_{g_\phi(\theta, S^L)}(y \vert \mathbf{x})</script></td>
</tr>
</tbody>
</table>
<p>(*) <script type="math/tex">k_\theta</script> is a kernel function measuring the similarity between <script type="math/tex">\mathbf{x}_i</script> and <script type="math/tex">\mathbf{x}</script>.</p>
<p>Next we are gonna review classic models in each approach.</p>
<h2 id="metric-based">Metric-Based</h2>
<p>The core idea in metric-based meta-learning is similar to nearest neighbors algorithms (i.e., <a href="https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm">k-NN</a> classificer and <a href="https://en.wikipedia.org/wiki/K-means_clustering">k-means</a> clustering) and <a href="https://en.wikipedia.org/wiki/Kernel_density_estimation">kernel density estimation</a>. The predicted probability over a set of known labels <script type="math/tex">y</script> is a weighted sum of labels of support set samples. The weight is generated by a kernel function <script type="math/tex">k_\theta</script>, measuring the similarity between two data samples.</p>
<script type="math/tex; mode=display">P_\theta(y \vert \mathbf{x}, S) = \sum_{(\mathbf{x}_i, y_i) \in S} k_\theta(\mathbf{x}, \mathbf{x}_i)y_i</script>
<p>To learn a good kernel is crucial to the success of a metric-based meta-learning model. <a href="https://en.wikipedia.org/wiki/Similarity_learning#Metric_learning">Metric learning</a> is well aligned with this intention, as it aims to learn a metric or distance function over objects. The notion of a good metric is problem-dependent. It should represent the relationship between inputs in the task space and facilitate problem solving.</p>
<p>All the models introduced below learn embedding vectors of input data explicitly and use them to design proper kernel functions.</p>
<h3 id="convolutional-siamese-neural-network">Convolutional Siamese Neural Network</h3>
<p>The <a href="https://papers.nips.cc/paper/769-signature-verification-using-a-siamese-time-delay-neural-network.pdf">Siamese Neural Network</a> is composed of two twin networks and their outputs are jointly trained on top with a function to learn the relationship between pairs of input data samples. The twin networks are identical, sharing the same weights and network parameters. In other words, both refer to the same embedding network that learns an efficient embedding to reveal relationship between pairs of data points.</p>
<p><a href="http://www.cs.toronto.edu/~rsalakhu/papers/oneshot1.pdf">Koch, Zemel & Salakhutdinov (2015)</a> proposed a method to use the siamese neural network to do one-shot image classification. First, the siamese network is trained for a verification task for telling whether two input images are in the same class. It outputs the probability of two images belonging to the same class. Then, during test time, the siamese network processes all the image pairs between a test image and every image in the support set. The final prediction is the class of the support image with the highest probability.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/siamese-conv-net.png" alt="siamese" /></p>
<p><em>Fig. 2. The architecture of convolutional siamese neural network for few-show image classification.</em></p>
<ol>
<li>First, convolutional siamese network learns to encode two images into feature vectors via a embedding function <script type="math/tex">f_\theta</script> which contains a couple of convolutional layers.</li>
<li>The L1-distance between two embeddings is <script type="math/tex">\vert f_\theta(\mathbf{x}_i) - f_\theta(\mathbf{x}_j) \vert</script>.</li>
<li>The distance is converted to a probability <script type="math/tex">p</script> by a linear feedforward layer and sigmoid. It is the probability of whether two images are drawn from the same class.</li>
<li>Intuitively the loss is cross entropy because the label is binary.</li>
</ol>
<!-- In this way, an efficient image embedding is trained so that the distance between two embeddings is proportional to the similarity between two images. -->
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
p(\mathbf{x}_i, \mathbf{x}_j) &= \sigma(\mathbf{W}\vert f_\theta(\mathbf{x}_i) - f_\theta(\mathbf{x}_j) \vert) \\
\mathcal{L}(B) &= \sum_{(\mathbf{x}_i, \mathbf{x}_j, y_i, y_j)\in B} \mathbf{1}_{y_i=y_j}\log p(\mathbf{x}_i, \mathbf{x}_j) + (1-\mathbf{1}_{y_i=y_j})\log (1-p(\mathbf{x}_i, \mathbf{x}_j))
\end{aligned} %]]></script>
<p>Images in the training batch <script type="math/tex">B</script> can be augmented with distortion. Of course, you can replace the L1 distance with other distance metric, L2, cosine, etc. Just make sure they are differential and then everything else works the same.</p>
<p>Given a support set <script type="math/tex">S</script> and a test image <script type="math/tex">\mathbf{x}</script>, the final predicted class is:</p>
<script type="math/tex; mode=display">\hat{c}_S(\mathbf{x}) = c(\arg\max_{\mathbf{x}_i \in S} P(\mathbf{x}, \mathbf{x}_i))</script>
<p>where <script type="math/tex">c(\mathbf{x})</script> is the class label of an image <script type="math/tex">\mathbf{x}</script> and <script type="math/tex">\hat{c}(.)</script> is the predicted label.</p>
<p>The assumption is that the learned embedding can be generalized to be useful for measuring the distance between images of unknown categories. This is the same assumption behind transfer learning via the adoption of a pre-trained model; for example, the convolutional features learned in the model pre-trained with ImageNet are expected to help other image tasks. However, the benefit of a pre-trained model decreases when the new task diverges from the original task that the model was trained on.</p>
<h3 id="matching-networks">Matching Networks</h3>
<p>The task of <strong>Matching Networks</strong> (<a href="http://papers.nips.cc/paper/6385-matching-networks-for-one-shot-learning.pdf">Vinyals et al., 2016</a>) is to learn a classifier <script type="math/tex">c_S</script> for any given (small) support set <script type="math/tex">S=\{x_i, y_i\}_{i=1}^k</script> (<em>k-shot</em> classification). This classifier defines a probability distribution over output labels <script type="math/tex">y</script> given a test example <script type="math/tex">\mathbf{x}</script>. Similar to other metric-based models, the classifier output is defined as a sum of labels of support samples weighted by attention kernel <script type="math/tex">a(\mathbf{x}, \mathbf{x}_i)</script> - which should be proportional to the similarity between <script type="math/tex">\mathbf{x}</script> and <script type="math/tex">\mathbf{x}_i</script>.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/matching-networks.png" alt="siamese" /></p>
<p><em>Fig. 3. The architecture of Matching Networks. (Image source: <a href="http://papers.nips.cc/paper/6385-matching-networks-for-one-shot-learning.pdf">original paper</a>)</em></p>
<script type="math/tex; mode=display">c_S(\mathbf{x}) = P(y \vert \mathbf{x}, S) = \sum_{i=1}^k a(\mathbf{x}, \mathbf{x}_i) y_i
\text{, where }S=\{(\mathbf{x}_i, y_i)\}_{i=1}^k</script>
<p>The attention kernel depends on two embedding functions, <script type="math/tex">f</script> and <script type="math/tex">g</script>, for encoding the test sample and the support set samples respectively. The attention weight between two data points is the cosine similarity, <script type="math/tex">\text{cosine}(.)</script>, between their embedding vectors, normalized by softmax:</p>
<script type="math/tex; mode=display">a(\mathbf{x}, \mathbf{x}_i) = \frac{\exp(\text{cosine}(f(\mathbf{x}), g(\mathbf{x}_i))}{\sum_{j=1}^k\exp(\text{cosine}(f(\mathbf{x}), g(\mathbf{x}_j))}</script>
<h4 id="simple-embedding">Simple Embedding</h4>
<p>In the simple version, an embedding function is a neural network with a single data sample as input. Potentially we can set <script type="math/tex">f=g</script>.</p>
<h4 id="full-context-embeddings">Full Context Embeddings</h4>
<p>The embedding vectors are critical inputs for building a good classifier. Taking a single data point as input might not be enough to efficiently gauge the entire feature space. Therefore, the Matching Network model further proposed to enhance the embedding functions by taking as input the whole support set <script type="math/tex">S</script> in addition to the original input, so that the learned embedding can be adjusted based on the relationship with other support samples.</p>
<ul>
<li><script type="math/tex">g_\theta(\mathbf{x}_i, S)</script> uses a bidirectional LSTM to encode <script type="math/tex">\mathbf{x}_i</script> in the context of the entire support set <script type="math/tex">S</script>.</li>
<li><script type="math/tex">f_\theta(\mathbf{x}, S)</script> encodes the test sample <script type="math/tex">\mathbf{x}</script> visa an LSTM with read attention over the support set <script type="math/tex">S</script>.
<ol>
<li>First the test sample goes through a simple neural network, such as a CNN, to extract basic features, <script type="math/tex">f'(\mathbf{x})</script>.</li>
<li>Then an LSTM is trained with a read attention vector over the support set as part of the hidden state: <br />
<script type="math/tex">% <![CDATA[
\begin{aligned}
\hat{\mathbf{h}}_t, \mathbf{c}_t &= \text{LSTM}(f'(\mathbf{x}), [\mathbf{h}_{t-1}, \mathbf{r}_{t-1}], \mathbf{c}_{t-1}) \\
\mathbf{h}_t &= \hat{\mathbf{h}}_t + f'(\mathbf{x}) \\
\mathbf{r}_{t-1} &= \sum_{i=1}^k a(\mathbf{h}_{t-1}, g(\mathbf{x}_i)) g(\mathbf{x}_i) \\
a(\mathbf{h}_{t-1}, g(\mathbf{x}_i)) &= \text{softmax}(\mathbf{h}_{t-1}^\top g(\mathbf{x}_i)) = \frac{\exp(\mathbf{h}_{t-1}^\top g(\mathbf{x}_i))}{\sum_{j=1}^k \exp(\mathbf{h}_{t-1}^\top g(\mathbf{x}_j))}
\end{aligned} %]]></script></li>
<li>Eventually <script type="math/tex">f(\mathbf{x}, S)=\mathbf{h}_K</script> if we do K steps of “read”.</li>
</ol>
</li>
</ul>
<p>This embedding method is called “Full Contextual Embeddings (FCE)”. Interestingly it does help improve the performance on a hard task (few-shot classification on mini ImageNet), but makes no difference on a simple task (Omniglot).</p>
<p>The training process in Matching Networks is designed to match inference at test time, see the details in the earlier <a href="#training-in-the-same-way-as-testing">section</a>. It is worthy of mentioning that the Matching Networks paper refined the idea that training and testing conditions should match.</p>
<script type="math/tex; mode=display">\theta^* = \arg\max_\theta \mathbb{E}_{L\subset\mathcal{L}}[ \mathbb{E}_{S^L \subset\mathcal{D}, B^L \subset\mathcal{D}} [\sum_{(\mathbf{x}, y)\in B^L} P_\theta(y\vert\mathbf{x}, S^L)]]</script>
<h3 id="relation-network">Relation Network</h3>
<p><strong>Relation Network (RN)</strong> (<a href="http://openaccess.thecvf.com/content_cvpr_2018/papers_backup/Sung_Learning_to_Compare_CVPR_2018_paper.pdf">Sung et al., 2018</a>) is similar to <a href="#convolutional-siamese-neural-network">siamese network</a> but with a few differences:</p>
<ol>
<li>The relationship is not captured by a simple L1 distance in the feature space, but predicted by a CNN classifier <script type="math/tex">g_\phi</script>. The relation score between a pair of inputs, <script type="math/tex">\mathbf{x}_i</script> and <script type="math/tex">\mathbf{x}_j</script>, is <script type="math/tex">r_{ij} = g_\phi([\mathbf{x}_i, \mathbf{x}_j])</script> where <script type="math/tex">[.,.]</script> is concatenation.</li>
<li>The objective function is MSE loss instead of cross-entropy, because conceptually RN focuses more on predicting relation scores which is more like regression, rather than binary classification, <script type="math/tex">\mathcal{L}(B) = \sum_{(\mathbf{x}_i, \mathbf{x}_j, y_i, y_j)\in B} (r_{ij} - \mathbf{1}_{y_i=y_j})^2</script>.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/relation-network.png" alt="relation-network" /></p>
<p><em>Fig. 4. Relation Network architecture for a 5-way 1-shot problem with one query example. (Image source: <a href="http://openaccess.thecvf.com/content_cvpr_2018/papers_backup/Sung_Learning_to_Compare_CVPR_2018_paper.pdf">original paper</a>)</em></p>
<p>(Note: There is another <a href="https://deepmind.com/blog/neural-approach-relational-reasoning/">Relation Network</a> for relational reasoning, proposed by DeepMind. Don’t get confused.)</p>
<h3 id="prototypical-networks">Prototypical Networks</h3>
<p><strong>Prototypical Networks</strong> (<a href="http://papers.nips.cc/paper/6996-prototypical-networks-for-few-shot-learning.pdf">Snell, Swersky & Zemel, 2017</a>) use an embedding function <script type="math/tex">f_\theta</script> to encode each input into a <script type="math/tex">M</script>-dimensional feature vector. A <em>prototype</em> feature vector is defined for every class <script type="math/tex">c \in \mathcal{C}</script>, as the mean vector of the embedded support data samples in this class.</p>
<script type="math/tex; mode=display">\mathbf{v}_c = \frac{1}{|S_c|} \sum_{(\mathbf{x}_i, y_i) \in S_c} f_\theta(\mathbf{x}_i)</script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/prototypical-networks.png" alt="prototypical-networks" /></p>
<p><em>Fig. 5. Prototypical networks in the few-shot and zero-shot scenarios. (Image source: <a href="http://papers.nips.cc/paper/6996-prototypical-networks-for-few-shot-learning.pdf">original paper</a>)</em></p>
<p>The distribution over classes for a given test input <script type="math/tex">\mathbf{x}</script> is a softmax over the inverse of distances between the test data embedding and prototype vectors.</p>
<script type="math/tex; mode=display">P(y=c\vert\mathbf{x})=\text{softmax}(-d_\varphi(f_\theta(\mathbf{x}), \mathbf{v}_c)) = \frac{\exp(-d_\varphi(f_\theta(\mathbf{x}), \mathbf{v}_c))}{\sum_{c' \in \mathcal{C}}\exp(-d_\varphi(f_\theta(\mathbf{x}), \mathbf{v}_{c'}))}</script>
<p>where <script type="math/tex">d_\varphi</script> can be any distance function as long as <script type="math/tex">\varphi</script> is differentiable. In the paper, they used the squared euclidean distance.</p>
<p>The loss function is the negative log-likelihood: <script type="math/tex">\mathcal{L}(\theta) = -\log P_\theta(y=c\vert\mathbf{x})</script>.</p>
<h2 id="model-based">Model-Based</h2>
<p>Model-based meta-learning models make no assumption on the form of <script type="math/tex">P_\theta(y\vert\mathbf{x})</script>. Rather it depends on a model designed specifically for fast learning — a model that updates its parameters rapidly with a few training steps. This rapid parameter update can be achieved by its internal architecture or controlled by another meta-learner model.</p>
<h3 id="memory-augmented-neural-networks">Memory-Augmented Neural Networks</h3>
<p>A family of model architectures use external memory storage to facilitate the learning process of neural networks, including <a href="/lil-log/2018/06/24/attention-attention.html#neural-turing-machines">Neural Turing Machines</a> and <a href="https://arxiv.org/abs/1410.3916">Memory Networks</a>. With an explicit storage buffer, it is easier for the network to rapidly incorporate new information and not to forget in the future. Such a model is known as <strong>MANN</strong>, short for “<strong>Memory-Augmented Neural Network</strong>”. Note that recurrent neural networks with only <em>internal memory</em> such as vanilla RNN or LSTM are not MANNs.</p>
<p>Because MANN is expected to encode new information fast and thus to adapt to new tasks after only a few samples, it fits well for meta-learning. Taking the Neural Turing Machine (NTM) as the base model, <a href="http://proceedings.mlr.press/v48/santoro16.pdf">Santoro et al. (2016)</a> proposed a set of modifications on the training setup and the memory retrieval mechanisms (or “addressing mechanisms”, deciding how to assign attention weights to memory vectors). Please go through <a href="/lil-log/2018/06/24/attention-attention.html#neural-turing-machines">the NTM section</a> in my other post first if you are not familiar with this matter before reading forward.</p>
<p>As a quick recap, NTM couples a controller neural network with external memory storage. The controller learns to read and write memory rows by soft attention, while the memory serves as a knowledge repository. The attention weights are generated by its addressing mechanism: content-based + location based.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/NTM.png" alt="NTM" /></p>
<p><em>Fig. 6. The architecture of Neural Turing Machine (NTM). The memory at time t, <script type="math/tex">\mathbf{M}_t</script> is a matrix of size <script type="math/tex">N \times M</script>, containing N vector rows and each has M dimensions.</em></p>
<h4 id="mann-for-meta-learning">MANN for Meta-Learning</h4>
<p>To use MANN for meta-learning tasks, we need to train it in a way that the memory can encode and capture information of new tasks fast and, in the meantime, any stored representation is easily and stably accessible.</p>
<p>The training described in <a href="http://proceedings.mlr.press/v48/santoro16.pdf">Santoro et al., 2016</a> happens in an interesting way so that the memory is forced to hold information for longer until the appropriate labels are presented later. In each training episode, the truth label <script type="math/tex">y_t</script> is presented with <strong>one step offset</strong>, <script type="math/tex">(\mathbf{x}_{t+1}, y_t)</script>: it is the true label for the input at the previous time step t, but presented as part of the input at time step t+1.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/mann-meta-learning.png" alt="NTM" /></p>
<p><em>Fig. 7. Task setup in MANN for meta-learning (Image source: <a href="http://proceedings.mlr.press/v48/santoro16.pdf">original paper</a>).</em></p>
<p>In this way, MANN is motivated to memorize the information of a new dataset, because the memory has to hold the current input until the label is present later and then retrieve the old information to make a prediction accordingly.</p>
<p>Next let us see how the memory is updated for efficient information retrieval and storage.</p>
<h4 id="addressing-mechanism-for-meta-learning">Addressing Mechanism for Meta-Learning</h4>
<p>Aside from the training process, a new pure content-based addressing mechanism is utilized to make the model better suitable for meta-learning.</p>
<p><strong>» How to read from memory?</strong>
<br />
The read attention is constructed purely based on the content similarity.</p>
<p>First, a key feature vector <script type="math/tex">\mathbf{k}_t</script> is produced at the time step t by the controller as a function of the input <script type="math/tex">\mathbf{x}</script>. Similar to NTM, a read weighting vector <script type="math/tex">\mathbf{w}_t^r</script> of N elements is computed as the cosine similarity between the key vector and every memory vector row, normalized by softmax. The read vector <script type="math/tex">\mathbf{r}_t</script> is a sum of memory records weighted by such weightings:</p>
<script type="math/tex; mode=display">\mathbf{r}_i = \sum_{i=1}^N w_t^r(i)\mathbf{M}_t(i)
\text{, where } w_t^r(i) = \text{softmax}(\frac{\mathbf{k}_t \cdot \mathbf{M}_t(i)}{\|\mathbf{k}_t\| \cdot \|\mathbf{M}_t(i)\|})</script>
<p>where <script type="math/tex">M_t</script> is the memory matrix at time t and <script type="math/tex">M_t(i)</script> is the i-th row in this matrix.</p>
<p><strong>» How to write into memory?</strong>
<br />
The addressing mechanism for writing newly received information into memory operates a lot like the <a href="https://en.wikipedia.org/wiki/Cache_replacement_policies">cache replacement</a> policy. The <strong>Least Recently Used Access (LRUA)</strong> writer is designed for MANN to better work in the scenario of meta-learning. A LRUA write head prefers to write new content to either the <em>least used</em> memory location or the <em>most recently used</em> memory location.</p>
<ul>
<li>Rarely used locations: so that we can preserve frequently used information (see <a href="https://en.wikipedia.org/wiki/Least_frequently_used">LFU</a>);</li>
<li>The last used location: the motivation is that once a piece of information is retrieved once, it probably won’t be called again for a while (see <a href="https://en.wikipedia.org/wiki/Cache_replacement_policies#Most_recently_used_(MRU)">MRU</a>).</li>
</ul>
<p>There are many cache replacement algorithms and each of them could potentially replace the design here with better performance in different use cases. Furthermore, it would be a good idea to learn the memory usage pattern and addressing strategies rather than arbitrarily set it.</p>
<p>The preference of LRUA is carried out in a way that everything is differentiable:</p>
<ol>
<li>The usage weight <script type="math/tex">\mathbf{w}^u_t</script> at time t is a sum of current read and write vectors, in addition to the decayed last usage weight, <script type="math/tex">\gamma \mathbf{w}^u_{t-1}</script>, where <script type="math/tex">\gamma</script> is a decay factor.</li>
<li>The write vector is an interpolation between the previous read weight (prefer “the last used location”) and the previous least-used weight (prefer “rarely used location”). The interpolation parameter is the sigmoid of a hyperparameter <script type="math/tex">\alpha</script>.</li>
<li>The least-used weight <script type="math/tex">\mathbf{w}^{lu}</script> is scaled according to usage weights <script type="math/tex">\mathbf{w}_t^u</script>, in which any dimension remains at 1 if smaller than the n-th smallest element in the vector and 0 otherwise.</li>
</ol>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{w}_t^u &= \gamma \mathbf{w}_{t-1}^u + \mathbf{w}_t^r + \mathbf{w}_t^w \\
\mathbf{w}_t^r &= \text{softmax}(\text{cosine}(\mathbf{k}_t, \mathbf{M}_t(i))) \\
\mathbf{w}_t^w &= \sigma(\alpha)\mathbf{w}_{t-1}^r + (1-\sigma(\alpha))\mathbf{w}^{lu}_{t-1}\\
\mathbf{w}_t^{lu} &= \mathbf{1}_{w_t^u(i) \leq m(\mathbf{w}_t^u, n)}
\text{, where }m(\mathbf{w}_t^u, n)\text{ is the }n\text{-th smallest element in vector }\mathbf{w}_t^u\text{.}
\end{aligned} %]]></script>
<p>Finally, after the least used memory location, indicated by <script type="math/tex">\mathbf{w}_t^{lu}</script>, is set to zero, every memory row is updated:</p>
<script type="math/tex; mode=display">\mathbf{M}_t(i) = \mathbf{M}_{t-1}(i) + w_t^w(i)\mathbf{k}_t, \forall i</script>
<h3 id="meta-networks">Meta Networks</h3>
<p><strong>Meta Networks</strong> (<a href="https://arxiv.org/abs/1703.00837">Munkhdalai & Yu, 2017</a>), short for <strong>MetaNet</strong>, is a meta-learning model with architecture and training process designed for <em>rapid</em> generalization across tasks.</p>
<h4 id="fast-weights">Fast Weights</h4>
<p>The rapid generalization of MetaNet relies on “fast weights”. There are a handful of papers on this topic, but I haven’t read all of them in detail and I failed to find a very concrete definition, only a vague agreement on the concept. Normally weights in the neural networks are updated by stochastic gradient descent in an objective function and this process is known to be slow. One faster way to learn is to utilize one neural network to predict the parameters of another neural network and the generated weights are called <em>fast weights</em>. In comparison, the ordinary SGD-based weights are named <em>slow weights</em>.</p>
<p>In MetaNet, loss gradients are used as <em>meta information</em> to populate models that learn fast weights. Slow and fast weights are combined to make predictions in neural networks.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/combine-slow-fast-weights.png" alt="slow-fast-weights" /></p>
<p><em>Fig. 8. Combining slow and fast weights in a MLP. <script type="math/tex">\bigoplus</script> is element-wise sum. (Image source: <a href="https://arxiv.org/abs/1703.00837">original paper</a>).</em></p>
<h4 id="model-components">Model Components</h4>
<blockquote>
<p>Disclaimer: Below you will find my annotations are different from those in the paper. imo, the paper is poorly written, but the idea is still interesting. So I’m presenting the idea in my own language.</p>
</blockquote>
<p>Key components of MetaNet are:</p>
<ul>
<li>An embedding function <script type="math/tex">f_\theta</script>, parameterized by <script type="math/tex">\theta</script>, encodes raw inputs into feature vectors. Similar to <a href="#convolutional-siamese-neural-network">Siamese Neural Network</a>, these embeddings are trained to be useful for telling whether two inputs are of the same class (verification task).</li>
<li>A base learner model <script type="math/tex">g_\phi</script>, parameterized by weights <script type="math/tex">\phi</script>, completes the actual learning task.</li>
</ul>
<p>If we stop here, it looks just like <a href="#relation-network">Relation Network</a>. MetaNet, in addition, explicitly models the fast weights of both functions and then aggregates them back into the model (See Fig. 8).</p>
<p>Therefore we need additional two functions to output fast weights for <script type="math/tex">f</script> and <script type="math/tex">g</script> respectively.</p>
<ul>
<li><script type="math/tex">F_w</script>: a LSTM parameterized by <script type="math/tex">w</script> for learning fast weights <script type="math/tex">\theta^+</script> of the embedding function <script type="math/tex">f</script>. It takes as input gradients of <script type="math/tex">f</script>’s embedding loss for verification task.</li>
<li><script type="math/tex">G_v</script>: a neural network parameterized by <script type="math/tex">v</script> learning fast weights <script type="math/tex">\phi^+</script> for the base learner <script type="math/tex">g</script> from its loss gradients. In MetaNet, the learner’s loss gradients are viewed as the <em>meta information</em> of the task.</li>
</ul>
<p>Ok, now let’s see how meta networks are trained. The training data contains multiple pairs of datasets: a support set <script type="math/tex">S=\{\mathbf{x}'_i, y'_i\}_{i=1}^K</script> and a test set <script type="math/tex">U=\{\mathbf{x}_i, y_i\}_{i=1}^L</script>. Recall that we have four networks and four sets of model parameters to learn, <script type="math/tex">(\theta, \phi, w, v)</script>.</p>
<p style="width: 90%;" class="center"><img src="/lil-log/assets/images/meta-network.png" alt="meta-net" /></p>
<p><em>Fig.9. The MetaNet architecture.</em></p>
<h4 id="training-process">Training Process</h4>
<ol>
<li>Sample a random pair of inputs at each time step t from the support set <script type="math/tex">S</script>, <script type="math/tex">(\mathbf{x}'_i, y'_i)</script> and <script type="math/tex">(\mathbf{x}'_j, y_j)</script>. Let <script type="math/tex">\mathbf{x}_{(t,1)}=\mathbf{x}'_i</script> and <script type="math/tex">\mathbf{x}_{(t,2)}=\mathbf{x}'_j</script>.<br />
for <script type="math/tex">t = 1, \dots, K</script>:
<ul>
<li>a. Compute a loss for representation learning; i.e., cross entropy for the verification task:<br />
<script type="math/tex">\mathcal{L}^\text{emb}_t = \mathbf{1}_{y'_i=y'_j} \log P_t + (1 - \mathbf{1}_{y'_i=y'_j})\log(1 - P_t)\text{, where }P_t = \sigma(\mathbf{W}\vert f_\theta(\mathbf{x}_{(t,1)}) - f_\theta(\mathbf{x}_{(t,2)})\vert)</script></li>
</ul>
</li>
<li>Compute the task-level fast weights:
<script type="math/tex">\theta^+ = F_w(\nabla_\theta \mathcal{L}^\text{emb}_1, \dots, \mathcal{L}^\text{emb}_T)</script></li>
<li>Next go through examples in the support set <script type="math/tex">S</script> and compute the example-level fast weights. Meanwhile, update the memory with learned representations.<br />
for <script type="math/tex">i=1, \dots, K</script>:
<ul>
<li>a. The base learner outputs a probability distribution: <script type="math/tex">P(\hat{y}_i \vert \mathbf{x}_i) = g_\phi(\mathbf{x}_i)</script> and the loss can be cross-entropy or MSE: <script type="math/tex">\mathcal{L}^\text{task}_i = y'_i \log g_\phi(\mathbf{x}'_i) + (1- y'_i) \log (1 - g_\phi(\mathbf{x}'_i))</script></li>
<li>b. Extract meta information (loss gradients) of the task and compute the example-level fast weights:
<script type="math/tex">\phi_i^+ = G_v(\nabla_\phi\mathcal{L}^\text{task}_i)</script>
<ul>
<li>Then store <script type="math/tex">\phi^+_i</script> into <script type="math/tex">i</script>-th location of the “value” memory <script type="math/tex">\mathbf{M}</script>.<br /></li>
</ul>
</li>
<li>d. Encode the support sample into a task-specific input representation using both slow and fast weights: <script type="math/tex">r'_i = f_{\theta, \theta^+}(\mathbf{x}'_i)</script>
<ul>
<li>Then store <script type="math/tex">r'_i</script> into <script type="math/tex">i</script>-th location of the “key” memory <script type="math/tex">\mathbf{R}</script>.</li>
</ul>
</li>
</ul>
</li>
<li>Finally it is the time to construct the training loss using the test set <script type="math/tex">U=\{\mathbf{x}_i, y_i\}_{i=1}^L</script>.<br />
Starts with <script type="math/tex">\mathcal{L}_\text{train}=0</script>:<br />
for <script type="math/tex">j=1, \dots, L</script>:
<ul>
<li>a. Encode the test sample into a task-specific input representation:
<script type="math/tex">r_j = f_{\theta, \theta^+}(\mathbf{x}_j)</script></li>
<li>b. The fast weights are computed by attending to representations of support set samples in memory <script type="math/tex">\mathbf{R}</script>. The attention function is of your choice. Here MetaNet uses cosine similarity:<br />
<script type="math/tex">% <![CDATA[
\begin{aligned}
a_j &= \text{cosine}(\mathbf{R}, r_j) = [\frac{r'_1\cdot r_j}{\|r'_1\|\cdot\|r_j\|}, \dots, \frac{r'_N\cdot r_j}{\|r'_N\|\cdot\|r_j\|}]\\
\phi^+_j &= \text{softmax}(a_j)^\top \mathbf{M}
\end{aligned} %]]></script></li>
<li>c. Update the training loss: <script type="math/tex">\mathcal{L}_\text{train} \leftarrow \mathcal{L}_\text{train} + \mathcal{L}^\text{task}(g_{\phi, \phi^+}(\mathbf{x}_i), y_i)</script></li>
</ul>
</li>
<li>Update all the parameters <script type="math/tex">(\theta, \phi, w, v)</script> using <script type="math/tex">\mathcal{L}_\text{train}</script>.</li>
</ol>
<h2 id="optimization-based">Optimization-Based</h2>
<p>Deep learning models learn through backpropagation of gradients. However, the gradient-based optimization is neither designed to cope with a small number of training samples, nor to converge within a small number of optimization steps. Is there a way to adjust the optimization algorithm so that the model can be good at learning with a few examples? This is what optimization-based approach meta-learning algorithms intend for.</p>
<h3 id="lstm-meta-learner">LSTM Meta-Learner</h3>
<p>The optimization algorithm can be explicitly modeled. <a href="https://openreview.net/pdf?id=rJY0-Kcll">Ravi & Larochelle (2017)</a> did so and named it “meta-learner”, while the original model for handling the task is called “learner”. The goal of the meta-learner is to efficiently update the learner’s parameters using a small support set so that the learner can adapt to the new task quickly.</p>
<p>Let’s denote the learner model as <script type="math/tex">M_\theta</script> parameterized by <script type="math/tex">\theta</script>, the meta-learner as <script type="math/tex">R_\Theta</script> with parameters <script type="math/tex">\Theta</script>, and the loss function <script type="math/tex">\mathcal{L}</script>.</p>
<h4 id="why-lstm">Why LSTM?</h4>
<p>The meta-learner is modeled as a LSTM, because:</p>
<ol>
<li>There is similarity between the gradient-based update in backpropagation and the cell-state update in LSTM.</li>
<li>Knowing a history of gradients benefits the gradient update; think about how <a href="http://ruder.io/optimizing-gradient-descent/index.html#momentum">momentum</a> works.</li>
</ol>
<p>The update for the learner’s parameters at time step t with a learning rate <script type="math/tex">\alpha_t</script> is:</p>
<script type="math/tex; mode=display">\theta_t = \theta_{t-1} - \alpha_t \nabla_{\theta_{t-1}}\mathcal{L}_t</script>
<p>It has the same form as the cell state update in LSTM, if we set forget gate <script type="math/tex">f_t=1</script>, input gate <script type="math/tex">i_t = \alpha_t</script>, cell state <script type="math/tex">c_t = \theta_t</script>, and new cell state <script type="math/tex">\tilde{c}_t = -\nabla_{\theta_{t-1}}\mathcal{L}_t</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t\\
&= \theta_{t-1} - \alpha_t\nabla_{\theta_{t-1}}\mathcal{L}_t
\end{aligned} %]]></script>
<p>While fixing <script type="math/tex">f_t=1</script> and <script type="math/tex">i_t=\alpha_t</script> might not be the optimal, both of them can be learnable and adaptable to different datasets.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
f_t &= \sigma(\mathbf{W}_f \cdot [\nabla_{\theta_{t-1}}\mathcal{L}_t, \mathcal{L}_t, \theta_{t-1}, f_{t-1}] + \mathbf{b}_f) & \scriptstyle{\text{; how much to forget the old value of parameters.}}\\
i_t &= \sigma(\mathbf{W}_i \cdot [\nabla_{\theta_{t-1}}\mathcal{L}_t, \mathcal{L}_t, \theta_{t-1}, i_{t-1}] + \mathbf{b}_i) & \scriptstyle{\text{; corresponding to the learning rate at time step t.}}\\
\tilde{\theta}_t &= -\nabla_{\theta_{t-1}}\mathcal{L}_t &\\
\theta_t &= f_t \odot \theta_{t-1} + i_t \odot \tilde{\theta}_t &\\
\end{aligned} %]]></script>
<h4 id="model-setup">Model Setup</h4>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/lstm-meta-learner.png" alt="lstm-meta-learner" /></p>
<p><em>Fig.10. How the learner <script type="math/tex">M_\theta</script> and the meta-learner <script type="math/tex">R_\Theta</script> are trained. (Image source: <a href="https://openreview.net/pdf?id=rJY0-Kcll">original paper</a> with more annotations)</em></p>
<p>The training process mimics what happens during test, since it has been proved to be beneficial in <a href="#matching-networks">Matching Networks</a>. During each training epoch, we first sample a dataset <script type="math/tex">\mathcal{D} = (\mathcal{D}_\text{train}, \mathcal{D}_\text{test}) \in \hat{\mathcal{D}}_\text{meta-train}</script> and then sample mini-batches out of <script type="math/tex">\mathcal{D}_\text{train}</script> to update <script type="math/tex">\theta</script> for <script type="math/tex">T</script> rounds. The final state of the learner parameter <script type="math/tex">\theta_T</script> is used to train the meta-learner on the test data <script type="math/tex">\mathcal{D}_\text{test}</script>.</p>
<p>Two implementation details to pay extra attention to:</p>
<ol>
<li>How to compress the parameter space in LSTM meta-learner? As the meta-learner is modeling parameters of another neural network, it would have hundreds of thousands of variables to learn. Following the <a href="https://arxiv.org/abs/1606.04474">idea</a> of sharing parameters across coordinates,</li>
<li>To simplify the training process, the meta-learner assumes that the loss <script type="math/tex">\mathcal{L}_t</script> and the gradient <script type="math/tex">\nabla_{\theta_{t-1}} \mathcal{L}_t</script> are independent.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/train-meta-learner.png" alt="train-meta-learner" /></p>
<h3 id="maml">MAML</h3>
<p><strong>MAML</strong>, short for <strong>Model-Agnostic Meta-Learning</strong> (<a href="https://arxiv.org/abs/1703.03400">Finn, et al. 2017</a>) is a fairly general optimization algorithm, compatible with any model that learns through gradient descent.</p>
<p>Let’s say our model is <script type="math/tex">f_\theta</script> with parameters <script type="math/tex">\theta</script>. Given a task <script type="math/tex">\tau_i</script> and its associated dataset <script type="math/tex">(\mathcal{D}^{(i)}_\text{train}, \mathcal{D}^{(i)}_\text{test})</script>, we can update the model parameters by one or more gradient descent steps (the following example only contains one step):</p>
<script type="math/tex; mode=display">\theta'_i = \theta - \alpha \nabla_\theta\mathcal{L}^{(0)}_{\tau_i}(f_\theta)</script>
<p>where <script type="math/tex">\mathcal{L}^{(0)}</script> is the loss computed using the mini data batch with id (0).</p>
<p style="width: 45%;" class="center"><img src="/lil-log/assets/images/maml.png" alt="MAML" /></p>
<p><em>Fig. 11. Diagram of MAML. (Image source: <a href="https://arxiv.org/abs/1703.03400">original paper</a>)</em></p>
<p>Well, the above formula only optimizes for one task. To achieve a good generalization across a variety of tasks, we would like to find the optimal <script type="math/tex">\theta^*</script> so that the task-specific fine-tuning is more efficient. Now, we sample a new data batch with id (1) for updating the meta-objective. The loss, denoted as <script type="math/tex">\mathcal{L}^{(1)}</script>, depends on the mini batch (1). The superscripts in <script type="math/tex">\mathcal{L}^{(0)}</script> and <script type="math/tex">\mathcal{L}^{(1)}</script> only indicate different data batches, and they refer to the same loss objective for the same task.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\theta^*
&= \arg\min_\theta \sum_{\tau_i \sim p(\tau)} \mathcal{L}_{\tau_i}^{(1)} (f_{\theta'_i}) = \arg\min_\theta \sum_{\tau_i \sim p(\tau)} \mathcal{L}_{\tau_i}^{(1)} (f_{\theta - \alpha\nabla_\theta \mathcal{L}_{\tau_i}^{(0)}(f_\theta)}) & \\
\theta &\leftarrow \theta - \beta \nabla_{\theta} \sum_{\tau_i \sim p(\tau)} \mathcal{L}_{\tau_i}^{(1)} (f_{\theta - \alpha\nabla_\theta \mathcal{L}_{\tau_i}^{(0)}(f_\theta)}) & \scriptstyle{\text{; updating rule}}
\end{aligned} %]]></script>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/maml-algo.png" alt="MAML Algorithm" /></p>
<p><em>Fig. 12. The general form of MAML algorithm. (Image source: <a href="https://arxiv.org/abs/1703.03400">original paper</a>)</em></p>
<h4 id="first-order-maml">First-Order MAML</h4>
<p>The meta-optimization step above relies on second derivatives. To make the computation less expensive, a modified version of MAML omits second derivatives, resulting in a simplified and cheaper implementation, known as <strong>First-Order MAML (FOMAML)</strong>.</p>
<p>Let’s consider the case of performing <script type="math/tex">k</script> inner gradient steps, <script type="math/tex">k\geq1</script>. Starting with the initial model parameter <script type="math/tex">\theta_\text{meta}</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\theta_0 &= \theta_\text{meta}\\
\theta_1 &= \theta_0 - \alpha\nabla_\theta\mathcal{L}^{(0)}(\theta_0)\\
\theta_2 &= \theta_1 - \alpha\nabla_\theta\mathcal{L}^{(0)}(\theta_1)\\
&\dots\\
\theta_k &= \theta_{k-1} - \alpha\nabla_\theta\mathcal{L}^{(0)}(\theta_{k-1})
\end{aligned} %]]></script>
<p>Then in the outer loop, we sample a new data batch for updating the meta-objective.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\theta_\text{meta} &\leftarrow \theta_\text{meta} - \beta g_\text{MAML} & \scriptstyle{\text{; update for meta-objective}} \\[2mm]
\text{where } g_\text{MAML}
&= \nabla_{\theta} \mathcal{L}^{(1)}(\theta_k) &\\[2mm]
&= \nabla_{\theta_k} \mathcal{L}^{(1)}(\theta_k) \cdot (\nabla_{\theta_{k-1}} \theta_k) \dots (\nabla_{\theta_0} \theta_1) \cdot (\nabla_{\theta} \theta_0) & \scriptstyle{\text{; following the chain rule}} \\
&= \nabla_{\theta_k} \mathcal{L}^{(1)}(\theta_k) \cdot \prod_{i=1}^k \nabla_{\theta_{i-1}} \theta_i & \\
&= \nabla_{\theta_k} \mathcal{L}^{(1)}(\theta_k) \cdot \prod_{i=1}^k \nabla_{\theta_{i-1}} (\theta_{i-1} - \alpha\nabla_\theta\mathcal{L}^{(0)}(\theta_{i-1})) & \\
&= \nabla_{\theta_k} \mathcal{L}^{(1)}(\theta_k) \cdot \prod_{i=1}^k (I - \alpha\nabla_{\theta_{i-1}}(\nabla_\theta\mathcal{L}^{(0)}(\theta_{i-1}))) &
\end{aligned} %]]></script>
<p>The MAML gradient is:</p>
<script type="math/tex; mode=display">g_\text{MAML} = \nabla_{\theta_k} \mathcal{L}^{(1)}(\theta_k) \cdot \prod_{i=1}^k (I - \alpha \color{red}{\nabla_{\theta_{i-1}}(\nabla_\theta\mathcal{L}^{(0)}(\theta_{i-1}))})</script>
<p>The First-Order MAML ignores the second derivative part in red. It is simplified as follows, equivalent to the derivative of the last inner gradient update result.</p>
<script type="math/tex; mode=display">g_\text{FOMAML} = \nabla_{\theta_k} \mathcal{L}^{(1)}(\theta_k)</script>
<h3 id="reptile">Reptile</h3>
<p><strong>Reptile</strong> (<a href="https://arxiv.org/abs/1803.02999">Nichol, Achiam & Schulman, 2018</a>) is a remarkably simple meta-learning optimization algorithm. It is similar to MAML in many ways, given that both rely on meta-optimization through gradient descent and both are model-agnostic.</p>
<p>The Reptile works by repeatedly:</p>
<ul>
<li>1) sampling a task,</li>
<li>2) training on it by multiple gradient descent steps,</li>
<li>3) and then moving the model weights towards the new parameters.</li>
</ul>
<p>See the algorithm below:
<script type="math/tex">\text{SGD}(\mathcal{L}_{\tau_i}, \theta, k)</script> performs stochastic gradient update for k steps on the loss <script type="math/tex">\mathcal{L}_{\tau_i}</script> starting with initial parameter <script type="math/tex">\theta</script> and returns the final parameter vector. The batch version samples multiple tasks instead of one within each iteration. The reptile gradient is defined as <script type="math/tex">(\theta - W)/\alpha</script>, where <script type="math/tex">\alpha</script> is the stepsize used by the SGD operation.</p>
<p style="width: 52%;" class="center"><img src="/lil-log/assets/images/reptile-algo.png" alt="Reptile Algorithm" /></p>
<p><em>Fig. 13. The batched version of Reptile algorithm. (Image source: <a href="https://arxiv.org/abs/1803.02999">original paper</a>)</em></p>
<p>At a glance, the algorithm looks a lot like an ordinary SGD. However, because the task-specific optimization can take more than one step. it eventually makes <script type="math/tex">\text{SGD}(\mathbb{E}
_\tau[\mathcal{L}_{\tau}], \theta, k)</script> diverge from <script type="math/tex">\mathbb{E}_\tau [\text{SGD}(\mathcal{L}_{\tau}, \theta, k)]</script> when k > 1.</p>
<h4 id="the-optimization-assumption">The Optimization Assumption</h4>
<p>Assuming that a task <script type="math/tex">\tau \sim p(\tau)</script> has a manifold of optimal network configuration, <script type="math/tex">\mathcal{W}_{\tau}^*</script>. The model <script type="math/tex">f_\theta</script> achieves the best performance for task <script type="math/tex">\tau</script> when <script type="math/tex">\theta</script> lays on the surface of <script type="math/tex">\mathcal{W}_{\tau}^*</script>. To find a solution that is good across tasks, we would like to find a parameter close to all the optimal manifolds of all tasks:</p>
<script type="math/tex; mode=display">\theta^* = \arg\min_\theta \mathbb{E}_{\tau \sim p(\tau)} [\frac{1}{2} \text{dist}(\theta, \mathcal{W}_\tau^*)^2]</script>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/reptile-optim.png" alt="Reptile Algorithm" /></p>
<p><em>Fig. 14. The Reptile algorithm updates the parameter alternatively to be closer to the optimal manifolds of different tasks. (Image source: <a href="https://arxiv.org/abs/1803.02999">original paper</a>)</em></p>
<p>Let’s use the L2 distance as <script type="math/tex">\text{dist}(.)</script> and the distance between a point <script type="math/tex">\theta</script> and a set <script type="math/tex">\mathcal{W}_\tau^*</script> equals to the distance between <script type="math/tex">\theta</script> and a point <script type="math/tex">W_{\tau}^*(\theta)</script> on the manifold that is closest to <script type="math/tex">\theta</script>:</p>
<script type="math/tex; mode=display">\text{dist}(\theta, \mathcal{W}_{\tau}^*) = \text{dist}(\theta, W_{\tau}^*(\theta)) \text{, where }W_{\tau}^*(\theta) = \arg\min_{W\in\mathcal{W}_{\tau}^*} \text{dist}(\theta, W)</script>
<p>The gradient of the squared euclidean distance is:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\nabla_\theta[\frac{1}{2}\text{dist}(\theta, \mathcal{W}_{\tau_i}^*)^2]
&= \nabla_\theta[\frac{1}{2}\text{dist}(\theta, W_{\tau_i}^*(\theta))^2] & \\
&= \nabla_\theta[\frac{1}{2}(\theta - W_{\tau_i}^*(\theta))^2] & \\
&= \theta - W_{\tau_i}^*(\theta) & \scriptstyle{\text{; See notes.}}
\end{aligned} %]]></script>
<p>Notes: According to the Reptile paper, “<em>the gradient of the squared euclidean distance between a point Θ and a set S is the vector 2(Θ − p), where p is the closest point in S to Θ</em>”. Technically the closest point in S is also a function of Θ, but I’m not sure why the gradient does not need to worry about the derivative of p. (Please feel free to leave me a comment or send me an email about this if you have ideas.)</p>
<p>Thus the update rule for one stochastic gradient step is:</p>
<script type="math/tex; mode=display">\theta = \theta - \alpha \nabla_\theta[\frac{1}{2} \text{dist}(\theta, \mathcal{W}_{\tau_i}^*)^2] = \theta - \alpha(\theta - W_{\tau_i}^*(\theta)) = (1-\alpha)\theta + \alpha W_{\tau_i}^*(\theta)</script>
<p>The closest point on the optimal task manifold <script type="math/tex">W_{\tau_i}^*(\theta)</script> cannot be computed exactly, but Reptile approximates it using <script type="math/tex">\text{SGD}(\mathcal{L}_\tau, \theta, k)</script>.</p>
<h4 id="reptile-vs-fomaml">Reptile vs FOMAML</h4>
<p>To demonstrate the deeper connection between Reptile and MAML, let’s expand the update formula with an example performing two gradient steps, k=2 in <script type="math/tex">\text{SGD}(.)</script>. Same as defined <a href="#maml">above</a>, <script type="math/tex">\mathcal{L}^{(0)}</script> and <script type="math/tex">\mathcal{L}^{(1)}</script> are losses using different mini-batches of data. For ease of reading, we adopt two simplified annotations: <script type="math/tex">g^{(i)}_j = \nabla_{\theta} \mathcal{L}^{(i)}(\theta_j)</script> and <script type="math/tex">H^{(i)}_j = \nabla^2_{\theta} \mathcal{L}^{(i)}(\theta_j)</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\theta_0 &= \theta_\text{meta}\\
\theta_1 &= \theta_0 - \alpha\nabla_\theta\mathcal{L}^{(0)}(\theta_0)= \theta_0 - \alpha g^{(0)}_0 \\
\theta_2 &= \theta_1 - \alpha\nabla_\theta\mathcal{L}^{(1)}(\theta_1) = \theta_0 - \alpha g^{(0)}_0 - \alpha g^{(1)}_1
\end{aligned} %]]></script>
<p>According to the <a href="#first-order-maml">early section</a>, the gradient of FOMAML is the last inner gradient update result. Therefore, when k=1:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
g_\text{FOMAML} &= \nabla_{\theta_1} \mathcal{L}^{(1)}(\theta_1) = g^{(1)}_1 \\
g_\text{MAML} &= \nabla_{\theta_1} \mathcal{L}^{(1)}(\theta_1) \cdot (I - \alpha\nabla^2_{\theta} \mathcal{L}^{(0)}(\theta_0)) = g^{(1)}_1 - \alpha H^{(0)}_0 g^{(1)}_1
\end{aligned} %]]></script>
<p>The Reptile gradient is defined as:</p>
<script type="math/tex; mode=display">g_\text{Reptile} = (\theta_0 - \theta_2) / \alpha = g^{(0)}_0 + g^{(1)}_1</script>
<p>Up to now we have:</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/reptile_vs_FOMAML.png" alt="Reptile vs FOMAML" /></p>
<p><em>Fig. 15. Reptile versus FOMAML in one loop of meta-optimization. (Image source: <a href="https://www.slideshare.net/YoonhoLee4/on-firstorder-metalearning-algorithms">slides</a> on Reptile by Yoonho Lee.)</em></p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
g_\text{FOMAML} &= g^{(1)}_1 \\
g_\text{MAML} &= g^{(1)}_1 - \alpha H^{(0)}_0 g^{(1)}_1 \\
g_\text{Reptile} &= g^{(0)}_0 + g^{(1)}_1
\end{aligned} %]]></script>
<p>Next let’s try further expand <script type="math/tex">g^{(1)}_1</script> using <a href="https://en.wikipedia.org/wiki/Taylor_series">Taylor expansion</a>. Recall that Taylor expansion of a function <script type="math/tex">f(x)</script> that is differentiable at a number <script type="math/tex">a</script> is:</p>
<script type="math/tex; mode=display">f(x) = f(a) + \frac{f'(a)}{1!}(x-a) + \frac{f''(a)}{2!}(x-a)^2 + \dots = \sum_{i=0}^\infty \frac{f^{(i)}(a)}{i!}(x-a)^i</script>
<p>We can consider <script type="math/tex">\nabla_{\theta}\mathcal{L}^{(1)}(.)</script> as a function and <script type="math/tex">\theta_0</script> as a value point. The Taylor expansion of <script type="math/tex">g_1^{(1)}</script> at the value point <script type="math/tex">\theta_0</script> is:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
g_1^{(1)} &= \nabla_{\theta}\mathcal{L}^{(1)}(\theta_1) \\
&= \nabla_{\theta}\mathcal{L}^{(1)}(\theta_0) + \nabla^2_\theta\mathcal{L}^{(1)}(\theta_0)(\theta_1 - \theta_0) + \frac{1}{2}\nabla^3_\theta\mathcal{L}^{(1)}(\theta_0)(\theta_1 - \theta_0)^2 + \dots & \\
&= g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} + \frac{\alpha^2}{2}\nabla^3_\theta\mathcal{L}^{(1)}(\theta_0) (g_0^{(0)})^2 + \dots & \scriptstyle{\text{; because }\theta_1-\theta_0=-\alpha g_0^{(0)}} \\
&= g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} + O(\alpha^2)
\end{aligned} %]]></script>
<p>Plug in the expanded form of <script type="math/tex">g_1^{(1)}</script> into the MAML gradients with one step inner gradient update:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
g_\text{FOMAML} &= g^{(1)}_1 = g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} + O(\alpha^2)\\
g_\text{MAML} &= g^{(1)}_1 - \alpha H^{(0)}_0 g^{(1)}_1 \\
&= g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} + O(\alpha^2) - \alpha H^{(0)}_0 (g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} + O(\alpha^2))\\
&= g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} - \alpha H^{(0)}_0 g_0^{(1)} + \alpha^2 \alpha H^{(0)}_0 H^{(1)}_0 g_0^{(0)} + O(\alpha^2)\\
&= g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} - \alpha H^{(0)}_0 g_0^{(1)} + O(\alpha^2)
\end{aligned} %]]></script>
<p>The Reptile gradient becomes:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
g_\text{Reptile}
&= g^{(0)}_0 + g^{(1)}_1 \\
&= g^{(0)}_0 + g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} + O(\alpha^2)
\end{aligned} %]]></script>
<p>So far we have the formula of three types of gradients:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
g_\text{FOMAML} &= g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} + O(\alpha^2)\\
g_\text{MAML} &= g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} - \alpha H^{(0)}_0 g_0^{(1)} + O(\alpha^2)\\
g_\text{Reptile} &= g^{(0)}_0 + g_0^{(1)} - \alpha H^{(1)}_0 g_0^{(0)} + O(\alpha^2)
\end{aligned} %]]></script>
<p>During training, we often average over multiple data batches. In our example, the mini batches (0) and (1) are interchangeable since both are drawn at random. The expectation <script type="math/tex">\mathbb{E}_{\tau,0,1}</script> is averaged over two data batches, ids (0) and (1), for task <script type="math/tex">\tau</script>.</p>
<p>Let,</p>
<ul>
<li><script type="math/tex">A = \mathbb{E}_{\tau,0,1} [g_0^{(0)}] = \mathbb{E}_{\tau,0,1} [g_0^{(1)}]</script>; it is the average gradient of task loss. We expect to improve the model parameter to achieve better task performance by following this direction pointed by <script type="math/tex">A</script>.</li>
<li><script type="math/tex">B = \mathbb{E}_{\tau,0,1} [H^{(1)}_0 g_0^{(0)}] = \frac{1}{2}\mathbb{E}_{\tau,0,1} [H^{(1)}_0 g_0^{(0)} + H^{(0)}_0 g_0^{(1)}] = \frac{1}{2}\mathbb{E}_{\tau,0,1} [\nabla_\theta(g^{(0)}_0 g_0^{(1)})]</script>; it is the direction (gradient) that increases the inner product of gradients of two different mini batches for the same task. We expect to improve the model parameter to achieve better generalization over different data by following this direction pointed by <script type="math/tex">B</script>.</li>
</ul>
<p>To conclude, both MAML and Reptile aim to optimize for the same goal, better task performance (guided by A) and better generalization (guided by B), when the gradient update is approximated by first three leading terms.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbb{E}_{\tau,1,2}[g_\text{FOMAML}] &= A - \alpha B + O(\alpha^2)\\
\mathbb{E}_{\tau,1,2}[g_\text{MAML}] &= A - 2\alpha B + O(\alpha^2)\\
\mathbb{E}_{\tau,1,2}[g_\text{Reptile}] &= 2A - \alpha B + O(\alpha^2)
\end{aligned} %]]></script>
<p>It is not clear to me whether the ignored term <script type="math/tex">O(\alpha^2)</script> might play a big impact on the parameter learning. But given that FOMAML is able to obtain a similar performance as the full version of MAML, it might be safe to say higher-level derivatives would not be critical during gradient descent update.</p>
<hr />
<p>Cited as:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2018metalearning,
title = "Meta-Learning: Learning to Learn Fast",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2018",
url = "http://lilianweng.github.io/lil-log/2018/11/29/meta-learning.html"
}
</code></pre></div></div>
<p><em>If you notice mistakes and errors in this post, don’t hesitate to leave a comment or contact me at [lilian dot wengweng at gmail dot com] and I would be very happy to correct them asap.</em></p>
<p>See you in the next post!</p>
<h2 id="reference">Reference</h2>
<p>[1] Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. <a href="https://www.cs.cmu.edu/~rsalakhu/papers/LakeEtAl2015Science.pdf">“Human-level concept learning through probabilistic program induction.”</a> Science 350.6266 (2015): 1332-1338.</p>
<p>[2] Oriol Vinyals’ talk on <a href="http://metalearning-symposium.ml/files/vinyals.pdf">“Model vs Optimization Meta Learning”</a></p>
<p>[3] Gregory Koch, Richard Zemel, and Ruslan Salakhutdinov. <a href="http://www.cs.toronto.edu/~rsalakhu/papers/oneshot1.pdf">“Siamese neural networks for one-shot image recognition.”</a> ICML Deep Learning Workshop. 2015.</p>
<p>[4] Oriol Vinyals, et al. <a href="http://papers.nips.cc/paper/6385-matching-networks-for-one-shot-learning.pdf">“Matching networks for one shot learning.”</a> NIPS. 2016.</p>
<p>[5] Flood Sung, et al. <a href="http://openaccess.thecvf.com/content_cvpr_2018/papers_backup/Sung_Learning_to_Compare_CVPR_2018_paper.pdf">“Learning to compare: Relation network for few-shot learning.”</a> CVPR. 2018.</p>
<p>[6] Jake Snell, Kevin Swersky, and Richard Zemel. <a href="http://papers.nips.cc/paper/6996-prototypical-networks-for-few-shot-learning.pdf">“Prototypical Networks for Few-shot Learning.”</a> CVPR. 2018.</p>
<p>[7] Adam Santoro, et al. <a href="http://proceedings.mlr.press/v48/santoro16.pdf">“Meta-learning with memory-augmented neural networks.”</a> ICML. 2016.</p>
<p>[8] Alex Graves, Greg Wayne, and Ivo Danihelka. <a href="https://arxiv.org/abs/1410.5401">“Neural turing machines.”</a> arXiv preprint arXiv:1410.5401 (2014).</p>
<p>[9] Tsendsuren Munkhdalai and Hong Yu. <a href="https://arxiv.org/abs/1703.00837">“Meta Networks.”</a> ICML. 2017.</p>
<p>[10] Sachin Ravi and Hugo Larochelle. <a href="https://openreview.net/pdf?id=rJY0-Kcll">“Optimization as a Model for Few-Shot Learning.”</a> ICLR. 2017.</p>
<p>[11] Chelsea Finn’s BAIR blog on <a href="https://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/">“Learning to Learn”</a>.</p>
<p>[12] Chelsea Finn, Pieter Abbeel, and Sergey Levine. <a href="https://arxiv.org/abs/1703.03400">“Model-agnostic meta-learning for fast adaptation of deep networks.”</a> ICML 2017.</p>
<p>[13] Alex Nichol, Joshua Achiam, John Schulman. <a href="https://arxiv.org/abs/1803.02999">“On First-Order Meta-Learning Algorithms.”</a> arXiv preprint arXiv:1803.02999 (2018).</p>
<p>[14] <a href="https://www.slideshare.net/YoonhoLee4/on-firstorder-metalearning-algorithms">Slides on Reptile</a> by Yoonho Lee.</p>Lilian WengMeta-learning, also known as “learning to learn”, intends to design models that can learn new skills or adapt to new environments rapidly with a few training examples. There are three common approaches: 1) learn an efficient distance metric (metric-based); 2) use (recurrent) network with external or internal memory (model-based); 3) optimize the model parameters explicitly for fast learning (optimization-based).Flow-based Deep Generative Models2018-10-13T12:15:00+00:002018-10-13T12:15:00+00:00https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models<blockquote>
<p>In this post, we are looking into the third type of generative models: flow-based generative models. Different from GAN and VAE, they explicitly learn the probability density function of the input data.</p>
</blockquote>
<!--more-->
<p>So far, I’ve written about two types of generative models, <a href="/lil-log/2017/08/20/from-GAN-to-WGAN.html">GAN</a> and <a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html">VAE</a>. Neither of them explicitly learns the probability density function of real data, <script type="math/tex">p(\mathbf{x})</script> (where <script type="math/tex">\mathbf{x} \in \mathcal{D}</script>) — because it is really hard! Taking the generative model with latent variables as an example, <script type="math/tex">p(\mathbf{x}) = \int p(\mathbf{x}\vert\mathbf{z})p(\mathbf{z})d\mathbf{z}</script> can hardly be calculated as it is intractable to go through all possible values of the latent code <script type="math/tex">\mathbf{z}</script>.</p>
<p>Flow-based deep generative models conquer this hard problem with the help of <a href="https://arxiv.org/abs/1505.05770">normalizing flows</a>, a powerful statistics tool for density estimation. A good estimation of <script type="math/tex">p(\mathbf{x})</script> makes it possible to efficiently complete many downstream tasks: sample unobserved but realistic new data points (data generation), predict the rareness of future events (density estimation), infer latent variables, fill in incomplete data samples, etc.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#types-of-generative-models" id="markdown-toc-types-of-generative-models">Types of Generative Models</a></li>
<li><a href="#linear-algebra-basics-recap" id="markdown-toc-linear-algebra-basics-recap">Linear Algebra Basics Recap</a> <ul>
<li><a href="#jacobian-matrix-and-determinant" id="markdown-toc-jacobian-matrix-and-determinant">Jacobian Matrix and Determinant</a></li>
<li><a href="#change-of-variable-theorem" id="markdown-toc-change-of-variable-theorem">Change of Variable Theorem</a></li>
</ul>
</li>
<li><a href="#what-is-normalizing-flows" id="markdown-toc-what-is-normalizing-flows">What is Normalizing Flows?</a></li>
<li><a href="#models-with-normalizing-flows" id="markdown-toc-models-with-normalizing-flows">Models with Normalizing Flows</a> <ul>
<li><a href="#realnvp" id="markdown-toc-realnvp">RealNVP</a></li>
<li><a href="#nice" id="markdown-toc-nice">NICE</a></li>
<li><a href="#glow" id="markdown-toc-glow">Glow</a></li>
</ul>
</li>
<li><a href="#models-with-autoregressive-flows" id="markdown-toc-models-with-autoregressive-flows">Models with Autoregressive Flows</a> <ul>
<li><a href="#made" id="markdown-toc-made">MADE</a></li>
<li><a href="#pixelrnn" id="markdown-toc-pixelrnn">PixelRNN</a></li>
<li><a href="#wavenet" id="markdown-toc-wavenet">WaveNet</a></li>
<li><a href="#masked-autoregressive-flow" id="markdown-toc-masked-autoregressive-flow">Masked Autoregressive Flow</a></li>
<li><a href="#inverse-autoregressive-flow" id="markdown-toc-inverse-autoregressive-flow">Inverse Autoregressive Flow</a></li>
</ul>
</li>
<li><a href="#vae--flows" id="markdown-toc-vae--flows">VAE + Flows</a></li>
<li><a href="#reference" id="markdown-toc-reference">Reference</a></li>
</ul>
<h2 id="types-of-generative-models">Types of Generative Models</h2>
<p>Here is a quick summary of the difference between GAN, VAE, and flow-based generative models:</p>
<ol>
<li>Generative adversarial networks: GAN provides a smart solution to model the data generation, an unsupervised learning problem, as a supervised one. The discriminator model learns to distinguish the real data from the fake samples that are produced by the generator model. Two models are trained as they are playing a <a href="https://en.wikipedia.org/wiki/Minimax">minimax</a> game.</li>
<li>Variational autoencoders: VAE inexplicitly optimizes the log-likelihood of the data by maximizing the evidence lower bound (ELBO).</li>
<li>Flow-based generative models: A flow-based generative model is constructed by a sequence of invertible transformations. Unlike other two, the model explicitly learns the data distribution <script type="math/tex">p(\mathbf{x})</script> and therefore the loss function is simply the negative log-likelihood.</li>
</ol>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/three-generative-models.png" alt="Categories of generative models" /></p>
<p><em>Fig. 1. Comparison of three categories of generative models.</em></p>
<h2 id="linear-algebra-basics-recap">Linear Algebra Basics Recap</h2>
<p>We should understand two key concepts before getting into the flow-based generative model: the Jacobian determinant and the change of variable rule. Pretty basic, so feel free to skip.</p>
<h3 id="jacobian-matrix-and-determinant">Jacobian Matrix and Determinant</h3>
<p>Given a function of mapping a <script type="math/tex">n</script>-dimensional input vector <script type="math/tex">\mathbf{x}</script> to a <script type="math/tex">m</script>-dimensional output vector, <script type="math/tex">\mathbf{f}: \mathbb{R}^n \mapsto \mathbb{R}^m</script>, the matrix of all first-order partial derivatives of this function is called the <strong>Jacobian matrix</strong>, <script type="math/tex">\mathbf{J}</script> where one entry on the i-th row and j-th column is <script type="math/tex">\mathbf{J}_{ij} = \frac{\partial f_i}{\partial x_j}</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\mathbf{J} = \begin{bmatrix}
\frac{\partial f_1}{\partial x_1} & \dots & \frac{\partial f_1}{\partial x_n} \\[6pt]
\vdots & \ddots & \vdots \\[6pt]
\frac{\partial f_m}{\partial x_1} & \dots & \frac{\partial f_m}{\partial x_n} \\[6pt]
\end{bmatrix} %]]></script>
<p>The determinant is one real number computed as a function of all the elements in a squared matrix. Note that the determinant <em>only exists for <strong>square</strong> matrices</em>. The absolute value of the determinant can be thought of as a measure of <em>“how much multiplication by the matrix expands or contracts space”.</em></p>
<p>The determinant of a nxn matrix <script type="math/tex">M</script> is:</p>
<script type="math/tex; mode=display">% <![CDATA[
\det M = \det \begin{bmatrix}
a_{11} & a_{12} & \dots & a_{1n} \\
a_{21} & a_{22} & \dots & a_{2n} \\
\vdots & \vdots & & \vdots \\
a_{n1} & a_{n2} & \dots & a_{nn} \\
\end{bmatrix} = \sum_{j_1 j_2 \dots j_n} (-1)^{\tau(j_1 j_2 \dots j_n)} a_{1j_1} a_{2j_2} \dots a_{nj_n} %]]></script>
<p>where the subscript under the summation <script type="math/tex">j_1 j_2 \dots j_n</script> are all permutations of the set {1, 2, …, n}, so there are <script type="math/tex">n!</script> items in total; <script type="math/tex">\tau(.)</script> indicates the <a href="https://en.wikipedia.org/wiki/Parity_of_a_permutation">signature</a> of a permutation.</p>
<p>The determinant of a square matrix <script type="math/tex">M</script> detects whether it is invertible: If <script type="math/tex">\det(M)=0</script> then <script type="math/tex">M</script> is not invertible (a <em>singular</em> matrix with linearly dependent rows or columns; or any row or column is all 0); otherwise, if <script type="math/tex">\det(M)\neq 0</script>, <script type="math/tex">M</script> is invertible.</p>
<p>The determinant of the product is equivalent to the product of the determinants: <script type="math/tex">\det(AB) = \det(A)\det(B)</script>. (<a href="https://proofwiki.org/wiki/Determinant_of_Matrix_Product">proof</a>)</p>
<h3 id="change-of-variable-theorem">Change of Variable Theorem</h3>
<p>Let’s review the change of variable theorem specifically in the context of probability density estimation, starting with a single variable case.</p>
<p>Given a random variable <script type="math/tex">z</script> and its known probability density function <script type="math/tex">z \sim \pi(z)</script>, we would like to construct a new random variable using a 1-1 mapping function <script type="math/tex">x = f(z)</script>. The function <script type="math/tex">f</script> is invertible, so <script type="math/tex">z=f^{-1}(x)</script>. Now the question is <em>how to infer the unknown probability density function of the new variable</em>, <script type="math/tex">p(x)</script>?</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
& \int p(x)dx = \int \pi(z)dz = 1 \scriptstyle{\text{ ; Definition of probability distribution.}}\\
& p(x) = \pi(z) \left\vert\frac{dz}{dx}\right\vert = \pi(f^{-1}(x)) \left\vert\frac{d f^{-1}}{dx}\right\vert = \pi(f^{-1}(x)) \vert (f^{-1})'(x) \vert
\end{aligned} %]]></script>
<p>By definition, the integral <script type="math/tex">\int \pi(z)dz</script> is the sum of an infinite number of rectangles of infinitesimal width <script type="math/tex">\Delta z</script>. The height of such a rectangle at position <script type="math/tex">z</script> is the value of the density function <script type="math/tex">\pi(z)</script>. When we substitute the variable, <script type="math/tex">z = f^{-1}(x)</script> yields <script type="math/tex">\frac{\Delta z}{\Delta x} = (f^{-1}(x))'</script> and <script type="math/tex">\Delta z = (f^{-1}(x))' \Delta x</script>. Here <script type="math/tex">\vert(f^{-1}(x))'\vert</script> indicates the ratio between the area of rectangles defined in two different coordinate of variables <script type="math/tex">z</script> and <script type="math/tex">x</script> respectively.</p>
<p>The multivariable version has a similar format:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{z} &\sim \pi(\mathbf{z}), \mathbf{x} = f(\mathbf{z}), \mathbf{z} = f^{-1}(\mathbf{x}) \\
p(\mathbf{x})
&= \pi(\mathbf{z}) \left\vert \det \dfrac{d \mathbf{z}}{d \mathbf{x}} \right\vert
= \pi(f^{-1}(\mathbf{x})) \left\vert \det \dfrac{d f^{-1}}{d \mathbf{x}} \right\vert
\end{aligned} %]]></script>
<p>where <script type="math/tex">\det \frac{\partial f}{\partial\mathbf{z}}</script> is the Jacobian determinant of the function <script type="math/tex">f</script>. The full proof of the multivariate version is out of the scope of this post; ask Google if interested ;)</p>
<h2 id="what-is-normalizing-flows">What is Normalizing Flows?</h2>
<p>Being able to do good density estimation has direct applications in many machine learning problems, but it is very hard. For example, since we need to run backward propagation in deep learning models, the embedded probability distribution (i.e. posterior <script type="math/tex">p(\mathbf{z})\vert\mathbf{x})</script>) is expected to be simple enough to calculate the derivative easily and efficiently. That is why Gaussian distribution is often used in latent variable generative models, even through most of real world distributions are much more complicated than Gaussian.</p>
<p>Here comes a <strong>Normalizing Flow</strong> (NF) model for better and more powerful distribution approximation. A normalizing flow transforms a simple distribution into a complex one by applying a sequence of invertible transformation functions. Flowing through a chain of transformations, we repeatedly substitute the variable for the new one according to the change of variables theorem and eventually obtain a probability distribution of the final target variable.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/normalizing-flow.png" alt="Normalizing flow" /></p>
<p><em>Fig. 2. Illustration of a normalizing flow model, transforming a simple distribution <script type="math/tex">p_0(\mathbf{z}_0)</script> to a complex one <script type="math/tex">p_K(\mathbf{z}_K)</script> step by step.</em></p>
<p>As defined in Fig. 2,</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{z}_{i-1} &\sim p_{i-1}(\mathbf{z}_{i-1}) \\
\mathbf{z}_i &= f_i(\mathbf{z}_{i-1})\text{, thus }\mathbf{z}_{i-1} = f_i^{-1}(\mathbf{z}_i) \\
p_i(\mathbf{z}_i)
&= p_{i-1}(f_i^{-1}(\mathbf{z}_i)) \left\vert \det\dfrac{d f_i^{-1}}{d \mathbf{z}_i} \right\vert
\end{aligned} %]]></script>
<p>Then let’s convert the equation to be a function of <script type="math/tex">\mathbf{z}_i</script> so that we can do inference with the base distribution.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
p_i(\mathbf{z}_i)
&= p_{i-1}(f_i^{-1}(\mathbf{z}_i)) \left\vert \det\dfrac{d f_i^{-1}}{d \mathbf{z}_i} \right\vert \\
&= p_{i-1}(\mathbf{z}_{i-1}) \left\vert \det \color{red}{\Big(\dfrac{d f_i}{d\mathbf{z}_{i-1}}\Big)^{-1}} \right\vert & \scriptstyle{\text{; According to the inverse func theorem.}} \\
&= p_{i-1}(\mathbf{z}_{i-1}) \color{red}{\left\vert \det \dfrac{d f_i}{d\mathbf{z}_{i-1}} \right\vert^{-1}} & \scriptstyle{\text{; According to a property of Jacobians of invertible func.}} \\
\log p_i(\mathbf{z}_i) &= \log p_{i-1}(\mathbf{z}_{i-1}) - \log \left\vert \det \dfrac{d f_i}{d\mathbf{z}_{i-1}} \right\vert
\end{aligned} %]]></script>
<p>(*) A note on the <em>“inverse function theorem”</em>: If <script type="math/tex">y=f(x)</script> and <script type="math/tex">x=f^{-1}(y)</script>, we have:</p>
<script type="math/tex; mode=display">\dfrac{df^{-1}(y)}{dy} = \dfrac{dx}{dy} = (\dfrac{dy}{dx})^{-1} = (\dfrac{df(x)}{dx})^{-1}</script>
<p>(*) A note on <em>“Jacobians of invertible function”</em>: The determinant of the inverse of an invertible matrix is the inverse of the determinant: <script type="math/tex">\det(M^{-1}) = (\det(M))^{-1}</script>, <a href="#jacobian-matrix-and-determinant">because</a> <script type="math/tex">\det(M)\det(M^{-1}) = \det(M \cdot M^{-1}) = \det(I) = 1</script>.</p>
<p>Given such a chain of probability density functions, we know the relationship between each pair of consecutive variables. We can expand the equation of the output <script type="math/tex">\mathbf{x}</script> step by step until tracing back to the initial distribution <script type="math/tex">\mathbf{z}_0</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{x} = \mathbf{z}_K &= f_K \circ f_{K-1} \circ \dots \circ f_1 (\mathbf{z}_0) \\
\log p(\mathbf{x}) = \log \pi_K(\mathbf{z}_K)
&= \log \pi_{K-1}(\mathbf{z}_{K-1}) - \log\left\vert\det\dfrac{d f_K}{d \mathbf{z}_{K-1}}\right\vert \\
&= \log \pi_{K-2}(\mathbf{z}_{K-2}) - \log\left\vert\det\dfrac{d f_{K-1}}{d\mathbf{z}_{K-2}}\right\vert - \log\left\vert\det\dfrac{d f_K}{d\mathbf{z}_{K-1}}\right\vert \\
&= \dots \\
&= \log \pi_0(\mathbf{z}_0) - \sum_{i=1}^K \log\left\vert\det\dfrac{d f_i}{d\mathbf{z}_{i-1}}\right\vert
\end{aligned} %]]></script>
<p>The path traversed by the random variables <script type="math/tex">\mathbf{z}_i = f_i(\mathbf{z}_{i-1})</script> is the <strong>flow</strong> and the full chain formed by the successive distributions <script type="math/tex">\pi_i</script> is called a <strong>normalizing flow</strong>. Required by the computation in the equation, a transformation function <script type="math/tex">f_i</script> should satisfy two properties:</p>
<ol>
<li>It is easily invertible.</li>
<li>Its Jacobian determinant is easy to compute.</li>
</ol>
<h2 id="models-with-normalizing-flows">Models with Normalizing Flows</h2>
<p>With normalizing flows in our toolbox, the exact log-likelihood of input data <script type="math/tex">\log p(\mathbf{x})</script> becomes tractable. As a result, the training criterion of flow-based generative model is simply the negative log-likelihood (NLL) over the training dataset <script type="math/tex">\mathcal{D}</script>:</p>
<script type="math/tex; mode=display">\mathcal{L}(\mathcal{D}) = - \frac{1}{\vert\mathcal{D}\vert}\sum_{\mathbf{x} \in \mathcal{D}} \log p(\mathbf{x})</script>
<h3 id="realnvp">RealNVP</h3>
<p>The <strong>RealNVP</strong> (Real-valued Non-Volume Preserving; <a href="https://arxiv.org/abs/1605.08803">Dinh et al., 2017</a>) model implements a normalizing flow by stacking a sequence of invertible bijective transformation functions. In each bijection <script type="math/tex">f: \mathbf{x} \mapsto \mathbf{y}</script>, known as <em>affine coupling layer</em>, the input dimensions are split into two parts:</p>
<ul>
<li>The first <script type="math/tex">d</script> dimensions stay same;</li>
<li>The second part, <script type="math/tex">d+1</script> to <script type="math/tex">D</script> dimensions, undergo an affine transformation (“scale-and-shift”) and both the scale and shift parameters are functions of the first <script type="math/tex">d</script> dimensions.</li>
</ul>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{y}_{1:d} &= \mathbf{x}_{1:d} \\
\mathbf{y}_{d+1:D} &= \mathbf{x}_{d+1:D} \odot \exp({s(\mathbf{x}_{1:d})}) + t(\mathbf{x}_{1:d})
\end{aligned} %]]></script>
<p>where <script type="math/tex">s(.)</script> and <script type="math/tex">t(.)</script> are <em>scale</em> and <em>translation</em> functions and both map <script type="math/tex">\mathbb{R}^d \mapsto \mathbb{R}^{D-d}</script>. The <script type="math/tex">\odot</script> operation is the element-wise product.</p>
<p>Now let’s check whether this transformation satisfy two basic properties for a flow transformation.</p>
<p><strong>Condition 1</strong>: “It is easily invertible.”</p>
<p>Yes and it is fairly straightforward.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{cases}
\mathbf{y}_{1:d} &= \mathbf{x}_{1:d} \\
\mathbf{y}_{d+1:D} &= \mathbf{x}_{d+1:D} \odot \exp({s(\mathbf{x}_{1:d})}) + t(\mathbf{x}_{1:d})
\end{cases}
\Leftrightarrow
\begin{cases}
\mathbf{x}_{1:d} &= \mathbf{y}_{1:d} \\
\mathbf{x}_{d+1:D} &= (\mathbf{y}_{d+1:D} - t(\mathbf{y}_{1:d})) \odot \exp(-s(\mathbf{y}_{1:d}))
\end{cases} %]]></script>
<p><strong>Condition 2</strong>: “Its Jacobian determinant is easy to compute.”</p>
<p>Yes. It is not hard to get the Jacobian matrix and determinant of this transformation. The Jacobian is a lower triangular matrix.</p>
<script type="math/tex; mode=display">% <![CDATA[
\mathbf{J} =
\begin{bmatrix}
\mathbb{I}_d & \mathbf{0}_{d\times(D-d)} \\[5pt]
\frac{\partial \mathbf{y}_{d+1:D}}{\partial \mathbf{x}_{1:d}} & \text{diag}(\exp(s(\mathbf{x}_{1:d})))
\end{bmatrix} %]]></script>
<p>Hence the determinant is simply the product of terms on the diagonal.</p>
<script type="math/tex; mode=display">\det(\mathbf{J})
= \prod_{j=1}^{D-d}\exp(s(\mathbf{x}_{1:d}))_j
= \exp(\sum_{j=1}^{D-d} s(\mathbf{x}_{1:d})_j)</script>
<p>So far, the affine coupling layer looks perfect for constructing a normalizing flow :)</p>
<p>Even better, since (i) computing <script type="math/tex">f^-1</script> does not require computing the inverse of <script type="math/tex">s</script> or <script type="math/tex">t</script> and (ii) computing the Jacobian determinant does not involve computing the Jacobian of <script type="math/tex">s</script> or <script type="math/tex">t</script>, those functions can be <em>arbitrarily complex</em>; i.e. both <script type="math/tex">s</script> and <script type="math/tex">t</script> can be modeled by deep neural networks.</p>
<p>In one affine coupling layer, some dimensions (channels) remain unchanged. To make sure all the inputs have a chance to be altered, the model reverses the ordering in each layer so that different components are left unchanged. Following such an alternating pattern, the set of units which remain identical in one transformation layer are always modified in the next. Batch normalization is found to help training models with a very deep stack of coupling layers.</p>
<p>Furthermore, RealNVP can work in a multi-scale architecture to build a more efficient model for large inputs. The multi-scale architecture applies several “sampling” operations to normal affine layers, including spatial checkerboard pattern masking, squeezing operation, and channel-wise masking. Read the <a href="https://arxiv.org/abs/1605.08803">paper</a> for more details on the multi-scale architecture.</p>
<h3 id="nice">NICE</h3>
<p>The <strong>NICE</strong> (Non-linear Independent Component Estimation; <a href="https://arxiv.org/abs/1410.8516">Dinh, et al. 2015</a>) model is a predecessor of <a href="#realnvp">RealNVP</a>. The transformation in NICE is the affine coupling layer without the scale term, known as <em>additive coupling layer</em>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{cases}
\mathbf{y}_{1:d} &= \mathbf{x}_{1:d} \\
\mathbf{y}_{d+1:D} &= \mathbf{x}_{d+1:D} + m(\mathbf{x}_{1:d})
\end{cases}
\Leftrightarrow
\begin{cases}
\mathbf{x}_{1:d} &= \mathbf{y}_{1:d} \\
\mathbf{x}_{d+1:D} &= \mathbf{y}_{d+1:D} - m(\mathbf{y}_{1:d})
\end{cases} %]]></script>
<h3 id="glow">Glow</h3>
<p>The <strong>Glow</strong> (<a href="https://arxiv.org/abs/1807.03039">Kingma and Dhariwal, 2018</a>) model extends the previous reversible generative models, NICE and RealNVP, and simplifies the architecture by replacing the reverse permutation operation on the channel ordering with invertible 1x1 convolutions.</p>
<p style="width: 45%;" class="center"><img src="/lil-log/assets/images/one-glow-step.png" alt="Glow step" /></p>
<p><em>Fig. 3. One step of flow in the Glow model. (Image source: <a href="https://arxiv.org/abs/1807.03039">Kingma and Dhariwal, 2018</a>)</em></p>
<p>There are three substeps in one step of flow in Glow.</p>
<p>Substep 1: <strong>Activation normalization</strong> (short for “actnorm”)</p>
<p>It performs an affine transformation using a scale and bias parameter per channel, similar to batch normalization, but works for mini-batch size 1. The parameters are trainable but initialized so that the first minibatch of data have mean 0 and standard deviation 1 after actnorm.</p>
<p>Substep 2: <strong>Invertible 1x1 conv</strong></p>
<p>Between layers of the RealNVP flow, the ordering of channels is reversed so that all the data dimensions have a chance to be altered. A 1×1 convolution with equal number of input and output channels is <em>a generalization of any permutation</em> of the channel ordering.</p>
<p>Say, we have an invertible 1x1 convolution of an input <script type="math/tex">h \times w \times c</script> tensor <script type="math/tex">\mathbf{h}</script> with a weight matrix <script type="math/tex">\mathbf{W}</script> of size <script type="math/tex">c \times c</script>. The output is a <script type="math/tex">h \times w \times c</script> tensor, labeled as <script type="math/tex">f = \texttt{conv2d}(\mathbf{h}; \mathbf{W})</script>. In order to apply the change of variable rule, we need to compute the Jacobian determinant <script type="math/tex">\vert \det\partial f / \partial\mathbf{h}\vert</script>.</p>
<p>Both the input and output of 1x1 convolution here can be viewed as a matrix of size <script type="math/tex">h \times w</script>. Each entry <script type="math/tex">\mathbf{x}_{ij}</script> (<script type="math/tex">i=1,\dots,h, j=1,\dots,w</script>) in <script type="math/tex">\mathbf{h}</script> is a vector of <script type="math/tex">c</script> channels and each entry is multiplied by the weight matrix <script type="math/tex">\mathbf{W}</script> to obtain the corresponding entry <script type="math/tex">\mathbf{y}_{ij}</script> in the output matrix respectively. The derivative of each entry is <script type="math/tex">\partial \mathbf{x}_{ij} \mathbf{W} / \partial\mathbf{x}_{ij} = \mathbf{W}</script> and there are <script type="math/tex">h \times w</script> such entries in total:</p>
<script type="math/tex; mode=display">\log \left\vert\det \frac{\partial\texttt{conv2d}(\mathbf{h}; \mathbf{W})}{\partial\mathbf{h}}\right\vert
= \log (\vert\det\mathbf{W}\vert^{h \cdot w}\vert) = h \cdot w \cdot \log \vert\det\mathbf{W}\vert</script>
<p>The inverse 1x1 convolution depends on the inverse matrix <script type="math/tex">\mathbf{W}^{-1}</script>. Since the weight matrix is relatively small, the amount of computation for the matrix determinant (<a href="https://www.tensorflow.org/api_docs/python/tf/linalg/det">tf.linalg.det</a>) and inversion (<a href="https://www.tensorflow.org/api_docs/python/tf/linalg/inv">tf.linalg.inv</a>) is still under control.</p>
<p>Substep 3: <strong>Affine coupling layer</strong></p>
<p>The design is same as in RealNVP.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/glow-table.png" alt="Glow three substeps" /></p>
<p><em>Fig. 4. Three substeps in one step of flow in Glow. (Image source: <a href="https://arxiv.org/abs/1807.03039">Kingma and Dhariwal, 2018</a>)</em></p>
<h2 id="models-with-autoregressive-flows">Models with Autoregressive Flows</h2>
<p>The <strong>autoregressive</strong> constraint is a way to model sequential data, <script type="math/tex">\mathbf{x} = [x_1, \dots, x_D]</script>: each output only depends on the data observed in the past, but not on the future ones. In other words, the probability of observing <script type="math/tex">x_i</script> is conditioned on <script type="math/tex">x_1, \dots, x_{i-1}</script> and the product of these conditional probabilities gives us the probability of observing the full sequence:</p>
<script type="math/tex; mode=display">p(\mathbf{x}) = \prod_{i=1}^{D} p(x_i\vert x_1, \dots, x_{i-1}) = \prod_{i=1}^{D} p(x_i\vert x_{1:i-1})</script>
<p>How to model the conditional density is of your choice. It can be a univariate Gaussian with mean and standard deviation computed as a function of <script type="math/tex">x_{1:i-1}</script>, or a multilayer neural network with <script type="math/tex">x_{1:i-1}</script> as the input.</p>
<p>If a flow transformation in a normalizing flow is framed as an autoregressive model — each dimension in a vector variable is conditioned on the previous dimensions — this is an <strong>autoregressive flow</strong>.</p>
<h3 id="made">MADE</h3>
<p><strong>MADE</strong> (Masked Autoencoder for Distribution Estimation; <a href="https://arxiv.org/abs/1502.03509">Germain et al., 2015</a>) is a specially designed architecture to enforce the autoregressive property in the autoencoder <em>efficiently</em>. When using an autoencoder to predict the conditional probabilities, rather than feeding the autoencoder with input of different observation windows <script type="math/tex">D</script> times, MADE removes the contribution from certain hidden units by multiplying binary mask matrices so that each input dimension is reconstructed only from previous dimensions in a <em>given</em> ordering in a <em>single pass</em>.</p>
<p>In a multilayer fully-connected neural network, say, we have <script type="math/tex">L</script> hidden layers with weight matrices <script type="math/tex">\mathbf{W}^1, \dots, \mathbf{W}^L</script> and an output layer with weight matrix <script type="math/tex">\mathbf{V}</script>. The output <script type="math/tex">\hat{\mathbf{x}}</script> has each dimension <script type="math/tex">\hat{x}_i = p(x_i\vert x_{1:i-1})</script>.</p>
<p>Without any mask, the computation through layers looks like the following:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{h}^0 &= \mathbf{x} \\
\mathbf{h}^l &= \text{activation}^l(\mathbf{W}^l\mathbf{h}^{l-1} + \mathbf{b}^l) \\
\hat{\mathbf{x}} &= \sigma(\mathbf{V}\mathbf{h}^L + \mathbf{c})
\end{aligned} %]]></script>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/MADE.png" alt="MADE" /></p>
<p><em>Fig. 5. Demonstration of how MADE works in a three-layer feed-forward neural network. (Image source: <a href="https://arxiv.org/abs/1502.03509">Germain et al., 2015</a>)</em></p>
<p>To zero out some connections between layers, we can simply element-wise multiply every weight matrix by a binary mask matrix. Each hidden node is assigned with a random “connectivity integer” between 1 and D-1; the assigned value for the <script type="math/tex">k</script>-th unit in the <script type="math/tex">l</script>-th layer is denoted by <script type="math/tex">m^l_k</script>. The binary mask matrix is determined by element-wise comparing values of two nodes in two layers.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{h}^l &= \text{activation}^l((\mathbf{W}^l \color{red}{\odot \mathbf{M}^{\mathbf{W}^l}}) \mathbf{h}^{l-1} + \mathbf{b}^l) \\
\hat{\mathbf{x}} &= \sigma((\mathbf{V} \color{red}{\odot \mathbf{M}^{\mathbf{V}}}) \mathbf{h}^L + \mathbf{c}) \\
M^{\mathbf{W}^l}_{k', k}
&= \mathbf{1}_{m^l_{k'} \geq m^{l-1}_k}
= \begin{cases}
1, & \text{if } m^l_{k'} \geq m^{l-1}_k\\
0, & \text{otherwise}
\end{cases} \\
M^{\mathbf{V}}_{d, k}
&= \mathbf{1}_{d \geq m^L_k}
= \begin{cases}
1, & \text{if } d \geq m^L_k\\
0, & \text{otherwise}
\end{cases}
\end{aligned} %]]></script>
<p>A unit in the current layer can only be connected to other units with equal or smaller numbers in the previous layer and this type of dependency easily propagates through the network up to the output layer. Once the numbers are assigned to all the units and layers, the ordering of input dimensions is fixed and the conditional probability is produced with respect to it. See a great illustration in Fig. 5. To make sure all the hidden units are connected to the input and output layers through some paths, the <script type="math/tex">m^l_k</script> is sampled to be equal or greater than the minimal connectivity integer in the previous layer, <script type="math/tex">\min_{k'} m_{k'}^{l-1}</script>.</p>
<p>MADE training can be further facilitated by:</p>
<ul>
<li><em>Order-agnostic training</em>: shuffle the input dimensions, so that MADE is able to model any arbitrary ordering; can create an ensemble of autoregressive models at the runtime.</li>
<li><em>Connectivity-agnostic training</em>: to avoid a model being tied up to a specific connectivity pattern constraints, resample <script type="math/tex">m^l_k</script> for each training minibatch.</li>
</ul>
<h3 id="pixelrnn">PixelRNN</h3>
<p>PixelRNN (<a href="https://arxiv.org/abs/1601.06759">Oord et al, 2016</a>) is a deep generative model for images. The image is generated one pixel at a time and each new pixel is sampled conditional on the pixels that have been seen before.</p>
<p>Let’s consider an image of size <script type="math/tex">n \times n</script>, <script type="math/tex">\mathbf{x} = \{x_1, \dots, x_{n^2}\}</script>, the model starts generating pixels from the top left corner, from left to right and top to bottom (See Fig. 6).</p>
<p style="width: 30%;" class="center"><img src="/lil-log/assets/images/pixel-rnn-context.png" alt="Context in PixelRNN" /></p>
<p><em>Fig. 6. The context for generating one pixel in PixelRNN. (Image source: <a href="https://arxiv.org/abs/1601.06759">Oord et al, 2016</a></em></p>
<p>Every pixel <script type="math/tex">x_i</script> is sampled from a probability distribution conditional over the the past context: pixels above it or on the left of it when in the same row. The definition of such context looks pretty arbitrary, because how visual <a href="/lil-log/2018/06/24/attention-attention.html">attention</a> is attended to an image is more flexible. Somehow magically a generative model with such a strong assumption works.</p>
<p>One implementation that could capture the entire context is the <em>Diagonal BiLSTM</em>. First, apply the <strong>skewing</strong> operation by offsetting each row of the input feature map by one position with respect to the previous row, so that computation for each row can be parallelized. Then the LSTM states are computed with respect to the current pixel and the pixels on the left.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/diagonal-biLSTM.png" alt="Diagonal BiLSTM" /></p>
<p><em>Fig. 7. (a) PixelRNN with diagonal BiLSTM. (b) Skewing operation that offsets each row in the feature map by one with regards to the row above. (Image source: <a href="https://arxiv.org/abs/1601.06759">Oord et al, 2016</a></em></p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\lbrack \mathbf{o}_i, \mathbf{f}_i, \mathbf{i}_i, \mathbf{g}_i \rbrack &= \sigma(\mathbf{K}^{ss} \circledast \mathbf{h}_{i-1} + \mathbf{K}^{is} \circledast \mathbf{x}_i) & \scriptstyle{\text{; }\sigma\scriptstyle{\text{ is tanh for g, but otherwise sigmoid; }}\circledast\scriptstyle{\text{ is convolution operation.}}} \\
\mathbf{c}_i &= \mathbf{f}_i \odot \mathbf{c}_{i-1} + \mathbf{i}_i \odot \mathbf{g}_i & \scriptstyle{\text{; }}\odot\scriptstyle{\text{ is elementwise product.}}\\
\mathbf{h}_i &= \mathbf{o}_i \odot \tanh(\mathbf{c}_i)
\end{aligned} %]]></script>
<p>where <script type="math/tex">\circledast</script> denotes the convolution operation and <script type="math/tex">\odot</script> is the element-wise multiplication. The input-to-state component <script type="math/tex">\mathbf{K}^{is}</script> is a 1x1 convolution, while the state-to-state recurrent component is computed with a column-wise convolution <script type="math/tex">\mathbf{K}^{ss}</script> with a kernel of size 2x1.</p>
<p>The diagonal BiLSTM layers are capable of processing an unbounded context field, but expensive to compute due to the sequential dependency between states. A faster implementation uses multiple convolutional layers without pooling to define a bounded context box. The convolution kernel is masked so that the future context is not seen, similar to <a href="#MADE">MADE</a>. This convolution version is called <strong>PixelCNN</strong>.</p>
<p style="width: 50%;" class="center"><img src="/lil-log/assets/images/pixel-cnn.png" alt="PixelCNN" /></p>
<p><em>Fig. 8. PixelCNN with masked convolution constructed by an elementwise product of a mask tensor and the convolution kernel before applying it. (Image source: http://slazebni.cs.illinois.edu/spring17/lec13_advanced.pdf)</em></p>
<h3 id="wavenet">WaveNet</h3>
<p><strong>WaveNet</strong> (<a href="https://arxiv.org/abs/1609.03499">Van Den Oord, et al. 2016</a>) is very similar to PixelCNN but applied to 1-D audio signals. WaveNet consists of a stack of <em>causal convolution</em> which is a convolution operation designed to respect the ordering: the prediction at a certain timestamp can only consume the data observed in the past, no dependency on the future. In PixelCNN, the causal convolution is implemented by masked convolution kernel. The causal convolution in WaveNet is simply to shift the output by a number of timestamps to the future so that the output is aligned with the last input element.</p>
<p>One big drawback of convolution layer is a very limited size of receptive field. The output can hardly depend on the input hundreds or thousands of timesteps ago, which can be a crucial requirement for modeling long sequences. WaveNet therefore adopts <em>dilated convolution</em> (<a href="https://github.com/vdumoulin/conv_arithmetic#dilated-convolution-animations">animation</a>), where the kernel is applied to an evenly-distributed subset of samples in a much larger receptive field of the input.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/wavenet.png" alt="WaveNet" /></p>
<p><em>Fig. 9. Visualization of WaveNet models with a stack of (top) causal convolution layers and (bottom) dilated convolution layers. (Image source: <a href="https://arxiv.org/abs/1609.03499">Van Den Oord, et al. 2016</a>)</em></p>
<p>WaveNet uses the gated activation unit as the non-linear layer, as it is found to work significantly better than ReLU for modeling 1-D audio data. The residual connection is applied after the gated activation.</p>
<script type="math/tex; mode=display">\mathbf{z} = \tanh(\mathbf{W}_{f,k}\circledast\mathbf{x})\odot\sigma(\mathbf{W}_{g,k}\circledast\mathbf{x})</script>
<p>where <script type="math/tex">\mathbf{W}_{f,k}</script> and <script type="math/tex">\mathbf{W}_{g,k}</script> are convolution filter and gate weight matrix of the <script type="math/tex">k</script>-th layer, respectively; both are learnable.</p>
<h3 id="masked-autoregressive-flow">Masked Autoregressive Flow</h3>
<p><strong>Masked Autoregressive Flow</strong> (<strong>MAF</strong>; <a href="https://arxiv.org/abs/1705.07057">Papamakarios et al., 2017</a>) is a type of normalizing flows, where the transformation layer is built as an autoregressive neural network. MAF is very similar to <strong>Inverse Autoregressive Flow</strong> (IAF) introduced later. See more discussion on the relationship between MAF and IAF in the next section.</p>
<p>Given two random variables, <script type="math/tex">\mathbf{z} \sim \pi(\mathbf{z})</script> and <script type="math/tex">\mathbf{x} \sim p(\mathbf{x})</script> and the probability density function <script type="math/tex">\pi(\mathbf{z})</script> is known, MAF aims to learn <script type="math/tex">p(\mathbf{x})</script>. MAF generates each <script type="math/tex">x_i</script> conditioned on the past dimensions <script type="math/tex">\mathbf{x}_{1:i-1}</script>.</p>
<p>Precisely the conditional probability is an affine transformation of <script type="math/tex">\mathbf{z}</script>, where the scale and shift terms are functions of the observed part of <script type="math/tex">\mathbf{x}</script>.</p>
<ul>
<li>Data generation, producing a new <script type="math/tex">\mathbf{x}</script>:</li>
</ul>
<script type="math/tex; mode=display">x_i \sim p(x_i\vert\mathbf{x}_{1:i-1}) = z_i \odot \sigma_i(\mathbf{x}_{1:i-1}) + \mu_i(\mathbf{x}_{1:i-1})\text{, where }\mathbf{z} \sim \pi(\mathbf{z})</script>
<ul>
<li>Density estimation, given a known <script type="math/tex">\mathbf{x}</script>:</li>
</ul>
<script type="math/tex; mode=display">p(\mathbf{x}) = \prod_{i=1}^D p(x_i\vert\mathbf{x}_{1:i-1})</script>
<p>The generation procedure is sequential, so it is slow by design. While density estimation only needs one pass the network using architecture like <a href="#MADE">MADE</a>. The transformation function is trivial to inverse and the Jacobian determinant is easy to compute too.</p>
<h3 id="inverse-autoregressive-flow">Inverse Autoregressive Flow</h3>
<p>Similar to MAF, <strong>Inverse autoregressive flow</strong> (<strong>IAF</strong>; <a href="https://arxiv.org/abs/1606.04934">Kingma et al., 2016</a>) models the conditional probability of the target variable as an autoregressive model too, but with a reversed flow, thus achieving a much efficient sampling process.</p>
<p>First, let’s reverse the affine transformation in MAF:</p>
<script type="math/tex; mode=display">z_i = \frac{x_i - \mu_i(\mathbf{x}_{1:i-1})}{\sigma_i(\mathbf{x}_{1:i-1})} = -\frac{\mu_i(\mathbf{x}_{1:i-1})}{\sigma_i(\mathbf{x}_{1:i-1})} + x_i \odot \frac{1}{\sigma_i(\mathbf{x}_{1:i-1})}</script>
<p>If let:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
& \tilde{\mathbf{x}} = \mathbf{z}\text{, }\tilde{p}(.) = \pi(.)\text{, }\tilde{\mathbf{x}} \sim \tilde{p}(\tilde{\mathbf{x}}) \\
& \tilde{\mathbf{z}} = \mathbf{x} \text{, }\tilde{\pi}(.) = p(.)\text{, }\tilde{\mathbf{z}} \sim \tilde{\pi}(\tilde{\mathbf{z}})\\
& \tilde{\mu}_i(\tilde{\mathbf{z}}_{1:i-1}) = \tilde{\mu}_i(\mathbf{x}_{1:i-1}) = -\frac{\mu_i(\mathbf{x}_{1:i-1})}{\sigma_i(\mathbf{x}_{1:i-1})} \\
& \tilde{\sigma}(\tilde{\mathbf{z}}_{1:i-1}) = \tilde{\sigma}(\mathbf{x}_{1:i-1}) = \frac{1}{\sigma_i(\mathbf{x}_{1:i-1})}
\end{aligned} %]]></script>
<p>Then we would have,</p>
<script type="math/tex; mode=display">\tilde{x}_i \sim p(\tilde{x}_i\vert\tilde{\mathbf{z}}_{1:i}) = \tilde{z}_i \odot \tilde{\sigma}_i(\tilde{\mathbf{z}}_{1:i-1}) + \tilde{\mu}_i(\tilde{\mathbf{z}}_{1:i-1})
\text{, where }\tilde{\mathbf{z}} \sim \tilde{\pi}(\tilde{\mathbf{z}})</script>
<p>IAF intends to estimate the probability density function of <script type="math/tex">\tilde{\mathbf{x}}</script> given that <script type="math/tex">\tilde{\pi}(\tilde{\mathbf{z}})</script> is already known. The inverse flow is an autoregressive affine transformation too, same as in MAF, but the scale and shift terms are autoregressive functions of observed variables from the known distribution <script type="math/tex">\tilde{\pi}(\tilde{\mathbf{z}})</script>. See the comparison between MAF and IAF in Fig. 10.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/MAF-vs-IAF.png" alt="MAF and IAF" /></p>
<p><em>Fig. 10. Comparison of MAF and IAF. The variable with known density is in green while the unknown one is in red.</em></p>
<p>Computations of the individual elements <script type="math/tex">\tilde{x}_i</script> do not depend on each other, so they are easily parallelizable (only one pass using MADE). The density estimation for a known <script type="math/tex">\tilde{\mathbf{x}}</script> is not efficient, because we have to recover the value of <script type="math/tex">\tilde{z}_i</script> in a sequential order, <script type="math/tex">\tilde{z}_i = (\tilde{x}_i - \tilde{\mu}_i(\tilde{\mathbf{z}}_{1:i-1})) / \tilde{\sigma}_i(\tilde{\mathbf{z}}_{1:i-1})</script>, thus D times in total.</p>
<table class="info">
<thead>
<tr>
<th> </th>
<th>Base distribution</th>
<th>Target distribution</th>
<th>Model</th>
<th>Data generation</th>
<th>Density estimation</th>
</tr>
</thead>
<tbody>
<tr>
<td>MAF</td>
<td><script type="math/tex">\mathbf{z}\sim\pi(\mathbf{z})</script></td>
<td><script type="math/tex">\mathbf{x}\sim p(\mathbf{x})</script></td>
<td><script type="math/tex">x_i = z_i \odot \sigma_i(\mathbf{x}_{1:i-1}) + \mu_i(\mathbf{x}_{1:i-1})</script></td>
<td>Sequential; slow</td>
<td>One pass; fast</td>
</tr>
<tr>
<td>IAF</td>
<td><script type="math/tex">\tilde{\mathbf{z}}\sim\tilde{\pi}(\tilde{\mathbf{z}})</script></td>
<td><script type="math/tex">\tilde{\mathbf{x}}\sim\tilde{p}(\tilde{\mathbf{x}})</script></td>
<td><script type="math/tex">\tilde{x}_i = \tilde{z}_i \odot \tilde{\sigma}_i(\tilde{\mathbf{z}}_{1:i-1}) + \tilde{\mu}_i(\tilde{\mathbf{z}}_{1:i-1})</script></td>
<td>One pass; fast</td>
<td>Sequential; slow</td>
</tr>
</tbody>
</table>
<h2 id="vae--flows">VAE + Flows</h2>
<p>In <a href="/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#vae-variational-autoencoder">Variational Autoencoder</a>, if we want to model the posterior <script type="math/tex">p(\mathbf{z}\vert\mathbf{x})</script> as a more complicated distribution rather than simple Gaussian. Intuitively we can use normalizing flow to transform the base Gaussian for better density approximation. The encoder then would predict a set of scale and shift terms <script type="math/tex">(\mu_i, \sigma_i)</script> which are all functions of input <script type="math/tex">\mathbf{x}</script>. Read the <a href="https://arxiv.org/abs/1809.05861">paper</a> for more details if interested.</p>
<hr />
<p><em>If you notice mistakes and errors in this post, don’t hesitate to contact me at [lilian dot wengweng at gmail dot com] and I would be very happy to correct them right away!</em></p>
<p>See you in the next post :D</p>
<hr />
<p>Cited as:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2018flow,
title = "Flow-based Deep Generative Models",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2018",
url = "http://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html"
}
</code></pre></div></div>
<h2 id="reference">Reference</h2>
<p>[1] Danilo Jimenez Rezende, and Shakir Mohamed. <a href="https://arxiv.org/abs/1505.05770">“Variational inference with normalizing flows.”</a> ICML 2015.</p>
<p>[2] <a href="https://blog.evjang.com/2018/01/nf1.html">Normalizing Flows Tutorial, Part 1: Distributions and Determinants</a> by Eric Jang.</p>
<p>[3] <a href="https://blog.evjang.com/2018/01/nf2.html">Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows</a> by Eric Jang.</p>
<p>[4] <a href="http://akosiorek.github.io/ml/2018/04/03/norm_flows.html">Normalizing Flows</a> by Adam Kosiorek.</p>
<p>[5] Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. <a href="https://arxiv.org/abs/1605.08803">“Density estimation using Real NVP.”</a> ICLR 2017.</p>
<p>[6] Laurent Dinh, David Krueger, and Yoshua Bengio. <a href="https://arxiv.org/abs/1410.8516">“NICE: Non-linear independent components estimation.”</a> ICLR 2015 Workshop track.</p>
<p>[7] Diederik P. Kingma, and Prafulla Dhariwal. <a href="https://arxiv.org/abs/1807.03039">“Glow: Generative flow with invertible 1x1 convolutions.”</a> arXiv:1807.03039 (2018).</p>
<p>[8] Germain, Mathieu, Karol Gregor, Iain Murray, and Hugo Larochelle. <a href="https://arxiv.org/abs/1502.03509">“Made: Masked autoencoder for distribution estimation.”</a> ICML 2015.</p>
<p>[9] Aaron van den Oord, Nal Kalchbrenner, and Koray Kavukcuoglu. <a href="https://arxiv.org/abs/1601.06759">“Pixel recurrent neural networks.”</a> ICML 2016.</p>
<p>[10] Diederik P. Kingma, et al. <a href="https://arxiv.org/abs/1606.04934">“Improved variational inference with inverse autoregressive flow.”</a> NIPS. 2016.</p>
<p>[11] George Papamakarios, Iain Murray, and Theo Pavlakou. <a href="https://arxiv.org/abs/1705.07057">“Masked autoregressive flow for density estimation.”</a> NIPS 2017.</p>
<p>[12] Jianlin Su, and Guang Wu. <a href="https://arxiv.org/abs/1809.05861">“f-VAEs: Improve VAEs with Conditional Flows.”</a> arXiv:1809.05861 (2018).</p>
<p>[13] Van Den Oord, Aaron, et al. <a href="https://arxiv.org/abs/1609.03499">“WaveNet: A generative model for raw audio.”</a> SSW. 2016.</p>Lilian WengIn this post, we are looking into the third type of generative models: flow-based generative models. Different from GAN and VAE, they explicitly learn the probability density function of the input data.From Autoencoder to Beta-VAE2018-08-12T10:18:00+00:002018-08-12T10:18:00+00:00https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae<blockquote>
<p>Autocoders are a family of neural network models aiming to learn compressed latent variables of high-dimensional data. Starting from the basic autocoder model, this post reviews several variations, including denoising, sparse, and contractive autoencoders, and then Variational Autoencoder (VAE) and its modification beta-VAE.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2019-07-18: add a section on <a href="#vq-vae-and-vq-vae-2">VQ-VAE & VQ-VAE-2</a>.]</span>
<br />
<span style="color: #286ee0;">[Updated on 2019-07-26: add a section on <a href="#td-vae">TD-VAE</a>.]</span>
<br /></p>
<p>Autocoder is invented to reconstruct high-dimensional data using a neural network model with a narrow bottleneck layer in the middle (oops, this is probably not true for <a href="#vae-variational-autoencoder">Variational Autoencoder</a>, and we will investigate it in details in later sections). A nice byproduct is dimension reduction: the bottleneck layer captures a compressed latent encoding. Such a low-dimensional representation can be used as en embedding vector in various applications (i.e. search), help data compression, or reveal the underlying data generative factors.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#notation" id="markdown-toc-notation">Notation</a></li>
<li><a href="#autoencoder" id="markdown-toc-autoencoder">Autoencoder</a></li>
<li><a href="#denoising-autoencoder" id="markdown-toc-denoising-autoencoder">Denoising Autoencoder</a></li>
<li><a href="#sparse-autoencoder" id="markdown-toc-sparse-autoencoder">Sparse Autoencoder</a></li>
<li><a href="#contractive-autoencoder" id="markdown-toc-contractive-autoencoder">Contractive Autoencoder</a></li>
<li><a href="#vae-variational-autoencoder" id="markdown-toc-vae-variational-autoencoder">VAE: Variational Autoencoder</a> <ul>
<li><a href="#loss-function-elbo" id="markdown-toc-loss-function-elbo">Loss Function: ELBO</a></li>
<li><a href="#reparameterization-trick" id="markdown-toc-reparameterization-trick">Reparameterization Trick</a></li>
</ul>
</li>
<li><a href="#beta-vae" id="markdown-toc-beta-vae">Beta-VAE</a></li>
<li><a href="#vq-vae-and-vq-vae-2" id="markdown-toc-vq-vae-and-vq-vae-2">VQ-VAE and VQ-VAE-2</a></li>
<li><a href="#td-vae" id="markdown-toc-td-vae">TD-VAE</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<h2 id="notation">Notation</h2>
<table class="info">
<thead>
<tr>
<th>Symbol</th>
<th>Mean</th>
</tr>
</thead>
<tbody>
<tr>
<td><script type="math/tex">\mathcal{D}</script></td>
<td>The dataset, <script type="math/tex">\mathcal{D} = \{ \mathbf{x}^{(1)}, \mathbf{x}^{(2)}, \dots, \mathbf{x}^{(n)} \}</script>, contains <script type="math/tex">n</script> data samples; <script type="math/tex">\vert\mathcal{D}\vert =n</script>.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{x}^{(i)}</script></td>
<td>Each data point is a vector of <script type="math/tex">d</script> dimensions, <script type="math/tex">\mathbf{x}^{(i)} = [x^{(i)}_1, x^{(i)}_2, \dots, x^{(i)}_d]</script>.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{x}</script></td>
<td>One data sample from the dataset, <script type="math/tex">\mathbf{x} \in \mathcal{D}</script>.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{x}’</script></td>
<td>The reconstructed version of <script type="math/tex">\mathbf{x}</script>.</td>
</tr>
<tr>
<td><script type="math/tex">\tilde{\mathbf{x}}</script></td>
<td>The corrupted version of <script type="math/tex">\mathbf{x}</script>.</td>
</tr>
<tr>
<td><script type="math/tex">\mathbf{z}</script></td>
<td>The compressed code learned in the bottleneck layer.</td>
</tr>
<tr>
<td><script type="math/tex">a_j^{(l)}</script></td>
<td>The activation function for the <script type="math/tex">j</script>-th neuron in the <script type="math/tex">l</script>-th hidden layer.</td>
</tr>
<tr>
<td><script type="math/tex">g_{\phi}(.)</script></td>
<td>The <strong>encoding</strong> function parameterized by <script type="math/tex">\phi</script>.</td>
</tr>
<tr>
<td><script type="math/tex">f_{\theta}(.)</script></td>
<td>The <strong>decoding</strong> function parameterized by <script type="math/tex">\theta</script>.</td>
</tr>
<tr>
<td><script type="math/tex">q_{\phi}(\mathbf{z}\vert\mathbf{x})</script></td>
<td>Estimated posterior probability function, also known as <strong>probabilistic encoder</strong>.</td>
</tr>
<tr>
<td><script type="math/tex">p_{\theta}(\mathbf{x}\vert\mathbf{z})</script></td>
<td>Likelihood of generating true data sample given the latent code, also known as <strong>probabilistic decoder</strong>.</td>
</tr>
</tbody>
</table>
<h2 id="autoencoder">Autoencoder</h2>
<p><strong>Autoencoder</strong> is a neural network designed to learn an identity function in an unsupervised way to reconstruct the original input while compressing the data in the process so as to discover a more efficient and compressed representation. The idea was originated in <a href="https://en.wikipedia.org/wiki/Autoencoder">the 1980s</a>, and later promoted by the seminal paper by <a href="https://pdfs.semanticscholar.org/c50d/ca78e97e335d362d6b991ae0e1448914e9a3.pdf">Hinton & Salakhutdinov, 2006</a>.</p>
<p>It consists of two networks:</p>
<ul>
<li><em>Encoder</em> network: It translates the original high-dimension input into the latent low-dimensional code. The input size is larger than the output size.</li>
<li><em>Decoder</em> network: The decoder network recovers the data from the code, likely with larger and larger output layers.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/autoencoder-architecture.png" alt="Autoencoder architecture" /></p>
<p><em>Fig. 1. Illustration of autoencoder model architecture.</em></p>
<p>The encoder network essentially accomplishes the <a href="https://en.wikipedia.org/wiki/Dimensionality_reduction">dimensionality reduction</a>, just like how we would use Principal Component Analysis (PCA) or Matrix Factorization (MF) for. In addition, the autoencoder is explicitly optimized for the data reconstruction from the code. A good intermediate representation not only can capture latent variables, but also benefits a full <a href="https://ai.googleblog.com/2016/09/image-compression-with-neural-networks.html">decompression</a> process.</p>
<p>The model contains an encoder function <script type="math/tex">g(.)</script> parameterized by <script type="math/tex">\phi</script> and a decoder function <script type="math/tex">f(.)</script> parameterized by <script type="math/tex">\theta</script>. The low-dimensional code learned for input <script type="math/tex">\mathbf{x}</script> in the bottleneck layer is <script type="math/tex">\mathbf{z} =</script> and the reconstructed input is <script type="math/tex">\mathbf{x}' = f_\theta(g_\phi(\mathbf{x}))</script>.</p>
<p>The parameters <script type="math/tex">(\theta, \phi)</script> are learned together to output a reconstructed data sample same as the original input, <script type="math/tex">\mathbf{x} \approx f_\theta(g_\phi(\mathbf{x}))</script>, or in other words, to learn an identity function. There are various metrics to quantify the difference between two vectors, such as cross entropy when the activation function is sigmoid, or as simple as MSE loss:</p>
<script type="math/tex; mode=display">L_\text{AE}(\theta, \phi) = \frac{1}{n}\sum_{i=1}^n (\mathbf{x}^{(i)} - f_\theta(g_\phi(\mathbf{x}^{(i)})))^2</script>
<h2 id="denoising-autoencoder">Denoising Autoencoder</h2>
<p>Since the autoencoder learns the identity function, we are facing the risk of “overfitting” when there are more network parameters than the number of data points.</p>
<p>To avoid overfitting and improve the robustness, <strong>Denoising Autoencoder</strong> (Vincent et al. 2008) proposed a modification to the basic autoencoder. The input is partially corrupted by adding noises to or masking some values of the input vector in a stochastic manner, <script type="math/tex">\tilde{\mathbf{x}} \sim \mathcal{M}_\mathcal{D}(\tilde{\mathbf{x}} \vert \mathbf{x})</script>. Then the model is trained to recover the original input (<strong>Note: Not the corrupt one!</strong>).</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\tilde{\mathbf{x}}^{(i)} &\sim \mathcal{M}_\mathcal{D}(\tilde{\mathbf{x}}^{(i)} \vert \mathbf{x}^{(i)})\\
L_\text{DAE}(\theta, \phi) &= \frac{1}{n} \sum_{i=1}^n (\mathbf{x}^{(i)} - f_\theta(g_\phi(\tilde{\mathbf{x}}^{(i)})))^2
\end{aligned} %]]></script>
<p>where <script type="math/tex">\mathcal{M}_\mathcal{D}</script> defines the mapping from the true data samples to the noisy or corrupted ones.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/denoising-autoencoder-architecture.png" alt="Denoising autoencoder architecture" /></p>
<p><em>Fig. 2. Illustration of denoising autoencoder model architecture.</em></p>
<p>This design is motivated by the fact that humans can easily recognize an object or a scene even the view is partially occluded or corrupted. To “repair” the partially destroyed input, the denoising autoencoder has to discover and capture relationship between dimensions of input in order to infer missing pieces.</p>
<p>For high dimensional input with high redundancy, like images, the model is likely to depend on evidence gathered from a combination of many input dimensions to recover the denoised version (sounds like the <a href="/lil-log/2018/06/24/attention-attention.html">attention</a> mechanism, right?) rather than to overfit one dimension. This builds up a good foundation for learning <em>robust</em> latent representation.</p>
<p>The noise is controlled by a stochastic mapping <script type="math/tex">\mathcal{M}_\mathcal{D}(\tilde{\mathbf{x}} \vert \mathbf{x})</script>, and it is not specific to a particular type of corruption process (i.e. masking noise, Gaussian noise, salt-and-pepper noise, etc.). Naturally the corruption process can be equipped with prior knowledge</p>
<p>In the experiment of the original DAE paper, the noise is applied in this way: a fixed proportion of input dimensions are selected at random and their values are forced to 0. Sounds a lot like dropout, right? Well, the denoising autoencoder was proposed in 2008, 4 years before the dropout paper (<a href="https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf">Hinton, et al. 2012</a>) ;)</p>
<!--
**Stacked Denoising Autoencoder**: In the old days when it was still hard to train deep neural networks, stacking denoising autoencoders was a way to build deep models ([Vincent et al., 2010](http://www.jmlr.org/papers/volume11/vincent10a/vincent10a.pdf)). The denoising autoencoders are trained layer by layer. Once one layer has been trained, it is fed with clean, uncorrupted inputs to learn the encoding in the next layer.
![Stacking denoising autoencoder](/lil-log/assets/images/stacking-dae.png)
{: style="width: 100%;" class="center"}
*Fig. 3. Stacking denoising autoencoders. (Image source: [Vincent et al., 2010](http://www.jmlr.org/papers/volume11/vincent10a/vincent10a.pdf))*
-->
<h2 id="sparse-autoencoder">Sparse Autoencoder</h2>
<p><strong>Sparse Autoencoder</strong> applies a “sparse” constraint on the hidden unit activation to avoid overfitting and improve robustness. It forces the model to only have a small number of hidden units being activated at the same time, or in other words, one hidden neuron should be inactivate most of time.</p>
<p>Recall that common <a href="http://cs231n.github.io/neural-networks-1/#actfun">activation functions</a> include sigmoid, tanh, relu, leaky relu, etc. A neuron is activated when the value is close to 1 and inactivate with a value close to 0.</p>
<p>Let’s say there are <script type="math/tex">s_l</script> neurons in the <script type="math/tex">l</script>-th hidden layer and the activation function for the <script type="math/tex">j</script>-th neuron in this layer is labelled as <script type="math/tex">a^{(l)}_j(.)</script>, <script type="math/tex">j=1, \dots, s_l</script>. The fraction of activation of this neuron <script type="math/tex">\hat{\rho}_j</script> is expected to be a small number <script type="math/tex">\rho</script>, known as <em>sparsity parameter</em>; a common config is <script type="math/tex">\rho = 0.05</script>.</p>
<script type="math/tex; mode=display">\hat{\rho}_j^{(l)} = \frac{1}{n} \sum_{i=1}^n [a_j^{(l)}(\mathbf{x}^{(i)})] \approx \rho</script>
<p>This constraint is achieved by adding a penalty term into the loss function. The KL-divergence <script type="math/tex">D_\text{KL}</script> measures the difference between two Bernoulli distributions, one with mean <script type="math/tex">\rho</script> and the other with mean <script type="math/tex">\hat{\rho}_j^{(l)}</script>. The hyperparameter <script type="math/tex">\beta</script> controls how strong the penalty we want to apply on the sparsity loss.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
L_\text{SAE}(\theta)
&= L(\theta) + \beta \sum_{l=1}^L \sum_{j=1}^{s_l} D_\text{KL}(\rho \| \hat{\rho}_j^{(l)}) \\
&= L(\theta) + \beta \sum_{l=1}^L \sum_{j=1}^{s_l} \rho\log\frac{\rho}{\hat{\rho}_j^{(l)}} + (1-\rho)\log\frac{1-\rho}{1-\hat{\rho}_j^{(l)}}
\end{aligned} %]]></script>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/kl-metric-sparse-autoencoder.png" alt="KL divergence" /></p>
<p><em>Fig. 4. The KL divergence between a Bernoulli distribution with mean <script type="math/tex">\rho=0.25</script> and a Bernoulli distribution with mean <script type="math/tex">0 \leq \hat{\rho} \leq 1</script>.</em></p>
<p><strong><script type="math/tex">k</script>-Sparse Autoencoder</strong></p>
<p>In <script type="math/tex">k</script>-Sparse Autoencoder (<a href="https://arxiv.org/abs/1312.5663">Makhzani and Frey, 2013</a>), the sparsity is enforced by only keeping the top k highest activations in the bottleneck layer with linear activation function.
First we run feedforward through the encoder network to get the compressed code: <script type="math/tex">\mathbf{z} = g(\mathbf{x})</script>.
Sort the values in the code vector <script type="math/tex">\mathbf{z}</script>. Only the k largest values are kept while other neurons are set to 0. This can be done in a ReLU layer with an adjustable threshold too. Now we have a sparsified code: <script type="math/tex">\mathbf{z}’ = \text{Sparsify}(\mathbf{z})</script>.
Compute the output and the loss from the sparsified code, <script type="math/tex">L = \|\mathbf{x} - f(\mathbf{z}') \|_2^2</script>.
And, the back-propagation only goes through the top k activated hidden units!</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/k-sparse-autoencoder.png" alt="k-sparse autoencoder" /></p>
<p><em>Fig. 5. Filters of the k-sparse autoencoder for different sparsity levels k, learnt from MNIST with 1000 hidden units.. (Image source: <a href="https://arxiv.org/abs/1312.5663">Makhzani and Frey, 2013</a>)</em></p>
<h2 id="contractive-autoencoder">Contractive Autoencoder</h2>
<p>Similar to sparse autoencoder, <strong>Contractive Autoencoder</strong> (<a href="http://www.icml-2011.org/papers/455_icmlpaper.pdf">Rifai, et al, 2011</a>) encourages the learned representation to stay in a contractive space for better robustness.</p>
<p>It adds a term in the loss function to penalize the representation being too sensitive to the input, and thus improve the robustness to small perturbations around the training data points. The sensitivity is measured by the Frobenius norm of the Jacobian matrix of the encoder activations with respect to the input:</p>
<script type="math/tex; mode=display">\|J_f(\mathbf{x})\|_F^2 = \sum_{ij} \Big( \frac{\partial h_j(\mathbf{x})}{\partial x_i} \Big)^2</script>
<p>where <script type="math/tex">h_j</script> is one unit output in the compressed code <script type="math/tex">\mathbf{z} = f(x)</script>.</p>
<p>This penalty term is the sum of squares of all partial derivatives of the learned encoding with respect to input dimensions. The authors claimed that empirically this penalty was found to carve a representation that corresponds to a lower-dimensional non-linear manifold, while staying more invariant to majority directions orthogonal to the manifold.</p>
<h2 id="vae-variational-autoencoder">VAE: Variational Autoencoder</h2>
<p>The idea of <strong>Variational Autoencoder</strong> (<a href="https://arxiv.org/abs/1312.6114">Kingma & Welling, 2014</a>), short for <strong>VAE</strong>, is actually less similar to all the autoencoder models above, but deeply rooted in the methods of variational bayesian and graphical model.</p>
<p>Instead of mapping the input into a <em>fixed</em> vector, we want to map it into a distribution. Let’s label this distribution as <script type="math/tex">p_\theta</script>, parameterized by <script type="math/tex">\theta</script>. The relationship between the data input <script type="math/tex">\mathbf{x}</script> and the latent encoding vector <script type="math/tex">\mathbf{z}</script> can be fully defined by:</p>
<ul>
<li>Prior <script type="math/tex">p_\theta(\mathbf{z})</script></li>
<li>Likelihood <script type="math/tex">p_\theta(\mathbf{x}\vert\mathbf{z})</script></li>
<li>Posterior <script type="math/tex">p_\theta(\mathbf{z}\vert\mathbf{x})</script></li>
</ul>
<p>Assuming that we know the real parameter <script type="math/tex">\theta^{*}</script> for this distribution. In order to generate a sample that looks like a real data point <script type="math/tex">\mathbf{x}^{(i)}</script>, we follow these steps:</p>
<ol>
<li>First, sample a <script type="math/tex">\mathbf{z}^{(i)}</script> from a prior distribution <script type="math/tex">p_{\theta^*}(\mathbf{z})</script>.</li>
<li>Then a value <script type="math/tex">\mathbf{x}^{(i)}</script> is generated from a conditional distribution <script type="math/tex">p_{\theta^*}(\mathbf{x} \vert \mathbf{z} = \mathbf{z}^{(i)})</script>.</li>
</ol>
<p>The optimal parameter <script type="math/tex">\theta^{*}</script> is the one that maximizes the probability of generating real data samples:</p>
<script type="math/tex; mode=display">\theta^{*} = \arg\max_\theta \prod_{i=1}^n p_\theta(\mathbf{x}^{(i)})</script>
<p>Commonly we use the log probabilities to convert the product on RHS to a sum:</p>
<script type="math/tex; mode=display">\theta^{*} = \arg\max_\theta \sum_{i=1}^n \log p_\theta(\mathbf{x}^{(i)})</script>
<p>Now let’s update the equation to better demonstrate the data generation process so as to involve the encoding vector:</p>
<script type="math/tex; mode=display">p_\theta(\mathbf{x}^{(i)}) = \int p_\theta(\mathbf{x}^{(i)}\vert\mathbf{z}) p_\theta(\mathbf{z}) d\mathbf{z}</script>
<p>Unfortunately it is not easy to compute <script type="math/tex">p_\theta(\mathbf{x}^{(i)})</script> in this way, as it is very expensive to check all the possible values of <script type="math/tex">\mathbf{z}</script> and sum them up. To narrow down the value space to facilitate faster search, we would like to introduce a new approximation function to output what is a likely code given an input <script type="math/tex">\mathbf{x}</script>, <script type="math/tex">q_\phi(\mathbf{z}\vert\mathbf{x})</script>, parameterized by <script type="math/tex">\phi</script>.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/VAE-graphical-model.png" alt="Distributions in VAE" /></p>
<p><em>Fig. 6. The graphical model involved in Variational Autoencoder. Solid lines denote the generative distribution <script type="math/tex">p_\theta(.)</script> and dashed lines denote the distribution <script type="math/tex">q_\phi (\mathbf{z}\vert\mathbf{x})</script> to approximate the intractable posterior <script type="math/tex">p_\theta (\mathbf{z}\vert\mathbf{x})</script>.</em></p>
<p>Now the structure looks a lot like an autoencoder:</p>
<ul>
<li>The conditional probability <script type="math/tex">p_\theta(\mathbf{x} \vert \mathbf{z})</script> defines a generative model, similar to the decoder <script type="math/tex">f_\theta(\mathbf{x} \vert \mathbf{z})</script> introduced above. <script type="math/tex">p_\theta(\mathbf{x} \vert \mathbf{z})</script> is also known as <em>probabilistic decoder</em>.</li>
<li>The approximation function <script type="math/tex">q_\phi(\mathbf{z} \vert \mathbf{x})</script> is the <em>probabilistic encoder</em>, playing a similar role as <script type="math/tex">g_\phi(\mathbf{z} \vert \mathbf{x})</script> above.</li>
</ul>
<h3 id="loss-function-elbo">Loss Function: ELBO</h3>
<p>The estimated posterior <script type="math/tex">q_\phi(\mathbf{z}\vert\mathbf{x})</script> should be very close to the real one <script type="math/tex">p_\theta(\mathbf{z}\vert\mathbf{x})</script>. We can use <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler divergence</a> to quantify the distance between these two distributions. KL divergence <script type="math/tex">D_\text{KL}(X\|Y)</script> measures how much information is lost if the distribution Y is used to represent X.</p>
<p>In our case we want to minimize <script type="math/tex">D_\text{KL}( q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z}\vert\mathbf{x}) )</script> with respect to <script type="math/tex">\phi</script>.</p>
<p>But why use <script type="math/tex">D_\text{KL}(q_\phi \| p_\theta)</script> (reversed KL) instead of <script type="math/tex">D_\text{KL}(p_\theta \| q_\phi)</script> (forward KL)? Eric Jang has a great explanation in his <a href="https://blog.evjang.com/2016/08/variational-bayes.html">post</a> on Bayesian Variational methods. As a quick recap:</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/forward_vs_reversed_KL.png" alt="Forward vs reversed KL" /></p>
<p><em>Fig. 7. Forward and reversed KL divergence have different demands on how to match two distributions. (Image source: <a href="https://blog.evjang.com/2016/08/variational-bayes.html">blog.evjang.com/2016/08/variational-bayes.html</a>)</em></p>
<ul>
<li>Forward KL divergence: <script type="math/tex">D_\text{KL}(P\|Q) = \mathbb{E}_{z\sim P(z)} \log\frac{P(z)}{Q(z)}</script>; we have to ensure that Q(z)>0 wherever P(z)>0. The optimized variational distribution <script type="math/tex">q(z)</script> has to cover over the entire <script type="math/tex">p(z)</script>.</li>
<li>Reversed KL divergence: <script type="math/tex">D_\text{KL}(Q\|P) = \mathbb{E}_{z\sim Q(z)} \log\frac{Q(z)}{P(z)}</script>; minimizing the reversed KL divergence squeezes the <script type="math/tex">Q(z)</script> under <script type="math/tex">P(z)</script>.</li>
</ul>
<p>Let’s now expand the equation:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
& D_\text{KL}( q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z}\vert\mathbf{x}) ) & \\
&=\int q_\phi(\mathbf{z} \vert \mathbf{x}) \log\frac{q_\phi(\mathbf{z} \vert \mathbf{x})}{p_\theta(\mathbf{z} \vert \mathbf{x})} d\mathbf{z} & \\
&=\int q_\phi(\mathbf{z} \vert \mathbf{x}) \log\frac{q_\phi(\mathbf{z} \vert \mathbf{x})p_\theta(\mathbf{x})}{p_\theta(\mathbf{z}, \mathbf{x})} d\mathbf{z} & \scriptstyle{\text{; Because }p(z \vert x) = p(z, x) / p(x)} \\
&=\int q_\phi(\mathbf{z} \vert \mathbf{x}) \big( \log p_\theta(\mathbf{x}) + \log\frac{q_\phi(\mathbf{z} \vert \mathbf{x})}{p_\theta(\mathbf{z}, \mathbf{x})} \big) d\mathbf{z} & \\
&=\log p_\theta(\mathbf{x}) + \int q_\phi(\mathbf{z} \vert \mathbf{x})\log\frac{q_\phi(\mathbf{z} \vert \mathbf{x})}{p_\theta(\mathbf{z}, \mathbf{x})} d\mathbf{z} & \scriptstyle{\text{; Because }\int q(z \vert x) dz = 1}\\
&=\log p_\theta(\mathbf{x}) + \int q_\phi(\mathbf{z} \vert \mathbf{x})\log\frac{q_\phi(\mathbf{z} \vert \mathbf{x})}{p_\theta(\mathbf{x}\vert\mathbf{z})p_\theta(\mathbf{z})} d\mathbf{z} & \scriptstyle{\text{; Because }p(z, x) = p(x \vert z) p(z)} \\
&=\log p_\theta(\mathbf{x}) + \mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z} \vert \mathbf{x})}[\log \frac{q_\phi(\mathbf{z} \vert \mathbf{x})}{p_\theta(\mathbf{z})} - \log p_\theta(\mathbf{x} \vert \mathbf{z})] &\\
&=\log p_\theta(\mathbf{x}) + D_\text{KL}(q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z})) - \mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}\vert\mathbf{x})}\log p_\theta(\mathbf{x}\vert\mathbf{z}) &
\end{aligned} %]]></script>
<p>So we have:</p>
<script type="math/tex; mode=display">D_\text{KL}( q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z}\vert\mathbf{x}) ) =\log p_\theta(\mathbf{x}) + D_\text{KL}(q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z})) - \mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}\vert\mathbf{x})}\log p_\theta(\mathbf{x}\vert\mathbf{z})</script>
<p>Once rearrange the left and right hand side of the equation,</p>
<script type="math/tex; mode=display">\log p_\theta(\mathbf{x}) - D_\text{KL}( q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z}\vert\mathbf{x}) ) = \mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}\vert\mathbf{x})}\log p_\theta(\mathbf{x}\vert\mathbf{z}) - D_\text{KL}(q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z}))</script>
<p>The LHS of the equation is exactly what we want to maximize when learning the true distributions: we want to maximize the (log-)likelihood of generating real data (that is <script type="math/tex">\log p_\theta(\mathbf{x})</script>) and also minimize the difference between the real and estimated posterior distributions (the term <script type="math/tex">D_\text{KL}</script> works like a regularizer). Note that <script type="math/tex">p_\theta(\mathbf{x})</script> is fixed with respect to <script type="math/tex">q_\phi</script>.</p>
<p>The negation of the above defines our loss function:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
L_\text{VAE}(\theta, \phi)
&= -\log p_\theta(\mathbf{x}) + D_\text{KL}( q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z}\vert\mathbf{x}) )\\
&= - \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z}\vert\mathbf{x})} \log p_\theta(\mathbf{x}\vert\mathbf{z}) + D_\text{KL}( q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z}) ) \\
\theta^{*}, \phi^{*} &= \arg\min_{\theta, \phi} L_\text{VAE}
\end{aligned} %]]></script>
<p>In Variational Bayesian methods, this loss function is known as the <em>variational lower bound</em>, or <em>evidence lower bound</em>. The “lower bound” part in the name comes from the fact that KL divergence is always non-negative and thus <script type="math/tex">-L_\text{VAE}</script> is the lower bound of <script type="math/tex">\log p_\theta (\mathbf{x})</script>.</p>
<script type="math/tex; mode=display">-L_\text{VAE} = \log p_\theta(\mathbf{x}) - D_\text{KL}( q_\phi(\mathbf{z}\vert\mathbf{x}) \| p_\theta(\mathbf{z}\vert\mathbf{x}) ) \leq \log p_\theta(\mathbf{x})</script>
<p>Therefore by minimizing the loss, we are maximizing the lower bound of the probability of generating real data samples.</p>
<h3 id="reparameterization-trick">Reparameterization Trick</h3>
<p>The expectation term in the loss function invokes generating samples from <script type="math/tex">\mathbf{z} \sim q_\phi(\mathbf{z}\vert\mathbf{x})</script>. Sampling is a stochastic process and therefore we cannot backpropagate the gradient. To make it trainable, the reparameterization trick is introduced: It is often possible to express the random variable <script type="math/tex">\mathbf{z}</script> as a deterministic variable <script type="math/tex">\mathbf{z} = \mathcal{T}_\phi(\mathbf{x}, \boldsymbol{\epsilon})</script>, where <script type="math/tex">\boldsymbol{\epsilon}</script> is an auxiliary independent random variable, and the transformation function <script type="math/tex">\mathcal{T}_\phi</script> parameterized by <script type="math/tex">\phi</script> converts <script type="math/tex">\boldsymbol{\epsilon}</script> to <script type="math/tex">\mathbf{z}</script>.</p>
<p>For example, a common choice of the form of <script type="math/tex">q_\phi(\mathbf{z}\vert\mathbf{x})</script> is a multivariate Gaussian with a diagonal covariance structure:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{z} &\sim q_\phi(\mathbf{z}\vert\mathbf{x}^{(i)}) = \mathcal{N}(\mathbf{z}; \boldsymbol{\mu}^{(i)}, \boldsymbol{\sigma}^{2(i)}\boldsymbol{I}) & \\
\mathbf{z} &= \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon} \text{, where } \boldsymbol{\epsilon} \sim \mathcal{N}(0, \boldsymbol{I}) & \scriptstyle{\text{; Reparameterization trick.}}
\end{aligned} %]]></script>
<p>where <script type="math/tex">\odot</script> refers to element-wise product.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/reparameterization-trick.png" alt="Reparameterization trick" /></p>
<p><em>Fig. 8. Illustration of how the reparameterization trick makes the <script type="math/tex">\mathbf{z}</script> sampling process trainable.(Image source: Slide 12 in Kingma’s NIPS 2015 workshop <a href="http://dpkingma.com/wordpress/wp-content/uploads/2015/12/talk_nips_workshop_2015.pdf">talk</a>)</em></p>
<p>The reparameterization trick works for other types of distributions too, not only Gaussian.
In the multivariate Gaussian case, we make the model trainable by learning the mean and variance of the distribution, <script type="math/tex">\mu</script> and <script type="math/tex">\sigma</script>, explicitly using the reparameterization trick, while the stochasticity remains in the random variable <script type="math/tex">\boldsymbol{\epsilon} \sim \mathcal{N}(0, \boldsymbol{I})</script>.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/vae-gaussian.png" alt="Gaussian VAE" /></p>
<p><em>Fig. 9. Illustration of variational autoencoder model with the multivariate Gaussian assumption.</em></p>
<h2 id="beta-vae">Beta-VAE</h2>
<p>If each variable in the inferred latent representation <script type="math/tex">\mathbf{z}</script> is only sensitive to one single generative factor and relatively invariant to other factors, we will say this representation is disentangled or factorized. One benefit that often comes with disentangled representation is <em>good interpretability</em> and easy generalization to a variety of tasks.</p>
<p>For example, a model trained on photos of human faces might capture the gentle, skin color, hair color, hair length, emotion, whether wearing a pair of glasses and many other relatively independent factors in separate dimensions. Such a disentangled representation is very beneficial to facial image generation.</p>
<p>β-VAE (<a href="https://openreview.net/forum?id=Sy2fzU9gl">Higgins et al., 2017</a>) is a modification of Variational Autoencoder with a special emphasis to discover disentangled latent factors. Following the same incentive in VAE, we want to maximize the probability of generating real data, while keeping the distance between the real and estimated posterior distributions small (say, under a small constant <script type="math/tex">\delta</script>):</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
&\max_{\phi, \theta} \mathbb{E}_{\mathbf{x}\sim\mathcal{D}}[\mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z}\vert\mathbf{x})} \log p_\theta(\mathbf{x}\vert\mathbf{z})]\\
&\text{subject to } D_\text{KL}(q_\phi(\mathbf{z}\vert\mathbf{x})\|p_\theta(\mathbf{z})) < \delta
\end{aligned} %]]></script>
<p>We can rewrite it as a Lagrangian with a Lagrangian multiplier <script type="math/tex">\beta</script> under the <a href="https://www.cs.cmu.edu/~ggordon/10725-F12/slides/16-kkt.pdf">KKT condition</a>. The above optimization problem with only one inequality constraint is equivalent to maximizing the following equation <script type="math/tex">\mathcal{F}(\theta, \phi, \beta)</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathcal{F}(\theta, \phi, \beta) &= \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z}\vert\mathbf{x})} \log p_\theta(\mathbf{x}\vert\mathbf{z}) - \beta(D_\text{KL}(q_\phi(\mathbf{z}\vert\mathbf{x})\|p_\theta(\mathbf{z})) - \delta) & \\
& = \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z}\vert\mathbf{x})} \log p_\theta(\mathbf{x}\vert\mathbf{z}) - \beta D_\text{KL}(q_\phi(\mathbf{z}\vert\mathbf{x})\|p_\theta(\mathbf{z})) + \beta \delta & \\
& \geq \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z}\vert\mathbf{x})} \log p_\theta(\mathbf{x}\vert\mathbf{z}) - \beta D_\text{KL}(q_\phi(\mathbf{z}\vert\mathbf{x})\|p_\theta(\mathbf{z})) & \scriptstyle{\text{; Because }\beta,\delta\geq 0}
\end{aligned} %]]></script>
<p>The loss function of <script type="math/tex">\beta</script>-VAE is defined as:</p>
<script type="math/tex; mode=display">L_\text{BETA}(\phi, \beta) = - \mathbb{E}_{\mathbf{z} \sim q_\phi(\mathbf{z}\vert\mathbf{x})} \log p_\theta(\mathbf{x}\vert\mathbf{z}) + \beta D_\text{KL}(q_\phi(\mathbf{z}\vert\mathbf{x})\|p_\theta(\mathbf{z}))</script>
<p>where the Lagrangian multiplier <script type="math/tex">\beta</script> is considered as a hyperparameter.</p>
<p>Since the negation of <script type="math/tex">L_\text{BETA}(\phi, \beta)</script> is the lower bound of the Lagrangian <script type="math/tex">\mathcal{F}(\theta, \phi, \beta)</script>. Minimizing the loss is equivalent to maximizing the Lagrangian and thus works for our initial optimization problem.</p>
<p>When <script type="math/tex">\beta=1</script>, it is same as VAE. When <script type="math/tex">\beta > 1</script>, it applies a stronger constraint on the latent bottleneck and limits the representation capacity of <script type="math/tex">\mathbf{z}</script>. For some conditionally independent generative factors, keeping them disentangled is the most efficient representation. Therefore a higher <script type="math/tex">\beta</script> encourages more efficient latent encoding and further encourages the disentanglement. Meanwhile, a higher <script type="math/tex">\beta</script> may create a trade-off between reconstruction quality and the extent of disentanglement.</p>
<p><a href="https://arxiv.org/pdf/1804.03599.pdf">Burgess, et al. (2017)</a> discussed the distentangling in <script type="math/tex">\beta</script>-VAE in depth with an inspiration by the <a href="/lil-log/2017/09/28/anatomize-deep-learning-with-information-theory.html">information bottleneck theory</a> and further proposed a modification to <script type="math/tex">\beta</script>-VAE to better control the encoding representation capacity.</p>
<h2 id="vq-vae-and-vq-vae-2">VQ-VAE and VQ-VAE-2</h2>
<p>The <strong>VQ-VAE</strong> (“Vector Quantised-Variational AutoEncoder”; <a href="http://papers.nips.cc/paper/7210-neural-discrete-representation-learning.pdf">van den Oord, et al. 2017</a>) model learns a discrete latent variable by the encoder, since discrete representations may be a more natural fit for problems like language, speech, reasoning, etc.</p>
<p>Vector quantisation (VQ) is a method to map $K$-dimensional vectors into a finite set of “code” vectors. The process is very much similar to <a href="https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm">KNN</a> algorithm. The optimal centroid code vector that a sample should be mapped to is the one with minimum euclidean distance.</p>
<p>Let <script type="math/tex">\mathbf{e}_i \in \mathbb{R}^{K \times D}, i=1, \dots, K</script> be the latent embedding space (also known as “codebook”) in VQ-VAE, where <script type="math/tex">K</script> is the number of latent variable categories and <script type="math/tex">D</script> is the embedding size. The encoder output <script type="math/tex">E(\mathbf{x}) = \mathbf{z}_e</script> goes through a nearest-neighbor lookup to match to one of <script type="math/tex">K</script> embedding vectors and then this matched code vector becomes the input for the decoder <script type="math/tex">D(.)</script>:</p>
<script type="math/tex; mode=display">\mathbf{z}_q(\mathbf{x}) = \text{Quantize}(E(\mathbf{x})) = \mathbf{e}_k \text{ where } k = \arg\min_i \|E(\mathbf{x}) - \mathbf{e}_i \|_2</script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/VQ-VAE.png" alt="VQ-VAE" /></p>
<p><em>Fig. 10. The architecture of VQ-VAE (Image source: <a href="http://papers.nips.cc/paper/7210-neural-discrete-representation-learning.pdf">van den Oord, et al. 2017</a>)</em></p>
<p>Because argmin() is non-differentiable on a discrete space, the gradients <script type="math/tex">\nabla_z L</script> from decoder input <script type="math/tex">\mathbf{z}_q</script> is copied to the encoder output <script type="math/tex">\mathbf{z}_e</script>. Other than reconstruction loss, VQ-VAE also optimizes:</p>
<ul>
<li><em>VQ loss</em>: The L2 error between the embedding space and the encoder outputs.</li>
<li><em>Commitment loss</em>: A measure to encourage the encoder output to stay close to the embedding space and to prevent it from fluctuating too frequently from one code vector to another.</li>
</ul>
<script type="math/tex; mode=display">L = \underbrace{\|\mathbf{x} - D(\mathbf{e}_k)\|_2^2}_{\textrm{reconstruction loss}} +
\underbrace{\|\text{sg}[E(\mathbf{x})] - \mathbf{e}_k\|_2^2}_{\textrm{VQ loss}} +
\underbrace{\beta \|E(\mathbf{x}) - \text{sg}[\mathbf{e}_k]\|_2^2}_{\textrm{commitment loss}}</script>
<p>where <script type="math/tex">\text{sq}[.]</script> is the <code class="highlighter-rouge">stop_gradient</code> operator.</p>
<p>The embedding vectors in the codebook is updated through EMA (exponential moving average). Given a code vector <script type="math/tex">\mathbf{e}_i</script>, say we have <script type="math/tex">n_i</script> encoder output vectors, <script type="math/tex">\{\mathbf{z}_{i,j}\}_{j=1}^{n_i}</script>, that are quantized to <script type="math/tex">\mathbf{e}_i</script>:</p>
<p><script type="math/tex">N_i^{(t)} = \gamma N_i^{(t-1)} + (1-\gamma)n_i^{(t)}\;\;\;
\mathbf{m}_i^{(t)} = \gamma \mathbf{m}_i^{(t-1)} + (1-\gamma)\sum_{j=1}^{n_i^{(t)}}\mathbf{z}_{i,j}^{(t)}\;\;\;
\mathbf{e}_i^{(t)} = \mathbf{m}_i^{(t)} / N_i^{(t)}</script>
where <script type="math/tex">(t)</script> refers to batch sequence in time. <script type="math/tex">N_i</script> and <script type="math/tex">\mathbf{m}_i</script> are accumulated vector count and volume, respectively.</p>
<p>VQ-VAE-2 (<a href="https://arxiv.org/abs/1906.00446">Ali Razavi, et al. 2019</a>) is a two-level hierarchical VQ-VAE combined with self-attention autoregressive model.</p>
<ol>
<li>Stage 1 is to <strong>train a hierarchical VQ-VAE</strong>: The design of hierarchical latent variables intends to separate local patterns (i.e., texture) from global information (i.e., object shapes). The training of the larger bottom level codebook is conditioned on the smaller top level code too, so that it does not have to learn everything from scratch.</li>
<li>Stage 2 is to <strong>learn a prior over the latent discrete codebook</strong> so that we sample from it and generate images. In this way, the decoder can receive input vectors sampled from a similar distribution as the one in training. A powerful autoregressive model enhanced with multi-headed self-attention layers is used to capture the prior distribution (like <a href="https://arxiv.org/abs/1712.09763">PixelSNAIL; Chen et al 2017</a>).</li>
</ol>
<p>Considering that VQ-VAE-2 depends on discrete latent variables configured in a simple hierarchical setting, the quality of its generated images are pretty amazing.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/VQ-VAE-2.png" alt="VQ-VAE-2" /></p>
<p><em>Fig. 11. Architecture of hierarchical VQ-VAE and multi-stage image generation. (Image source: <a href="https://arxiv.org/abs/1906.00446">Ali Razavi, et al. 2019</a>)</em></p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/VQ-VAE-2-algo.png" alt="VQ-VAE-2-algo" /></p>
<p><em>Fig. 12. The VQ-VAE-2 algorithm. (Image source: <a href="https://arxiv.org/abs/1906.00446">Ali Razavi, et al. 2019</a>)</em></p>
<h2 id="td-vae">TD-VAE</h2>
<p><strong>TD-VAE</strong> (“Temporal Difference VAE”; <a href="https://arxiv.org/abs/1806.03107">Gregor et al., 2019</a>) works with sequential data. It relies on three main ideas, described below.</p>
<p style="width: 80%;" class="center"><img src="/lil-log/assets/images/TD-VAE-state-space.png" alt="TD-VAE-state-space" /></p>
<p><em>Fig. 13. State-space model as a Markov Chain model.</em></p>
<p><strong>1. State-Space Models</strong>
<br />
In (latent) state-space models, a sequence of unobserved hidden states <script type="math/tex">\mathbf{z} = (z_1, \dots, z_T)</script> determine the observation states <script type="math/tex">\mathbf{x} = (x_1, \dots, x_T)</script>. Each time step in the Markov chain model in Fig. 13 can be trained in a similar manner as in Fig. 6, where the intractable posterior <script type="math/tex">p(z \vert x)</script> is approximated by a function <script type="math/tex">q(z \vert x)</script>.</p>
<p><strong>2. Belief State</strong>
<br />
An agent should learn to encode all the past states to reason about the future, named as <em>belief state</em>, <script type="math/tex">b_t = belief(x_1, \dots, x_t) = belief(b_{t-1}, x_t)</script>. Given this, the distribution of future states conditioned on the past can be written as <script type="math/tex">p(x_{t+1}, \dots, x_T \vert x_1, \dots, x_t) \approx p(x_{t+1}, \dots, x_T \vert b_t)</script>. The hidden states in a recurrent policy are used as the agent’s belief state in TD-VAE. Thus we have <script type="math/tex">b_t = \text{RNN}(b_{t-1}, x_t)</script>.</p>
<p><strong>3. Jumpy Prediction</strong>
<br />
Further, an agent is expected to imagine distant futures based on all the information gathered so far, suggesting the capability of making jumpy predictions, that is, predicting states several steps further into the future.</p>
<p>Recall what we have learned from the variance lower bound <a href="#loss-function-elbo">above</a>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\log p(x)
&\geq \log p(x) - D_\text{KL}(q(z|x)\|p(z|x)) \\
&= \mathbb{E}_{z\sim q} \log p(x|z) - D_\text{KL}(q(z|x)\|p(z)) \\
&= \mathbb{E}_{z \sim q} \log p(x|z) - \mathbb{E}_{z \sim q} \log \frac{q(z|x)}{p(z)} \\
&= \mathbb{E}_{z \sim q}[\log p(x|z) -\log q(z|x) + \log p(z)] \\
&= \mathbb{E}_{z \sim q}[\log p(x, z) -\log q(z|x)] \\
\log p(x)
&\geq \mathbb{E}_{z \sim q}[\log p(x, z) -\log q(z|x)]
\end{aligned} %]]></script>
<p>Now let’s model the distribution of the state <script type="math/tex">x_t</script> as a probability function conditioned on all the past states <script type="math/tex">% <![CDATA[
x_{<t} %]]></script> and two latent variables, <script type="math/tex">z_t</script> and <script type="math/tex">z_{t-1}</script>, at current time step and one step back:</p>
<script type="math/tex; mode=display">% <![CDATA[
\log p(x_t|x_{<t}) \geq \mathbb{E}_{(z_{t-1}, z_t) \sim q}[\log p(x_t, z_{t-1}, z_{t}|x_{<t}) -\log q(z_{t-1}, z_t|x_{\leq t})] %]]></script>
<p>Continue expanding the equation:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
& \log p(x_t|x_{<t}) \\
&\geq \mathbb{E}_{(z_{t-1}, z_t) \sim q}[\log p(x_t, z_{t-1}, z_{t}|x_{<t}) -\log q(z_{t-1}, z_t|x_{\leq t})] \\
&\geq \mathbb{E}_{(z_{t-1}, z_t) \sim q}[\log p(x_t|\color{red}{z_{t-1}}, z_{t}, \color{red}{x_{<t}}) + \color{blue}{\log p(z_{t-1}, z_{t}|x_{<t})} -\log q(z_{t-1}, z_t|x_{\leq t})] \\
&\geq \mathbb{E}_{(z_{t-1}, z_t) \sim q}[\log p(x_t|z_{t}) + \color{blue}{\log p(z_{t-1}|x_{<t})} + \color{blue}{\log p(z_{t}|z_{t-1})} - \color{green}{\log q(z_{t-1}, z_t|x_{\leq t})}] \\
&\geq \mathbb{E}_{(z_{t-1}, z_t) \sim q}[\log p(x_t|z_{t}) + \log p(z_{t-1}|x_{<t}) + \log p(z_{t}|z_{t-1}) - \color{green}{\log q(z_t|x_{\leq t})} - \color{green}{\log q(z_{t-1}|z_t, x_{\leq t})}]
\end{aligned} %]]></script>
<p>Notice two things:</p>
<ul>
<li>The <span style="color: red;">red</span> terms can be ignored according to Markov assumptions.</li>
<li>The <span style="color: blue;">blue</span> term is expanded according to Markov assumptions.</li>
<li>The <span style="color: green;">green</span> term is expanded to include an one-step prediction back to the past as a smoothing distribution.</li>
</ul>
<p>Precisely, there are four types of distributions to learn:</p>
<ol>
<li><script type="math/tex">p_D(.)</script> is the <strong>decoder</strong> distribution:
<ul>
<li><script type="math/tex">p(x_t \mid z_t)</script> is the encoder by the common definition;</li>
<li><script type="math/tex">p(x_t \mid z_t) \to p_D(x_t \mid z_t)</script>;</li>
</ul>
</li>
<li><script type="math/tex">p_T(.)</script> is the <strong>transition</strong> distribution:
<ul>
<li><script type="math/tex">p(z_t \mid z_{t-1})</script> captures the sequential dependency between latent variables;</li>
<li><script type="math/tex">p(z_t \mid z_{t-1}) \to p_T(z_t \mid z_{t-1})</script>;</li>
</ul>
</li>
<li><script type="math/tex">p_B(.)</script> is the <strong>belief</strong> distribution:
<ul>
<li>Both <script type="math/tex">% <![CDATA[
p(z_{t-1} \mid x_{<t}) %]]></script> and <script type="math/tex">q(z_t \mid x_{\leq t})</script> can use the belief states to predict the latent variables;</li>
<li><script type="math/tex">% <![CDATA[
p(z_{t-1} \mid x_{<t}) \to p_B(z_{t-1} \mid b_{t-1}) %]]></script>;</li>
<li><script type="math/tex">q(z_{t} \mid x_{\leq t}) \to p_B(z_t \mid b_t)</script>;</li>
</ul>
</li>
<li><script type="math/tex">p_S(.)</script> is the <strong>smoothing</strong> distribution:
<ul>
<li>The back-to-past smoothing term <script type="math/tex">q(z_{t-1} \mid z_t, x_{\leq t})</script> can be rewritten to be dependent of belief states too;</li>
<li><script type="math/tex">q(z_{t-1} \mid z_t, x_{\leq t}) \to p_S(z_{t-1} \mid z_t, b_{t-1}, b_t)</script>;</li>
</ul>
</li>
</ol>
<p>To incorporate the idea of jumpy prediction, the sequential ELBO has to not only work on <script type="math/tex">t, t+1</script>, but also two distant timestamp <script type="math/tex">% <![CDATA[
t_1 < t_2 %]]></script>. Here is the final TD-VAE objective function to maximize:</p>
<script type="math/tex; mode=display">J_{t_1, t_2} = \mathbb{E}[
\log p_D(x_{t_2}|z_{t_2})
+ \log p_B(z_{t_1}|b_{t_1})
+ \log p_T(z_{t_2}|z_{t_1})
- \log p_B(z_{t_2}|b_{t_2})
- \log p_S(z_{t_1}|z_{t_2}, b_{t_1}, b_{t_2})]</script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/TD-VAE.png" alt="TD-VAE" /></p>
<p><em>Fig. 14. A detailed overview of TD-VAE architecture, very nicely done. (Image source: <a href="https://arxiv.org/abs/1806.03107">TD-VAE paper</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2018VAE,
title = "From Autoencoder to Beta-VAE",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2018",
url = "http://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html"
}
</code></pre></div></div>
<h2 id="references">References</h2>
<p>[1] Geoffrey E. Hinton, and Ruslan R. Salakhutdinov. <a href="https://pdfs.semanticscholar.org/c50d/ca78e97e335d362d6b991ae0e1448914e9a3.pdf">“Reducing the dimensionality of data with neural networks.”</a> Science 313.5786 (2006): 504-507.</p>
<p>[2] Pascal Vincent, et al. <a href="http://www.cs.toronto.edu/~larocheh/publications/icml-2008-denoising-autoencoders.pdf">“Extracting and composing robust features with denoising autoencoders.”</a> ICML, 2008.</p>
<p>[3] Pascal Vincent, et al. <a href="http://www.jmlr.org/papers/volume11/vincent10a/vincent10a.pdf">“Stacked denoising autoencoders: Learning useful representations in a deep network with a local denoising criterion.”</a>. Journal of machine learning research 11.Dec (2010): 3371-3408.</p>
<p>[4] Geoffrey E. Hinton, Nitish Srivastava, Alex Krizhevsky, Ilya Sutskever, and Ruslan R. Salakhutdinov. “Improving neural networks by preventing co-adaptation of feature detectors.” arXiv preprint arXiv:1207.0580 (2012).</p>
<p>[5] <a href="https://web.stanford.edu/class/cs294a/sparseAutoencoder.pdf">Sparse Autoencoder</a> by Andrew Ng.</p>
<p>[6] Alireza Makhzani, Brendan Frey (2013). <a href="https://arxiv.org/abs/1312.5663">“k-sparse autoencoder”</a>. ICLR 2014.</p>
<p>[7] Salah Rifai, et al. <a href="http://www.icml-2011.org/papers/455_icmlpaper.pdf">“Contractive auto-encoders: Explicit invariance during feature extraction.”</a> ICML, 2011.</p>
<p>[8] Diederik P. Kingma, and Max Welling. <a href="https://arxiv.org/abs/1312.6114">“Auto-encoding variational bayes.”</a> ICLR 2014.</p>
<p>[9] <a href="https://jaan.io/what-is-variational-autoencoder-vae-tutorial/">Tutorial - What is a variational autoencoder?</a> on jaan.io</p>
<p>[10] Youtube tutorial: <a href="https://www.youtube.com/watch?v=9zKuYvjFFS8">Variational Autoencoders</a> by Arxiv Insights</p>
<p>[11] <a href="https://blog.evjang.com/2016/08/variational-bayes.html">“A Beginner’s Guide to Variational Methods: Mean-Field Approximation”</a> by Eric Jang.</p>
<p>[12] Carl Doersch. <a href="https://arxiv.org/abs/1606.05908">“Tutorial on variational autoencoders.”</a> arXiv:1606.05908, 2016.</p>
<p>[13] Irina Higgins, et al. <a href="https://openreview.net/forum?id=Sy2fzU9gl">”<script type="math/tex">\beta</script>-VAE: Learning basic visual concepts with a constrained variational framework.”</a> ICLR 2017.</p>
<p>[14] Christopher P. Burgess, et al. <a href="https://arxiv.org/abs/1804.03599">“Understanding disentangling in beta-VAE.”</a> NIPS 2017.</p>
<p>[15] Aaron van den Oord, et al. <a href="https://arxiv.org/abs/1711.00937">“Neural Discrete Representation Learning”</a> NIPS 2017.</p>
<p>[16] Ali Razavi, et al. <a href="https://arxiv.org/abs/1906.00446">“Generating Diverse High-Fidelity Images with VQ-VAE-2”</a>. arXiv preprint arXiv:1906.00446 (2019).</p>
<p>[17] Xi Chen, et al. <a href="https://arxiv.org/abs/1712.09763">“PixelSNAIL: An Improved Autoregressive Generative Model.”</a> arXiv preprint arXiv:1712.09763 (2017).</p>
<p>[18] Karol Gregor, et al. <a href="https://arxiv.org/abs/1806.03107">“Temporal Difference Variational Auto-Encoder.”</a> ICLR 2019.</p>Lilian WengAutocoders are a family of neural network models aiming to learn compressed latent variables of high-dimensional data. Starting from the basic autocoder model, this post reviews several variations, including denoising, sparse, and contractive autoencoders, and then Variational Autoencoder (VAE) and its modification beta-VAE.Attention? Attention!2018-06-24T11:07:00+00:002018-06-24T11:07:00+00:00https://lilianweng.github.io/lil-log/2018/06/24/attention-attention<blockquote>
<p>Attention has been a fairly popular concept and a useful tool in the deep learning community in recent years. In this post, we are gonna look into how attention was invented, and various attention mechanisms and models, such as transformer and SNAIL.</p>
</blockquote>
<!--more-->
<p><span style="color: #286ee0;">[Updated on 2018-10-28: Add <a href="#pointer-network">Pointer Network</a> and the <a href="https://github.com/lilianweng/transformer-tensorflow">link</a> to my implementation of Transformer.]</span><br />
<span style="color: #286ee0;">[Updated on 2018-11-06: Add a <a href="https://github.com/lilianweng/transformer-tensorflow">link</a> to the implementation of Transformer model.]</span><br />
<span style="color: #286ee0;">[Updated on 2018-11-18: Add <a href="#neural-turing-machines">Neural Turing Machines</a>.]</span><br />
<span style="color: #286ee0;">[Updated on 2019-07-18: Correct the mistake on using the term “self-attention” when introducing the <a href="https://arxiv.org/abs/1502.03044">show-attention-tell</a> paper; moved it to <a href="#self-attention">Self-Attention</a> section.]</span></p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#whats-wrong-with-seq2seq-model" id="markdown-toc-whats-wrong-with-seq2seq-model">What’s Wrong with Seq2Seq Model?</a></li>
<li><a href="#born-for-translation" id="markdown-toc-born-for-translation">Born for Translation</a> <ul>
<li><a href="#definition" id="markdown-toc-definition">Definition</a></li>
</ul>
</li>
<li><a href="#a-family-of-attention-mechanisms" id="markdown-toc-a-family-of-attention-mechanisms">A Family of Attention Mechanisms</a> <ul>
<li><a href="#summary" id="markdown-toc-summary">Summary</a></li>
<li><a href="#self-attention" id="markdown-toc-self-attention">Self-Attention</a></li>
<li><a href="#soft-vs-hard-attention" id="markdown-toc-soft-vs-hard-attention">Soft vs Hard Attention</a></li>
<li><a href="#global-vs-local-attention" id="markdown-toc-global-vs-local-attention">Global vs Local Attention</a></li>
</ul>
</li>
<li><a href="#neural-turing-machines" id="markdown-toc-neural-turing-machines">Neural Turing Machines</a> <ul>
<li><a href="#reading-and-writing" id="markdown-toc-reading-and-writing">Reading and Writing</a></li>
<li><a href="#attention-mechanisms" id="markdown-toc-attention-mechanisms">Attention Mechanisms</a></li>
</ul>
</li>
<li><a href="#pointer-network" id="markdown-toc-pointer-network">Pointer Network</a></li>
<li><a href="#transformer" id="markdown-toc-transformer">Transformer</a> <ul>
<li><a href="#key-value-and-query" id="markdown-toc-key-value-and-query">Key, Value and Query</a></li>
<li><a href="#multi-head-self-attention" id="markdown-toc-multi-head-self-attention">Multi-Head Self-Attention</a></li>
<li><a href="#encoder" id="markdown-toc-encoder">Encoder</a></li>
<li><a href="#decoder" id="markdown-toc-decoder">Decoder</a></li>
<li><a href="#full-architecture" id="markdown-toc-full-architecture">Full Architecture</a></li>
</ul>
</li>
<li><a href="#snail" id="markdown-toc-snail">SNAIL</a></li>
<li><a href="#self-attention-gan" id="markdown-toc-self-attention-gan">Self-Attention GAN</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<p>Attention is, to some extent, motivated by how we pay visual attention to different regions of an image or correlate words in one sentence. Take the picture of a Shiba Inu in Fig. 1 as an example.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/shiba-example-attention.png" alt="shiba" /></p>
<p><em>Fig. 1. A Shiba Inu in a men’s outfit. The credit of the original photo goes to Instagram <a href="https://www.instagram.com/mensweardog/?hl=en">@mensweardog</a>.</em></p>
<p>Human visual attention allows us to focus on a certain region with “high resolution” (i.e. look at the pointy ear in the yellow box) while perceiving the surrounding image in “low resolution” (i.e. now how about the snowy background and the outfit?), and then adjust the focal point or do the inference accordingly. Given a small patch of an image, pixels in the rest provide clues what should be displayed there. We expect to see a pointy ear in the yellow box because we have seen a dog’s nose, another pointy ear on the right, and Shiba’s mystery eyes (stuff in the red boxes). However, the sweater and blanket at the bottom would not be as helpful as those doggy features.</p>
<p>Similarly, we can explain the relationship between words in one sentence or close context. When we see “eating”, we expect to encounter a food word very soon. The color term describes the food, but probably not so much with “eating” directly.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/sentence-example-attention.png" alt="sentence" /></p>
<p><em>Fig. 2. One word “attends” to other words in the same sentence differently.</em></p>
<p>In a nutshell, attention in the deep learning can be broadly interpreted as a vector of importance weights: in order to predict or infer one element, such as a pixel in an image or a word in a sentence, we estimate using the attention vector how strongly it is correlated with (or “<em>attends to</em>” as you may have read in many papers) other elements and take the sum of their values weighted by the attention vector as the approximation of the target.</p>
<h2 id="whats-wrong-with-seq2seq-model">What’s Wrong with Seq2Seq Model?</h2>
<p>The <strong>seq2seq</strong> model was born in the field of language modeling (<a href="https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf">Sutskever, et al. 2014</a>). Broadly speaking, it aims to transform an input sequence (source) to a new one (target) and both sequences can be of arbitrary lengths. Examples of transformation tasks include machine translation between multiple languages in either text or audio, question-answer dialog generation, or even parsing sentences into grammar trees.</p>
<p>The seq2seq model normally has an encoder-decoder architecture, composed of:</p>
<ul>
<li>An <strong>encoder</strong> processes the input sequence and compresses the information into a context vector (also known as sentence embedding or “thought” vector) of a <em>fixed length</em>. This representation is expected to be a good summary of the meaning of the <em>whole</em> source sequence.</li>
<li>A <strong>decoder</strong> is initialized with the context vector to emit the transformed output. The early work only used the last state of the encoder network as the decoder initial state.</li>
</ul>
<p>Both the encoder and decoder are recurrent neural networks, i.e. using <a href="http://colah.github.io/posts/2015-08-Understanding-LSTMs/">LSTM or GRU</a> units.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/encoder-decoder-example.png" alt="encoder-decoder model with additive attention layer" /></p>
<p><em>Fig. 3. The encoder-decoder model, translating the sentence “she is eating a green apple” to Chinese. The visualization of both encoder and decoder is unrolled in time.</em></p>
<p>A critical and apparent disadvantage of this fixed-length context vector design is incapability of remembering long sentences. Often it has forgotten the first part once it completes processing the whole input. The attention mechanism was born (<a href="https://arxiv.org/pdf/1409.0473.pdf">Bahdanau et al., 2015</a>) to resolve this problem.</p>
<h2 id="born-for-translation">Born for Translation</h2>
<p>The attention mechanism was born to help memorize long source sentences in neural machine translation (<a href="https://arxiv.org/pdf/1409.0473.pdf">NMT</a>). Rather than building a single context vector out of the encoder’s last hidden state, the secret sauce invented by attention is to create shortcuts between the context vector and the entire source input. The weights of these shortcut connections are customizable for each output element.</p>
<p>While the context vector has access to the entire input sequence, we don’t need to worry about forgetting. The alignment between the source and target is learned and controlled by the context vector. Essentially the context vector consumes three pieces of information:</p>
<ul>
<li>encoder hidden states;</li>
<li>decoder hidden states;</li>
<li>alignment between source and target.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/encoder-decoder-attention.png" alt="encoder-decoder model with additive attention layer" /></p>
<p><em>Fig. 4. The encoder-decoder model with additive attention mechanism in <a href="https://arxiv.org/pdf/1409.0473.pdf">Bahdanau et al., 2015</a>.</em></p>
<h3 id="definition">Definition</h3>
<p>Now let’s define the attention mechanism introduced in NMT in a scientific way. Say, we have a source sequence <script type="math/tex">\mathbf{x}</script> of length <script type="math/tex">n</script> and try to output a target sequence <script type="math/tex">\mathbf{y}</script> of length <script type="math/tex">m</script>:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{x} &= [x_1, x_2, \dots, x_n] \\
\mathbf{y} &= [y_1, y_2, \dots, y_m]
\end{aligned} %]]></script>
<p>(Variables in bold indicate that they are vectors; same for everything else in this post.)</p>
<p>The encoder is a <a href="https://www.coursera.org/lecture/nlp-sequence-models/bidirectional-rnn-fyXnn">bidirectional RNN</a> (or other recurrent network setting of your choice) with a forward hidden state <script type="math/tex">\overrightarrow{\boldsymbol{h}}_i</script> and a backward one <script type="math/tex">\overleftarrow{\boldsymbol{h}}_i</script>. A simple concatenation of two represents the encoder state. The motivation is to include both the preceding and following words in the annotation of one word.</p>
<script type="math/tex; mode=display">\boldsymbol{h}_i = [\overrightarrow{\boldsymbol{h}}_i^\top; \overleftarrow{\boldsymbol{h}}_i^\top]^\top, i=1,\dots,n</script>
<p>The decoder network has hidden state <script type="math/tex">\boldsymbol{s}_t=f(\boldsymbol{s}_{t-1}, y_{t-1}, \mathbf{c}_t)</script> for the output word at position t, <script type="math/tex">t=1,\dots,m</script>, where the context vector <script type="math/tex">\mathbf{c}_t</script> is a sum of hidden states of the input sequence, weighted by alignment scores:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\mathbf{c}_t &= \sum_{i=1}^n \alpha_{t,i} \boldsymbol{h}_i & \small{\text{; Context vector for output }y_t}\\
\alpha_{t,i} &= \text{align}(y_t, x_i) & \small{\text{; How well two words }y_t\text{ and }x_i\text{ are aligned.}}\\
&= \frac{\exp(\text{score}(\boldsymbol{s}_{t-1}, \boldsymbol{h}_i))}{\sum_{i'=1}^n \exp(\text{score}(\boldsymbol{s}_{t-1}, \boldsymbol{h}_{i'}))} & \small{\text{; Softmax of some predefined alignment score.}}.
\end{aligned} %]]></script>
<p>The alignment model assigns a score <script type="math/tex">\alpha_{t,i}</script> to the pair of input at position i and output at position t, <script type="math/tex">(y_t, x_i)</script>, based on how well they match. The set of <script type="math/tex">\{\alpha_{t, i}\}</script> are weights defining how much of each source hidden state should be considered for each output. In Bahdanau’s paper, the alignment score <script type="math/tex">\alpha</script> is parametrized by a <strong>feed-forward network</strong> with a single hidden layer and this network is jointly trained with other parts of the model. The score function is therefore in the following form, given that tanh is used as the non-linear activation function:</p>
<script type="math/tex; mode=display">\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \mathbf{v}_a^\top \tanh(\mathbf{W}_a[\boldsymbol{s}_t; \boldsymbol{h}_i])</script>
<p>where both <script type="math/tex">\mathbf{v}_a</script> and <script type="math/tex">\mathbf{W}_a</script> are weight matrices to be learned in the alignment model.</p>
<p>The matrix of alignment scores is a nice byproduct to explicitly show the correlation between source and target words.</p>
<p style="width: 65%;" class="center"><img src="/lil-log/assets/images/bahdanau-fig3.png" alt="alignment matrix" /></p>
<p><em>Fig. 5. Alignment matrix of “L’accord sur l’Espace économique européen a été signé en août 1992” (French) and its English translation “The agreement on the European Economic Area was signed in August 1992”. (Image source: Fig 3 in <a href="https://arxiv.org/pdf/1409.0473.pdf">Bahdanau et al., 2015</a>)</em></p>
<p>Check out this nice <a href="https://www.tensorflow.org/versions/master/tutorials/seq2seq">tutorial</a> by Tensorflow team for more implementation instructions.</p>
<h2 id="a-family-of-attention-mechanisms">A Family of Attention Mechanisms</h2>
<p>With the help of the attention, the dependencies between source and target sequences are not restricted by the in-between distance anymore! Given the big improvement by attention in machine translation, it soon got extended into the computer vision field (<a href="http://proceedings.mlr.press/v37/xuc15.pdf">Xu et al. 2015</a>) and people started exploring various other forms of attention mechanisms (<a href="https://arxiv.org/pdf/1508.04025.pdf">Luong, et al., 2015</a>; <a href="https://arxiv.org/abs/1703.03906">Britz et al., 2017</a>; <a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">Vaswani, et al., 2017</a>).</p>
<h3 id="summary">Summary</h3>
<p>Below is a summary table of several popular attention mechanisms and corresponding alignment score functions:</p>
<table class="info">
<thead>
<tr>
<th>Name</th>
<th>Alignment score function</th>
<th>Citation</th>
</tr>
</thead>
<tbody>
<tr>
<td>Content-base attention</td>
<td><script type="math/tex">\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \text{cosine}[\boldsymbol{s}_t, \boldsymbol{h}_i]</script></td>
<td><a href="https://arxiv.org/abs/1410.5401">Graves2014</a></td>
</tr>
<tr>
<td>Additive(*)</td>
<td><script type="math/tex">\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \mathbf{v}_a^\top \tanh(\mathbf{W}_a[\boldsymbol{s}_t; \boldsymbol{h}_i])</script></td>
<td><a href="https://arxiv.org/pdf/1409.0473.pdf">Bahdanau2015</a></td>
</tr>
<tr>
<td>Location-Base</td>
<td><script type="math/tex">\alpha_{t,i} = \text{softmax}(\mathbf{W}_a \boldsymbol{s}_t)</script><br />Note: This simplifies the softmax alignment to only depend on the target position.</td>
<td><a href="https://arxiv.org/pdf/1508.04025.pdf">Luong2015</a></td>
</tr>
<tr>
<td>General</td>
<td><script type="math/tex">\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \boldsymbol{s}_t^\top\mathbf{W}_a\boldsymbol{h}_i</script><br />where <script type="math/tex">\mathbf{W}_a</script> is a trainable weight matrix in the attention layer.</td>
<td><a href="https://arxiv.org/pdf/1508.04025.pdf">Luong2015</a></td>
</tr>
<tr>
<td>Dot-Product</td>
<td><script type="math/tex">\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \boldsymbol{s}_t^\top\boldsymbol{h}_i</script></td>
<td><a href="https://arxiv.org/pdf/1508.4025.pdf">Luong2015</a></td>
</tr>
<tr>
<td>Scaled Dot-Product(^)</td>
<td><script type="math/tex">\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \frac{\boldsymbol{s}_t^\top\boldsymbol{h}_i}{\sqrt{n}}</script><br />Note: very similar to the dot-product attention except for a scaling factor; where n is the dimension of the source hidden state.</td>
<td><a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">Vaswani2017</a></td>
</tr>
</tbody>
</table>
<p>(*) Referred to as “concat” in Luong, et al., 2015 and as “additive attention” in Vaswani, et al., 2017.<br />
(^) It adds a scaling factor <script type="math/tex">1/\sqrt{n}</script>, motivated by the concern when the input is large, the softmax function may have an extremely small gradient, hard for efficient learning.<br /></p>
<p>Here are a summary of broader categories of attention mechanisms:</p>
<table class="info">
<thead>
<tr>
<th>Name</th>
<th>Definition</th>
<th>Citation</th>
</tr>
</thead>
<tbody>
<tr>
<td>Self-Attention(&)</td>
<td>Relating different positions of the same input sequence. Theoretically the self-attention can adopt any score functions above, but just replace the target sequence with the same input sequence.</td>
<td><a href="https://arxiv.org/pdf/1601.06733.pdf">Cheng2016</a></td>
</tr>
<tr>
<td>Global/Soft</td>
<td>Attending to the entire input state space.</td>
<td><a href="http://proceedings.mlr.press/v37/xuc15.pdf">Xu2015</a></td>
</tr>
<tr>
<td>Local/Hard</td>
<td>Attending to the part of input state space; i.e. a patch of the input image.</td>
<td><a href="http://proceedings.mlr.press/v37/xuc15.pdf">Xu2015</a>; <a href="https://arxiv.org/pdf/1508.04025.pdf">Luong2015</a></td>
</tr>
</tbody>
</table>
<p>(&) Also, referred to as “intra-attention” in Cheng et al., 2016 and some other papers.</p>
<h3 id="self-attention">Self-Attention</h3>
<p><strong>Self-attention</strong>, also known as <strong>intra-attention</strong>, is an attention mechanism relating different positions of a single sequence in order to compute a representation of the same sequence. It has been shown to be very useful in machine reading, abstractive summarization, or image description generation.</p>
<p>The <a href="https://arxiv.org/pdf/1601.06733.pdf">long short-term memory network</a> paper used self-attention to do machine reading. In the example below, the self-attention mechanism enables us to learn the correlation between the current words and the previous part of the sentence.</p>
<p style="width: 70%;" class="center"><img src="/lil-log/assets/images/cheng2016-fig1.png" alt="intra-attention" /></p>
<p><em>Fig. 6. The current word is in red and the size of the blue shade indicates the activation level. (Image source: <a href="https://arxiv.org/pdf/1601.06733.pdf">Cheng et al., 2016</a>)</em></p>
<h3 id="soft-vs-hard-attention">Soft vs Hard Attention</h3>
<p>In the <a href="http://proceedings.mlr.press/v37/xuc15.pdf">show, attend and tell</a> paper, attention mechanism is applied to images to generate captions. The image is first encoded by a CNN to extract features. Then a LSTM decoder consumes the convolution features to produce descriptive words one by one, where the weights are learned through attention. The visualization of the attention weights clearly demonstrates which regions of the image the model is paying attention to so as to output a certain word.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/xu2015-fig6b.png" alt="show-attend-and-tell" /></p>
<p><em>Fig. 7. “A woman is throwing a frisbee in a park.” (Image source: Fig. 6(b) in <a href="http://proceedings.mlr.press/v37/xuc15.pdf">Xu et al. 2015</a>)</em></p>
<p>This paper first proposed the distinction between “soft” vs “hard” attention, based on whether the attention has access to the entire image or only a patch:</p>
<ul>
<li><strong>Soft</strong> Attention: the alignment weights are learned and placed “softly” over all patches in the source image; essentially the same type of attention as in <a href="https://arxiv.org/abs/1409.0473">Bahdanau et al., 2015</a>.
<ul>
<li><em>Pro</em>: the model is smooth and differentiable.</li>
<li><em>Con</em>: expensive when the source input is large.</li>
</ul>
</li>
<li><strong>Hard</strong> Attention: only selects one patch of the image to attend to at a time.
<ul>
<li><em>Pro</em>: less calculation at the inference time.</li>
<li><em>Con</em>: the model is non-differentiable and requires more complicated techniques such as variance reduction or reinforcement learning to train. (<a href="https://arxiv.org/abs/1508.04025">Luong, et al., 2015</a>)</li>
</ul>
</li>
</ul>
<h3 id="global-vs-local-attention">Global vs Local Attention</h3>
<p><a href="https://arxiv.org/pdf/1508.04025.pdf">Luong, et al., 2015</a> proposed the “global” and “local” attention. The global attention is similar to the soft attention, while the local one is an interesting blend between <a href="#soft-vs-hard-attention">hard and soft</a>, an improvement over the hard attention to make it differentiable: the model first predicts a single aligned position for the current target word and a window centered around the source position is then used to compute a context vector.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/luong2015-fig2-3.png" alt="global-local-attention" /></p>
<p><em>Fig. 8. Global vs local attention (Image source: Fig 2 & 3 in <a href="https://arxiv.org/pdf/1508.04025.pdf">Luong, et al., 2015</a>)</em></p>
<h2 id="neural-turing-machines">Neural Turing Machines</h2>
<p>Alan Turing in <a href="https://en.wikipedia.org/wiki/Turing_machine">1936</a> proposed a minimalistic model of computation. It is composed of a infinitely long tape and a head to interact with the tape. The tape has countless cells on it, each filled with a symbol: 0, 1 or blank (“ “). The operation head can read symbols, edit symbols and move left/right on the tape. Theoretically a Turing machine can simulate any computer algorithm, irrespective of how complex or expensive the procedure might be. The infinite memory gives a Turing machine an edge to be mathematically limitless. However, infinite memory is not feasible in real modern computers and then we only consider Turing machine as a mathematical model of computation.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/turing-machine.jpg" alt="turing-machine" /></p>
<p><em>Fig. 9. How a Turing machine looks like: a tape + a head that handles the tape. (Image source: http://aturingmachine.com/)</em></p>
<p><strong>Neural Turing Machine</strong> (<strong>NTM</strong>, <a href="https://arxiv.org/abs/1410.5401">Graves, Wayne & Danihelka, 2014</a>) is a model architecture for coupling a neural network with external memory storage. The memory mimics the Turing machine tape and the neural network controls the operation heads to read from or write to the tape. However, the memory in NTM is finite, and thus it probably looks more like a “Neural <a href="https://en.wikipedia.org/wiki/Von_Neumann_architecture">von Neumann</a> Machine”.</p>
<p>NTM contains two major components, a <em>controller</em> neural network and a <em>memory</em> bank.
Controller: is in charge of executing operations on the memory. It can be any type of neural network, feed-forward or recurrent.
Memory: stores processed information. It is a matrix of size <script type="math/tex">N \times M</script>, containing N vector rows and each has <script type="math/tex">M</script> dimensions.</p>
<p>In one update iteration, the controller processes the input and interacts with the memory bank accordingly to generate output. The interaction is handled by a set of parallel <em>read</em> and <em>write</em> heads. Both read and write operations are “blurry” by softly attending to all the memory addresses.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/NTM.png" alt="turing-machine" /></p>
<p><em>Fig 10. Neural Turing Machine Architecture.</em></p>
<h3 id="reading-and-writing">Reading and Writing</h3>
<p>When reading from the memory at time t, an attention vector of size <script type="math/tex">N</script>, <script type="math/tex">\mathbf{w}_t</script> controls how much attention to assign to different memory locations (matrix rows). The read vector <script type="math/tex">\mathbf{r}_t</script> is a sum weighted by attention intensity:</p>
<script type="math/tex; mode=display">\mathbf{r}_i = \sum_{i=1}^N w_t(i)\mathbf{M}_t(i)\text{, where }\sum_{i=1}^N w_t(i)=1, \forall i: 0 \leq w_t(i) \leq 1</script>
<p>where <script type="math/tex">w_t(i)</script> is the <script type="math/tex">i</script>-th element in <script type="math/tex">\mathbf{w}_t</script> and <script type="math/tex">\mathbf{M}_t(i)</script> is the <script type="math/tex">i</script>-th row vector in the memory.</p>
<p>When writing into the memory at time t, as inspired by the input and forget gates in LSTM, a write head first wipes off some old content according to an erase vector <script type="math/tex">\mathbf{e}_t</script> and then adds new information by an add vector <script type="math/tex">\mathbf{a}_t</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\tilde{\mathbf{M}}_t(i) &= \mathbf{M}_{t-1}(i) [\mathbf{1} - w_t(i)\mathbf{e}_t] &\scriptstyle{\text{; erase}}\\
\mathbf{M}_t(i) &= \tilde{\mathbf{M}}_t(i) + w_t(i) \mathbf{a}_t &\scriptstyle{\text{; add}}
\end{aligned} %]]></script>
<h3 id="attention-mechanisms">Attention Mechanisms</h3>
<p>In Neural Turing Machine, how to generate the attention distribution <script type="math/tex">\mathbf{w}_t</script> depends on the addressing mechanisms: NTM uses a mixture of content-based and location-based addressings.</p>
<p><strong>Content-based addressing</strong></p>
<p>The content-addressing creates attention vectors based on the similarity between the key vector <script type="math/tex">\mathbf{k}_t</script> extracted by the controller from the input and memory rows. The content-based attention scores are computed as cosine similarity and then normalized by softmax. In addition, NTM adds a strength multiplier <script type="math/tex">\beta_t</script> to amplify or attenuate the focus of the distribution.</p>
<script type="math/tex; mode=display">w_t^c(i)
= \text{softmax}(\beta_t \cdot \text{cosine}[\mathbf{k}_t, \mathbf{M}_t(i)])
= \frac{\exp(\beta_t \frac{\mathbf{k}_t \cdot \mathbf{M}_t(i)}{\|\mathbf{k}_t\| \cdot \|\mathbf{M}_t(i)\|})}{\sum_{j=1}^N \exp(\beta_t \frac{\mathbf{k}_t \cdot \mathbf{M}_t(j)}{\|\mathbf{k}_t\| \cdot \|\mathbf{M}_t(j)\|})}</script>
<p><strong>Interpolation</strong></p>
<p>Then an interpolation gate scalar <script type="math/tex">g_t</script> is used to blend the newly generated content-based attention vector with the attention weights in the last time step:</p>
<script type="math/tex; mode=display">\mathbf{w}_t^g = g_t \mathbf{w}_t^c + (1 - g_t) \mathbf{w}_{t-1}</script>
<p><strong>Location-based addressing</strong></p>
<p>The location-based addressing sums up the values at different positions in the attention vector, weighted by a weighting distribution over allowable integer shifts. It is equivalent to a 1-d convolution with a kernel <script type="math/tex">\mathbf{s}_t(.)</script>, a function of the position offset. There are multiple ways to define this distribution. See Fig. 11. for inspiration.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/shift-weighting.png" alt="shift-weighting" /></p>
<p><em>Fig. 11. Two ways to represent the shift weighting distribution <script type="math/tex">\mathbf{s}_t</script>.</em></p>
<p>Finally the attention distribution is enhanced by a sharpening scalar <script type="math/tex">\gamma_t \geq 1</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\tilde{w}_t(i) &= \sum_{j=1}^N w_t^g(j) s_t(i-j) & \scriptstyle{\text{; circular convolution}}\\
w_t(i) &= \frac{\tilde{w}_t(i)^{\gamma_t}}{\sum_{j=1}^N \tilde{w}_t(j)^{\gamma_t}} & \scriptstyle{\text{; sharpen}}
\end{aligned} %]]></script>
<p>The complete process of generating the attention vector <script type="math/tex">\mathbf{w}_t</script> at time step t is illustrated in Fig. X. All the parameters produced by the controller are unique for each head. If there are multiple read and write heads in parallel, the controller would output multiple sets.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/NTM-flow-addressing.png" alt="NTM-flow-addressing" /></p>
<p><em>Fig. 12. Flow diagram of the addressing mechanisms in Neural Turing Machine. (Image source: <a href="https://arxiv.org/abs/1410.5401">Graves, Wayne & Danihelka, 2014</a>)</em></p>
<h2 id="pointer-network">Pointer Network</h2>
<p>In problems like sorting or travelling salesman, both input and output are sequential data. Unfortunately, they cannot be easily solved by classic seq-2-seq or NMT models, given that the discrete categories of output elements are not determined in advance, but depends on the variable input size. The <strong>Pointer Net</strong> (<strong>Ptr-Net</strong>; <a href="https://arxiv.org/abs/1506.03134">Vinyals, et al. 2015</a>) is proposed to resolve this type of problems: When the output elements correspond to <em>positions</em> in an input sequence. Rather than using attention to blend hidden units of an encoder into a context vector (See Fig. 8), the Pointer Net applies attention over the input elements to pick one as the output at each decoder step.</p>
<p style="width: 75%;" class="center"><img src="/lil-log/assets/images/ptr-net.png" alt="pointer network" /></p>
<p><em>Fig. 13. The architecture of a Pointer Network model. (Image source: <a href="https://arxiv.org/abs/1506.03134">Vinyals, et al. 2015</a>)</em></p>
<p>The Ptr-Net outputs a sequence of integer indices, <script type="math/tex">\boldsymbol{c} = (c_1, \dots, c_m)</script> given a sequence of input vectors <script type="math/tex">\boldsymbol{x} = (x_1, \dots, x_n)</script> and <script type="math/tex">1 \leq c_i \leq n</script>. The model still embraces an encoder-decoder framework. The encoder and decoder hidden states are denoted as <script type="math/tex">(\boldsymbol{h}_1, \dots, \boldsymbol{h}_n)</script> and <script type="math/tex">(\boldsymbol{s}_1, \dots, \boldsymbol{s}_m)</script>, respectively. Note that <script type="math/tex">\mathbf{s}_i</script> is the output gate after cell activation in the decoder. The Ptr-Net applies addictive attention between states and then normalizes it by softmax to model the output conditional probability:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
y_i &= p(c_i \vert c_1, \dots, c_{i-1}, \boldsymbol{x}) \\
&= \sigma(\text{score}(\boldsymbol{s}_t; \boldsymbol{h}_i)) = \sigma(\mathbf{v}_a^\top \tanh(\mathbf{W}_a[\boldsymbol{s}_t; \boldsymbol{h}_i]))
\end{aligned} %]]></script>
<p>The attention mechanism is simplified, as Ptr-Net does not blend the encoder states into the output with attention weights. In this way, the output only responds to the positions but not the input content.</p>
<h2 id="transformer">Transformer</h2>
<p><a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">“Attention is All you Need”</a>
(Vaswani, et al., 2017), without a doubt, is one of the most impactful and interesting paper in 2017. It presented a lot of improvements to the soft attention and make it possible to do seq2seq modeling <em>without</em> recurrent network units. The proposed “<strong>transformer</strong>” model is entirely built on the self-attention mechanisms without using sequence-aligned recurrent architecture.</p>
<p>The secret recipe is carried in its model architecture.</p>
<h3 id="key-value-and-query">Key, Value and Query</h3>
<p>The major component in the transformer is the unit of <em>multi-head self-attention mechanism</em>. The transformer views the encoded representation of the input as a set of <strong>key</strong>-<strong>value</strong> pairs, <script type="math/tex">(\mathbf{K}, \mathbf{V})</script>, both of dimension <script type="math/tex">n</script> (input sequence length); in the context of NMT, both the keys and values are the encoder hidden states. In the decoder, the previous output is compressed into a <strong>query</strong> (<script type="math/tex">\mathbf{Q}</script> of dimension <script type="math/tex">m</script>) and the next output is produced by mapping this query and the set of keys and values.</p>
<p>The transformer adopts the <a href="#summary">scaled dot-product attention</a>: the output is a weighted sum of the values, where the weight assigned to each value is determined by the dot-product of the query with all the keys:</p>
<script type="math/tex; mode=display">\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{n}})\mathbf{V}</script>
<h3 id="multi-head-self-attention">Multi-Head Self-Attention</h3>
<p style="width: 40%;" class="center"><img src="/lil-log/assets/images/multi-head-attention.png" alt="multi-head scaled dot-product attention" /></p>
<p><em>Fig. 14. Multi-head scaled dot-product attention mechanism. (Image source: Fig 2 in <a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">Vaswani, et al., 2017</a>)</em></p>
<p>Rather than only computing the attention once, the multi-head mechanism runs through the scaled dot-product attention multiple times in parallel. The independent attention outputs are simply concatenated and linearly transformed into the expected dimensions. I assume the motivation is because ensembling always helps? ;) According to the paper, <em>“multi-head attention allows the model to jointly attend to information from different representation <strong>subspaces</strong> at different positions. With a single attention head, averaging inhibits this.”</em></p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) &= [\text{head}_1; \dots; \text{head}_h]\mathbf{W}^O \\
\text{where head}_i &= \text{Attention}(\mathbf{Q}\mathbf{W}^Q_i, \mathbf{K}\mathbf{W}^K_i, \mathbf{V}\mathbf{W}^V_i)
\end{aligned} %]]></script>
<p>where <script type="math/tex">\mathbf{W}^Q_i</script>, <script type="math/tex">\mathbf{W}^K_i</script>, <script type="math/tex">\mathbf{W}^V_i</script>, and <script type="math/tex">\mathbf{W}^O</script> are parameter matrices to be learned.</p>
<h3 id="encoder">Encoder</h3>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/transformer-encoder.png" alt="Transformer encoder" /></p>
<p><em>Fig. 15. The transformer’s encoder. (Image source: <a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">Vaswani, et al., 2017</a>)</em></p>
<p>The encoder generates an attention-based representation with capability to locate a specific piece of information from a potentially infinitely-large context.</p>
<ul>
<li>A stack of N=6 identical layers.</li>
<li>Each layer has a <strong>multi-head self-attention layer</strong> and a simple position-wise <strong>fully connected feed-forward network</strong>.</li>
<li>Each sub-layer adopts a <a href="https://arxiv.org/pdf/1512.03385.pdf"><strong>residual</strong></a> connection and a layer <strong>normalization</strong>.
All the sub-layers output data of the same dimension <script type="math/tex">d_\text{model} = 512</script>.</li>
</ul>
<h3 id="decoder">Decoder</h3>
<p style="width: 58%;" class="center"><img src="/lil-log/assets/images/transformer-decoder.png" alt="Transformer decoder" /></p>
<p><em>Fig. 16. The transformer’s decoder. (Image source: <a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">Vaswani, et al., 2017</a>)</em></p>
<p>The decoder is able to retrieval from the encoded representation.</p>
<ul>
<li>A stack of N = 6 identical layers</li>
<li>Each layer has two sub-layers of multi-head attention mechanisms and one sub-layer of fully-connected feed-forward network.</li>
<li>Similar to the encoder, each sub-layer adopts a residual connection and a layer normalization.</li>
<li>The first multi-head attention sub-layer is <strong>modified</strong> to prevent positions from attending to subsequent positions, as we don’t want to look into the future of the target sequence when predicting the current position.</li>
</ul>
<h3 id="full-architecture">Full Architecture</h3>
<p>Finally here is the complete view of the transformer’s architecture:</p>
<ul>
<li>Both the source and target sequences first go through embedding layers to produce data of the same dimension <script type="math/tex">d_\text{model} =512</script>.</li>
<li>To preserve the position information, a sinusoid-wave-based positional encoding is applied and summed with the embedding output.</li>
<li>A softmax and linear layer are added to the final decoder output.</li>
</ul>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/transformer.png" alt="Transformer model" /></p>
<p><em>Fig. 17. The full model architecture of the transformer. (Image source: Fig 1 & 2 in <a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">Vaswani, et al., 2017</a>.)</em></p>
<p>Try to implement the transformer model is an interesting experience, here is mine: <a href="https://github.com/lilianweng/transformer-tensorflow">lilianweng/transformer-tensorflow</a>. Read the comments in the code if you are interested.</p>
<h2 id="snail">SNAIL</h2>
<p>The transformer has no recurrent or convolutional structure, even with the positional encoding added to the embedding vector, the sequential order is only weakly incorporated. For problems sensitive to the positional dependency like <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html">reinforcement learning</a>, this can be a big problem.</p>
<p>The <strong>Simple Neural Attention <a href="http://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/">Meta-Learner</a></strong> (<strong>SNAIL</strong>) (<a href="http://metalearning.ml/papers/metalearn17_mishra.pdf">Mishra et al., 2017</a>) was developed partially to resolve the problem with <a href="#full-architecture">positioning</a> in the transformer model by combining the self-attention mechanism in transformer with <a href="https://deepmind.com/blog/wavenet-generative-model-raw-audio/">temporal convolutions</a>. It has been demonstrated to be good at both supervised learning and reinforcement learning tasks.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/snail.png" alt="SNAIL" /></p>
<p><em>Fig. 18. SNAIL model architecture (Image source: <a href="http://metalearning.ml/papers/metalearn17_mishra.pdf">Mishra et al., 2017</a>)</em></p>
<p>SNAIL was born in the field of meta-learning, which is another big topic worthy of a post by itself. But in simple words, the meta-learning model is expected to be generalizable to novel, unseen tasks in the similar distribution. Read <a href="http://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/">this</a> nice introduction if interested.</p>
<h2 id="self-attention-gan">Self-Attention GAN</h2>
<p><em>Self-Attention GAN</em> (<strong>SAGAN</strong>; <a href="https://arxiv.org/pdf/1805.08318.pdf">Zhang et al., 2018</a>) adds self-attention layers into <a href="/lil-log/2017/08/20/from-GAN-to-WGAN.html">GAN</a> to enable both the generator and the discriminator to better model relationships between spatial regions.</p>
<p>The classic <a href="https://arxiv.org/abs/1511.06434">DCGAN</a> (Deep Convolutional GAN) represents both discriminator and generator as multi-layer convolutional networks. However, the representation capacity of the network is restrained by the filter size, as the feature of one pixel is limited to a small local region. In order to connect regions far apart, the features have to be dilute through layers of convolutional operations and the dependencies are not guaranteed to be maintained.</p>
<p>As the (soft) self-attention in the vision context is designed to explicitly learn the relationship between one pixel and all other positions, even regions far apart, it can easily capture global dependencies. Hence GAN equipped with self-attention is expected to <em>handle details better</em>, hooray!</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/conv-vs-self-attention.png" alt="Conv vs self-attention on images" /></p>
<p><em>Fig. 19. Convolution operation and self-attention have access to regions of very different sizes.</em></p>
<p>The SAGAN adopts the <a href="https://arxiv.org/pdf/1711.07971.pdf">non-local neural network</a> to apply the attention computation. The convolutional image feature maps <script type="math/tex">\mathbf{x}</script> is branched out into three copies, corresponding to the concepts of <a href="#key-value-and-query">key, value, and query</a> in the transformer:</p>
<ul>
<li>Key: <script type="math/tex">f(\mathbf{x}) = \mathbf{W}_f \mathbf{x}</script></li>
<li>Query: <script type="math/tex">g(\mathbf{x}) = \mathbf{W}_g \mathbf{x}</script></li>
<li>Value: <script type="math/tex">h(\mathbf{x}) = \mathbf{W}_h \mathbf{x}</script></li>
</ul>
<p>Then we apply the dot-product attention to output the self-attention feature maps:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\alpha_{i,j} &= \text{softmax}(f(\mathbf{x}_i)^\top g(\mathbf{x}_j)) \\
\mathbf{o}_j &= \sum_{i=1}^N \alpha_{i,j} h(\mathbf{x}_i)
\end{aligned} %]]></script>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/self-attention-gan-network.png" alt="SAGAN" /></p>
<p><em>Fig. 20. The self-attention mechanism in SAGAN. (Image source: Fig. 2 in <a href="https://arxiv.org/pdf/1805.08318.pdf">Zhang et al., 2018</a>)</em></p>
<p>Note that <script type="math/tex">\alpha_{i,j}</script> is one entry in the attention map, indicating how much attention the model should pay to the i-th position when synthesizing the j-th location. <script type="math/tex">\mathbf{W}_f</script>, <script type="math/tex">\mathbf{W}_g</script>, and <script type="math/tex">\mathbf{W}_h</script> are all 1x1 convolution filters. If you feel that 1x1 conv sounds like a weird concept (i.e., isn’t it just to multiply the whole feature map with one number?), watch this short <a href="https://www.coursera.org/lecture/convolutional-neural-networks/networks-in-networks-and-1x1-convolutions-ZTb8x">tutorial</a> by Andrew Ng. The output <script type="math/tex">\mathbf{o}_j</script> is a column vector of the final output <script type="math/tex">\mathbf{o}= (\mathbf{o}_1, \mathbf{o}_2, \dots, \mathbf{o}_j, \dots, \mathbf{o}_N)</script>.</p>
<p>Furthermore, the output of the attention layer is multiplied by a scale parameter and added back to the original input feature map:</p>
<script type="math/tex; mode=display">\mathbf{y} = \mathbf{x}_i + \gamma \mathbf{o}_i</script>
<p>While the scaling parameter <script type="math/tex">\gamma</script> is increased gradually from 0 during the training, the network is configured to first rely on the cues in the local regions and then gradually learn to assign more weight to the regions that are further away.</p>
<p style="width: 100%;" class="center"><img src="/lil-log/assets/images/SAGAN-examples.png" alt="SAGAN examples" /></p>
<p><em>Fig. 21. 128×128 example images generated by SAGAN for different classes. (Image source: Partial Fig. 6 in <a href="https://arxiv.org/pdf/1805.08318.pdf">Zhang et al., 2018</a>)</em></p>
<hr />
<p>Cited as:</p>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{weng2018attention,
title = "Attention? Attention!",
author = "Weng, Lilian",
journal = "lilianweng.github.io/lil-log",
year = "2018",
url = "http://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html"
}
</code></pre></div></div>
<p><em>If you notice mistakes and errors in this post, don’t hesitate to contact me at [lilian dot wengweng at gmail dot com] and I would be very happy to correct them right away!</em></p>
<p>See you in the next post :D</p>
<h2 id="references">References</h2>
<p>[1] <a href="http://www.wildml.com/2016/01/attention-and-memory-in-deep-learning-and-nlp/">“Attention and Memory in Deep Learning and NLP.”</a> - Jan 3, 2016 by Denny Britz</p>
<p>[2] <a href="https://github.com/tensorflow/nmt">“Neural Machine Translation (seq2seq) Tutorial”</a></p>
<p>[3] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. <a href="https://arxiv.org/pdf/1409.0473.pdf">“Neural machine translation by jointly learning to align and translate.”</a> ICLR 2015.</p>
<p>[4] Kelvin Xu, Jimmy Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan Salakhudinov, Rich Zemel, and Yoshua Bengio. <a href="http://proceedings.mlr.press/v37/xuc15.pdf">“Show, attend and tell: Neural image caption generation with visual attention.”</a> ICML, 2015.</p>
<p>[5] Ilya Sutskever, Oriol Vinyals, and Quoc V. Le. <a href="https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf">“Sequence to sequence learning with neural networks.”</a> NIPS 2014.</p>
<p>[6] Thang Luong, Hieu Pham, Christopher D. Manning. <a href="https://arxiv.org/pdf/1508.04025.pdf">“Effective Approaches to Attention-based Neural Machine Translation.”</a> EMNLP 2015.</p>
<p>[7] Denny Britz, Anna Goldie, Thang Luong, and Quoc Le. <a href="https://arxiv.org/abs/1703.03906">“Massive exploration of neural machine translation architectures.”</a> ACL 2017.</p>
<p>[8] Ashish Vaswani, et al. <a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf">“Attention is all you need.”</a> NIPS 2017.</p>
<p>[9] Jianpeng Cheng, Li Dong, and Mirella Lapata. <a href="https://arxiv.org/pdf/1601.06733.pdf">“Long short-term memory-networks for machine reading.”</a> EMNLP 2016.</p>
<p>[10] Xiaolong Wang, et al. <a href="https://arxiv.org/pdf/1711.07971.pdf">“Non-local Neural Networks.”</a> CVPR 2018</p>
<p>[11] Han Zhang, Ian Goodfellow, Dimitris Metaxas, and Augustus Odena. <a href="https://arxiv.org/pdf/1805.08318.pdf">“Self-Attention Generative Adversarial Networks.”</a> arXiv preprint arXiv:1805.08318 (2018).</p>
<p>[12] Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel. <a href="https://arxiv.org/abs/1707.03141">“A simple neural attentive meta-learner.”</a> ICLR 2018.</p>
<p>[13] <a href="https://deepmind.com/blog/wavenet-generative-model-raw-audio/">“WaveNet: A Generative Model for Raw Audio”</a> - Sep 8, 2016 by DeepMind.</p>
<p>[14] Oriol Vinyals, Meire Fortunato, and Navdeep Jaitly. <a href="https://arxiv.org/abs/1506.03134">“Pointer networks.”</a> NIPS 2015.</p>
<p>[15] Alex Graves, Greg Wayne, and Ivo Danihelka. <a href="https://arxiv.org/abs/1410.5401">“Neural turing machines.”</a> arXiv preprint arXiv:1410.5401 (2014).</p>Lilian WengAttention has been a fairly popular concept and a useful tool in the deep learning community in recent years. In this post, we are gonna look into how attention was invented, and various attention mechanisms and models, such as transformer and SNAIL.Implementing Deep Reinforcement Learning Models with Tensorflow + OpenAI Gym2018-05-05T16:00:00+00:002018-05-05T16:00:00+00:00https://lilianweng.github.io/lil-log/2018/05/05/implementing-deep-reinforcement-learning-models<blockquote>
<p>Let’s see how to implement a number of classic deep reinforcement learning models in code.</p>
</blockquote>
<!--more-->
<p>The full implementation is available in <a href="https://github.com/lilianweng/deep-reinforcement-learning-gym">lilianweng/deep-reinforcement-learning-gym</a></p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#environment-setup" id="markdown-toc-environment-setup">Environment Setup</a></li>
<li><a href="#gym-environment" id="markdown-toc-gym-environment">Gym Environment</a></li>
<li><a href="#naive-q-learning" id="markdown-toc-naive-q-learning">Naive Q-Learning</a></li>
<li><a href="#deep-q-network" id="markdown-toc-deep-q-network">Deep Q-Network</a> <ul>
<li><a href="#double-q-learning" id="markdown-toc-double-q-learning">Double Q-Learning</a></li>
<li><a href="#dueling-q-network" id="markdown-toc-dueling-q-network">Dueling Q-Network</a></li>
</ul>
</li>
<li><a href="#monte-carlo-policy-gradient" id="markdown-toc-monte-carlo-policy-gradient">Monte-Carlo Policy Gradient</a></li>
<li><a href="#actor-critic" id="markdown-toc-actor-critic">Actor-Critic</a></li>
<li><a href="#references" id="markdown-toc-references">References</a></li>
</ul>
<p>In the previous two posts, I have introduced the algorithms of many deep reinforcement learning models. Now it is the time to get our hands dirty and practice how to implement the models in the wild. The implementation is gonna be built in Tensorflow and OpenAI <a href="https://github.com/openai/gym">gym</a> environment. The full version of the code in this tutorial is available in <a href="https://github.com/lilianweng/deep-reinforcement-learning-gym">[lilian/deep-reinforcement-learning-gym]</a>.</p>
<h2 id="environment-setup">Environment Setup</h2>
<p>0) Make sure you have <a href="https://docs.brew.sh/Installation">Homebrew</a> installed:</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>/usr/bin/ruby <span class="nt">-e</span> <span class="s2">"</span><span class="k">$(</span>curl <span class="nt">-fsSL</span> https://raw.githubusercontent.com/Homebrew/install/master/install<span class="k">)</span><span class="s2">"</span>
</code></pre></div></div>
<p>1) I would suggest starting a virtualenv for your development. It makes life so much easier when you have multiple projects with conflicting requirements; i.e. one works in Python 2.7 while the other is only compatible with Python 3.5+.</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># Install python virtualenv</span>
brew install pyenv-virtualenv
<span class="c"># Create a virtual environment of any name you like with Python 3.6.4 support</span>
pyenv virtualenv 3.6.4 workspace
<span class="c"># Activate the virtualenv named "workspace"</span>
pyenv activate workspace
</code></pre></div></div>
<p><em>[*] For every new installation below, please make sure you are in the virtualenv.</em></p>
<p>2) Install OpenAI gym according to the <a href="https://github.com/openai/gym#installation">instruction</a>. For a minimal installation, run:</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>git clone https://github.com/openai/gym.git
<span class="nb">cd </span>gym
pip install <span class="nt">-e</span> <span class="nb">.</span>
</code></pre></div></div>
<p>If you are interested in playing with Atari games or other advanced packages, please continue to get a couple of system packages installed.</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>brew install cmake boost boost-python sdl2 swig wget
</code></pre></div></div>
<p>For Atari, go to the gym directory and pip install it. This <a href="http://alvinwan.com/installing-arcade-learning-environment-with-python3-on-macosx/">post</a> is pretty helpful if you have troubles with ALE (arcade learning environment) installation.</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>pip install <span class="nt">-e</span> <span class="s1">'.[atari]'</span>
</code></pre></div></div>
<p>3) Finally clone the “playground” code and install the requirements.</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>git clone git@github.com:lilianweng/deep-reinforcement-learning-gym.git
<span class="nb">cd </span>deep-reinforcement-learning-gym
pip install <span class="nt">-e</span> <span class="nb">.</span> <span class="c"># install the "playground" project.</span>
pip install <span class="nt">-r</span> requirements.txt <span class="c"># install required packages.</span>
</code></pre></div></div>
<h2 id="gym-environment">Gym Environment</h2>
<p>The <a href="https://gym.openai.com/">OpenAI Gym</a> toolkit provides a set of physical simulation environments, games, and robot simulators that we can play with and design reinforcement learning agents for. An environment object can be initialized by <code class="highlighter-rouge">gym.make("{environment name}"</code>:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">gym</span>
<span class="n">env</span> <span class="o">=</span> <span class="n">gym</span><span class="o">.</span><span class="n">make</span><span class="p">(</span><span class="s">"MsPacman-v0"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/lil-log/assets/images/pacman-original.gif" alt="Pacman" /></p>
<p>The formats of action and observation of an environment are defined by <code class="highlighter-rouge">env.action_space</code> and <code class="highlighter-rouge">env.observation_space</code>, respectively.</p>
<p>Types of gym <a href="https://gym.openai.com/docs/#spaces">spaces</a>:</p>
<ul>
<li><code class="highlighter-rouge">gym.spaces.Discrete(n)</code>: discrete values from 0 to n-1.</li>
<li><code class="highlighter-rouge">gym.spaces.Box</code>: a multi-dimensional vector of numeric values, the upper and lower bounds of each dimension are defined by <code class="highlighter-rouge">Box.low</code> and <code class="highlighter-rouge">Box.high</code>.</li>
</ul>
<p>We interact with the env through two major api calls:</p>
<p><strong><code class="highlighter-rouge">ob = env.reset()</code></strong></p>
<ul>
<li>Resets the env to the original setting.</li>
<li>Returns the initial observation.</li>
</ul>
<p><strong><code class="highlighter-rouge">ob_next, reward, done, info = env.step(action)</code></strong></p>
<ul>
<li>Applies one action in the env which should be compatible with <code class="highlighter-rouge">env.action_space</code>.</li>
<li>Gets back the new observation <code class="highlighter-rouge">ob_next</code> (env.observation_space), a reward (float), a <code class="highlighter-rouge">done</code> flag (bool), and other meta information (dict). If <code class="highlighter-rouge">done=True</code>, the episode is complete and we should reset the env to restart. Read more <a href="https://gym.openai.com/docs/#observations">here</a>.</li>
</ul>
<h2 id="naive-q-learning">Naive Q-Learning</h2>
<p><a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#q-learning-off-policy-td-control">Q-learning</a> (Watkins & Dayan, 1992) learns the action value (“Q-value”) and update it according to the <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#bellman-equations">Bellman equation</a>. The key point is while estimating what is the next action, it does not follow the current policy but rather adopt the best Q value (the part in red) independently.</p>
<script type="math/tex; mode=display">Q(s, a) \leftarrow (1 - \alpha) Q(s, a) + \alpha (r + \gamma \color{red}{\max_{a' \in \mathcal{A}} Q(s', a')})</script>
<p>In a naive implementation, the Q value for all (s, a) pairs can be simply tracked in a dict. No complicated machine learning model is involved yet.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">defaultdict</span>
<span class="n">Q</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span>
<span class="n">gamma</span> <span class="o">=</span> <span class="mf">0.99</span> <span class="c"># Discounting factor</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="c"># soft update param</span>
<span class="n">env</span> <span class="o">=</span> <span class="n">gym</span><span class="o">.</span><span class="n">make</span><span class="p">(</span><span class="s">"CartPole-v0"</span><span class="p">)</span>
<span class="n">actions</span> <span class="o">=</span> <span class="nb">range</span><span class="p">(</span><span class="n">env</span><span class="o">.</span><span class="n">action_space</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">update_Q</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="n">r</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">s_next</span><span class="p">,</span> <span class="n">done</span><span class="p">):</span>
<span class="n">max_q_next</span> <span class="o">=</span> <span class="nb">max</span><span class="p">([</span><span class="n">Q</span><span class="p">[</span><span class="n">s_next</span><span class="p">,</span> <span class="n">a</span><span class="p">]</span> <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">actions</span><span class="p">])</span>
<span class="c"># Do not include the next state's value if currently at the terminal state.</span>
<span class="n">Q</span><span class="p">[</span><span class="n">s</span><span class="p">,</span> <span class="n">a</span><span class="p">]</span> <span class="o">+=</span> <span class="n">alpha</span> <span class="o">*</span> <span class="p">(</span><span class="n">r</span> <span class="o">+</span> <span class="n">gamma</span> <span class="o">*</span> <span class="n">max_q_next</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">done</span><span class="p">)</span> <span class="o">-</span> <span class="n">Q</span><span class="p">[</span><span class="n">s</span><span class="p">,</span> <span class="n">a</span><span class="p">])</span>
</code></pre></div></div>
<p>Most gym environments have a multi-dimensional continuous observation space (<code class="highlighter-rouge">gym.spaces.Box</code>). To make sure our Q dictionary will not explode by trying to memorize an infinite number of keys, we apply a wrapper to discretize the observation. The concept of <a href="https://github.com/openai/gym/tree/master/gym/wrappers">wrappers</a> is very powerful, with which we are capable to customize observation, action, step function, etc. of an env. No matter how many wrappers are applied, <code class="highlighter-rouge">env.unwrapped</code> always gives back the internal original environment object.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">gym</span>
<span class="k">class</span> <span class="nc">DiscretizedObservationWrapper</span><span class="p">(</span><span class="n">gym</span><span class="o">.</span><span class="n">ObservationWrapper</span><span class="p">):</span>
<span class="s">"""This wrapper converts a Box observation into a single integer.
"""</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">n_bins</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">__init__</span><span class="p">(</span><span class="n">env</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="p">,</span> <span class="n">Box</span><span class="p">)</span>
<span class="n">low</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">observation_space</span><span class="o">.</span><span class="n">low</span> <span class="k">if</span> <span class="n">low</span> <span class="ow">is</span> <span class="bp">None</span> <span class="k">else</span> <span class="n">low</span>
<span class="n">high</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">observation_space</span><span class="o">.</span><span class="n">high</span> <span class="k">if</span> <span class="n">high</span> <span class="ow">is</span> <span class="bp">None</span> <span class="k">else</span> <span class="n">high</span>
<span class="bp">self</span><span class="o">.</span><span class="n">n_bins</span> <span class="o">=</span> <span class="n">n_bins</span>
<span class="bp">self</span><span class="o">.</span><span class="n">val_bins</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">l</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">n_bins</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">l</span><span class="p">,</span> <span class="n">h</span> <span class="ow">in</span>
<span class="nb">zip</span><span class="p">(</span><span class="n">low</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span> <span class="n">high</span><span class="o">.</span><span class="n">flatten</span><span class="p">())]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">observation_space</span> <span class="o">=</span> <span class="n">Discrete</span><span class="p">(</span><span class="n">n_bins</span> <span class="o">**</span> <span class="n">low</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">def</span> <span class="nf">_convert_to_one_number</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">digits</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">sum</span><span class="p">([</span><span class="n">d</span> <span class="o">*</span> <span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_bins</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">**</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">d</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">digits</span><span class="p">)])</span>
<span class="k">def</span> <span class="nf">observation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">observation</span><span class="p">):</span>
<span class="n">digits</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">digitize</span><span class="p">([</span><span class="n">x</span><span class="p">],</span> <span class="n">bins</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">bins</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">observation</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_bins</span><span class="p">)]</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_convert_to_one_number</span><span class="p">(</span><span class="n">digits</span><span class="p">)</span>
<span class="n">env</span> <span class="o">=</span> <span class="n">DiscretizedObservationWrapper</span><span class="p">(</span>
<span class="n">env</span><span class="p">,</span>
<span class="n">n_bins</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
<span class="n">low</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mf">2.4</span><span class="p">,</span> <span class="o">-</span><span class="mf">2.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.42</span><span class="p">,</span> <span class="o">-</span><span class="mf">3.5</span><span class="p">],</span>
<span class="n">high</span><span class="o">=</span><span class="p">[</span><span class="mf">2.4</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">0.42</span><span class="p">,</span> <span class="mf">3.5</span><span class="p">]</span>
<span class="p">)</span>
</code></pre></div></div>
<p>Let’s plug in the interaction with a gym env and update the Q function every time a new transition is generated. When picking the action, we use ε-greedy to force exploration.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">gym</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="n">n_steps</span> <span class="o">=</span> <span class="mi">100000</span>
<span class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.1</span> <span class="c"># 10% chances to apply a random action</span>
<span class="k">def</span> <span class="nf">act</span><span class="p">(</span><span class="n">ob</span><span class="p">):</span>
<span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o"><</span> <span class="n">epsilon</span><span class="p">:</span>
<span class="c"># action_space.sample() is a convenient function to get a random action</span>
<span class="c"># that is compatible with this given action space.</span>
<span class="k">return</span> <span class="n">env</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
<span class="c"># Pick the action with highest q value.</span>
<span class="n">qvals</span> <span class="o">=</span> <span class="p">{</span><span class="n">a</span><span class="p">:</span> <span class="n">q</span><span class="p">[</span><span class="n">state</span><span class="p">,</span> <span class="n">a</span><span class="p">]</span> <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">actions</span><span class="p">}</span>
<span class="n">max_q</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">qvals</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
<span class="c"># In case multiple actions have the same maximum q value.</span>
<span class="n">actions_with_max_q</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">q</span> <span class="ow">in</span> <span class="n">qvals</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="n">q</span> <span class="o">==</span> <span class="n">max_q</span><span class="p">]</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">actions_with_max_q</span><span class="p">)</span>
<span class="n">ob</span> <span class="o">=</span> <span class="n">env</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="n">rewards</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">reward</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_steps</span><span class="p">):</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">act</span><span class="p">(</span><span class="n">ob</span><span class="p">)</span>
<span class="n">ob_next</span><span class="p">,</span> <span class="n">r</span><span class="p">,</span> <span class="n">done</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">env</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="n">update_Q</span><span class="p">(</span><span class="n">ob</span><span class="p">,</span> <span class="n">r</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">ob_next</span><span class="p">,</span> <span class="n">done</span><span class="p">)</span>
<span class="n">reward</span> <span class="o">+=</span> <span class="n">r</span>
<span class="k">if</span> <span class="n">done</span><span class="p">:</span>
<span class="n">rewards</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">reward</span><span class="p">)</span>
<span class="n">reward</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="n">ob</span> <span class="o">=</span> <span class="n">env</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">ob</span> <span class="o">=</span> <span class="n">ob_next</span>
</code></pre></div></div>
<p>Often we start with a high <code class="highlighter-rouge">epsilon</code> and gradually decrease it during the training, known as “epsilon annealing”. The full code of <code class="highlighter-rouge">QLearningPolicy</code> is available <a href="https://github.com/lilianweng/deep-reinforcement-learning-gym/blob/master/playground/policies/qlearning.py">here</a>.</p>
<h2 id="deep-q-network">Deep Q-Network</h2>
<p><a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#deep-q-network">Deep Q-network</a> is a seminal piece of work to make the training of Q-learning more stable and more data-efficient, when the Q value is approximated with a nonlinear function. Two key ingredients are experience replay and a separately updated target network.</p>
<p>The main loss function looks like the following,</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
& Y(s, a, r, s') = r + \gamma \max_{a'} Q_{\theta^{-}}(s', a') \\
& \mathcal{L}(\theta) = \mathbb{E}_{(s, a, r, s') \sim U(D)} \Big[ \big( Y(s, a, r, s') - Q_\theta(s, a) \big)^2 \Big]
\end{aligned} %]]></script>
<p>The Q network can be a multi-layer dense neural network, a convolutional network, or a recurrent network, depending on the problem. In the <a href="https://github.com/lilianweng/deep-reinforcement-learning-gym/blob/master/playground/policies/dqn.py">full implementation</a> of the DQN policy, it is determined by the <code class="highlighter-rouge">model_type</code> parameter, one of (“dense”, “conv”, “lstm”).</p>
<p>In the following example, I’m using a 2-layer densely connected neural network to learn Q values for the cart pole balancing problem.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">gym</span>
<span class="n">env</span> <span class="o">=</span> <span class="n">gym</span><span class="o">.</span><span class="n">make</span><span class="p">(</span><span class="s">'CartPole-v1'</span><span class="p">)</span>
<span class="c"># The observation space is `Box(4,)`, a 4-element vector.</span>
<span class="n">observation_size</span> <span class="o">=</span> <span class="n">env</span><span class="o">.</span><span class="n">observation_space</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</code></pre></div></div>
<p>We have a helper function for creating the networks below:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="k">def</span> <span class="nf">dense_nn</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">layers_sizes</span><span class="p">,</span> <span class="n">scope_name</span><span class="p">):</span>
<span class="s">"""Creates a densely connected multi-layer neural network.
inputs: the input tensor
layers_sizes (list<int>): defines the number of units in each layer. The output
layer has the size layers_sizes[-1].
"""</span>
<span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">variable_scope</span><span class="p">(</span><span class="n">scope_name</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">size</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">layers_sizes</span><span class="p">):</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span>
<span class="n">inputs</span><span class="p">,</span>
<span class="n">size</span><span class="p">,</span>
<span class="c"># Add relu activation only for internal layers.</span>
<span class="n">activation</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span> <span class="k">if</span> <span class="n">i</span> <span class="o"><</span> <span class="nb">len</span><span class="p">(</span><span class="n">layers_sizes</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">else</span> <span class="bp">None</span><span class="p">,</span>
<span class="n">kernel_initializer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">xavier_initializer</span><span class="p">(),</span>
<span class="n">name</span><span class="o">=</span><span class="n">scope_name</span> <span class="o">+</span> <span class="s">'_l'</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">inputs</span>
</code></pre></div></div>
<p>The Q-network and the target network are updated with a batch of transitions (state, action, reward, state_next, done_flag). The input tensors are:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span> <span class="c"># A tunable hyperparameter.</span>
<span class="n">states</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">observation_size</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s">'state'</span><span class="p">)</span>
<span class="n">states_next</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">observation_size</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s">'state_next'</span><span class="p">)</span>
<span class="n">actions</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s">'action'</span><span class="p">)</span>
<span class="n">rewards</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s">'reward'</span><span class="p">)</span>
<span class="n">done_flags</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s">'done'</span><span class="p">)</span>
</code></pre></div></div>
<p>We have two networks of the same structure. Both have the same network architectures with the state observation as the inputs and Q values over all the actions as the outputs.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">q</span> <span class="o">=</span> <span class="n">dense</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="p">[</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s">'Q_primary'</span><span class="p">)</span>
<span class="n">q_target</span> <span class="o">=</span> <span class="n">dense</span><span class="p">(</span><span class="n">states_next</span><span class="p">,</span> <span class="p">[</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s">'Q_target'</span><span class="p">)</span>
</code></pre></div></div>
<p>The target network “Q_target” takes the <code class="highlighter-rouge">states_next</code> tensor as the input, because we use its prediction to select the optimal next state in the Bellman equation.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># The prediction by the primary Q network for the actual actions.</span>
<span class="n">action_one_hot</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">actions</span><span class="p">,</span> <span class="n">act_size</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">'action_one_hot'</span><span class="p">)</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">q</span> <span class="o">*</span> <span class="n">action_one_hot</span><span class="p">,</span> <span class="n">reduction_indices</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">'q_acted'</span><span class="p">)</span>
<span class="c"># The optimization target defined by the Bellman equation and the target network.</span>
<span class="n">max_q_next_by_target</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_max</span><span class="p">(</span><span class="n">q_target</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">rewards</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">done_flags</span><span class="p">)</span> <span class="o">*</span> <span class="n">gamma</span> <span class="o">*</span> <span class="n">max_q_next_by_target</span>
<span class="c"># The loss measures the mean squared error between prediction and target.</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">pred</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">stop_gradient</span><span class="p">(</span><span class="n">y</span><span class="p">)),</span> <span class="n">name</span><span class="o">=</span><span class="s">"loss_mse_train"</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">AdamOptimizer</span><span class="p">(</span><span class="mf">0.001</span><span class="p">)</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">"adam_optim"</span><span class="p">)</span>
</code></pre></div></div>
<p>Note that <a href="https://www.tensorflow.org/api_docs/python/tf/stop_gradient">tf.stop_gradient()</a> on the target y, because the target network should stay fixed during the loss-minimizing gradient update.</p>
<p><img src="/lil-log/assets/images/dqn-tensorboard-graph.png" alt="DQN-tensorflow" /></p>
<p>The target network is updated by copying the primary Q network parameters over every <code class="highlighter-rouge">C</code> number of steps (“hard update”) or polyak averaging towards the primary network (“soft update”)</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># Get all the variables in the Q primary network.</span>
<span class="n">q_vars</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">get_collection</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">GraphKeys</span><span class="o">.</span><span class="n">GLOBAL_VARIABLES</span><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="s">"Q_primary"</span><span class="p">)</span>
<span class="c"># Get all the variables in the Q target network.</span>
<span class="n">q_target_vars</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">get_collection</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">GraphKeys</span><span class="o">.</span><span class="n">GLOBAL_VARIABLES</span><span class="p">,</span> <span class="n">scope</span><span class="o">=</span><span class="s">"Q_target"</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">q_vars</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">q_target_vars</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">update_target_q_net_hard</span><span class="p">():</span>
<span class="c"># Hard update</span>
<span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">([</span><span class="n">v_t</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="k">for</span> <span class="n">v_t</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">q_target_vars</span><span class="p">,</span> <span class="n">q_vars</span><span class="p">)])</span>
<span class="k">def</span> <span class="nf">update_target_q_net_soft</span><span class="p">(</span><span class="n">tau</span><span class="o">=</span><span class="mf">0.05</span><span class="p">):</span>
<span class="c"># Soft update: polyak averaging.</span>
<span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">([</span><span class="n">v_t</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">v_t</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">tau</span><span class="p">)</span> <span class="o">+</span> <span class="n">v</span> <span class="o">*</span> <span class="n">tau</span><span class="p">)</span> <span class="k">for</span> <span class="n">v_t</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">q_target_vars</span><span class="p">,</span> <span class="n">q_vars</span><span class="p">)])</span>
</code></pre></div></div>
<h3 id="double-q-learning">Double Q-Learning</h3>
<p>If we look into the standard form of the Q value target, <script type="math/tex">Y(s, a) = r + \gamma \max_{a' \in \mathcal{A}} Q_\theta (s', a')</script>, it is easy to notice that we use <script type="math/tex">Q_\theta</script> to select the best next action at state s’ and then apply the action value predicted by the same <script type="math/tex">Q_\theta</script>. This two-step reinforcing procedure could potentially lead to overestimation of an (already) overestimated value, further leading to training instability. The solution proposed by double Q-learning (<a href="http://papers.nips.cc/paper/3964-double-q-learning.pdf">Hasselt, 2010</a>) is to decouple the action selection and action value estimation by using two Q networks, <script type="math/tex">Q_1</script> and <script type="math/tex">Q_2</script>: when <script type="math/tex">Q_1</script> is being updated, <script type="math/tex">Q_2</script> decides the best next action, and vice versa.</p>
<script type="math/tex; mode=display">Y_1(s, a, r, s') = r + \gamma Q_1 (s', \arg\max_{a' \in \mathcal{A}}Q_2(s', a'))\\
Y_2(s, a, r, s') = r + \gamma Q_2 (s', \arg\max_{a' \in \mathcal{A}}Q_1(s', a'))</script>
<p>To incorporate double Q-learning into DQN, the minimum modification (<a href="https://arxiv.org/pdf/1509.06461.pdf">Hasselt, Guez, & Silver, 2016</a>) is to use the primary Q network to select the action while the action value is estimated by the target network:</p>
<script type="math/tex; mode=display">Y(s, a, r, s') = r + \gamma Q_{\theta^{-}}(s', \arg\max_{a' \in \mathcal{A}} Q_\theta(s', a'))</script>
<p>In the code, we add a new tensor for getting the action selected by the primary Q network as the input and a tensor operation for selecting this action.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">actions_next</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s">'action_next'</span><span class="p">)</span>
<span class="n">actions_selected_by_q</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">'action_selected'</span><span class="p">)</span>
</code></pre></div></div>
<p>The prediction target y in the loss function becomes:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">actions_next_flatten</span> <span class="o">=</span> <span class="n">actions_next</span> <span class="o">+</span> <span class="n">tf</span><span class="o">.</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span> <span class="o">*</span> <span class="n">q_target</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">max_q_next_target</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">q_target</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">actions_next_flatten</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">rewards</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">done_flags</span><span class="p">)</span> <span class="o">*</span> <span class="n">gamma</span> <span class="o">*</span> <span class="n">max_q_next_by_target</span>
</code></pre></div></div>
<p>Here I used <a href="https://www.tensorflow.org/api_docs/python/tf/gather">tf.gather()</a> to select the action values of interests.</p>
<p style="width: 60%;" class="center"><img src="/lil-log/assets/images/tf_gather.png" alt="tf-gather" /></p>
<p><em>(Image source: <a href="https://www.tensorflow.org/api_docs/python/tf/gather">tf.gather() docs</a>)</em></p>
<p>During the episode rollout, we compute the <code class="highlighter-rouge">actions_next</code> by feeding the next states’ data into the <code class="highlighter-rouge">actions_selected_by_q</code> operation.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># batch_data is a dict with keys, ‘s', ‘a', ‘r', ‘s_next' and ‘done', containing a batch of transitions.</span>
<span class="n">actions_next</span> <span class="o">=</span> <span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">actions_selected_by_q</span><span class="p">,</span> <span class="p">{</span><span class="n">states</span><span class="p">:</span> <span class="n">batch_data</span><span class="p">[</span><span class="s">'s_next'</span><span class="p">]})</span>
</code></pre></div></div>
<h3 id="dueling-q-network">Dueling Q-Network</h3>
<p>The dueling Q-network (<a href="https://arxiv.org/pdf/1511.06581.pdf">Wang et al., 2016</a>) is equipped with an enhanced network architecture: the output layer branches out into two heads, one for predicting state value, V, and the other for <a href="/lil-log/2018/02/19/a-long-peek-into-reinforcement-learning.html#value-function">advantage</a>, A. The Q-value is then reconstructed, <script type="math/tex">Q(s, a) = V(s) + A(s, a)</script>.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
A(s, a) &= Q(s, a) - V(s)\\
V(s) &= \sum_a Q(s, a) \pi(a \vert s) = \sum_a (V(s) + A(s, a)) \pi(a \vert s) = V(s) + \sum_a A(s, a)\pi(a \vert s)\\
\text{Thus, }& \sum_a A(s, a)\pi(a \vert s) = 0
\end{aligned} %]]></script>
<p>To make sure the estimated advantage values sum up to zero, <script type="math/tex">\sum_a A(s, a)\pi(a \vert s) = 0</script>, we deduct the mean value from the prediction.</p>
<script type="math/tex; mode=display">Q(s, a) = V(s) + (A(s, a) - \frac{1}{|\mathcal{A}|} \sum_a A(s, a))</script>
<p>The code change is straightforward:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">q_hidden</span> <span class="o">=</span> <span class="n">dense_nn</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="p">[</span><span class="mi">32</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s">'Q_primary_hidden'</span><span class="p">)</span>
<span class="n">adv</span> <span class="o">=</span> <span class="n">dense_nn</span><span class="p">(</span><span class="n">q_hidden</span><span class="p">,</span> <span class="p">[</span><span class="mi">32</span><span class="p">,</span> <span class="n">env</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">n</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s">'Q_primary_adv'</span><span class="p">)</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">dense_nn</span><span class="p">(</span><span class="n">q_hidden</span><span class="p">,</span> <span class="p">[</span><span class="mi">32</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s">'Q_primary_v'</span><span class="p">)</span>
<span class="c"># Average dueling</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">v</span> <span class="o">+</span> <span class="p">(</span><span class="n">adv</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">adv</span><span class="p">,</span> <span class="n">reduction_indices</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
</code></pre></div></div>
<p><img src="/lil-log/assets/images/dueling-q-network.png" alt="dueling-q-network" />
<em>(Image source: <a href="https://arxiv.org/pdf/1511.06581.pdf">Wang et al., 2016</a>)</em></p>
<p>Check the <a href="https://github.com/lilianweng/deep-reinforcement-learning-gym/blob/master/playground/policies/dqn.py">code</a> for the complete flow.</p>
<h2 id="monte-carlo-policy-gradient">Monte-Carlo Policy Gradient</h2>
<p>I reviewed a number of popular policy gradient methods in my <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html">last post</a>. Monte-Carlo policy gradient, also known as <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce">REINFORCE</a>, is a classic on-policy method that learns the policy model explicitly. It uses the return estimated from a full on-policy trajectory and updates the policy parameters with policy gradient.</p>
<p>The returns are computed during rollouts and then fed into the Tensorflow graph as inputs.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># Inputs</span>
<span class="n">states</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,</span> <span class="n">obs_size</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s">'state'</span><span class="p">)</span>
<span class="n">actions</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s">'action'</span><span class="p">)</span>
<span class="n">returns</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s">'return'</span><span class="p">)</span>
</code></pre></div></div>
<p>The policy network is contructed. We update the policy parameters by minimizing the loss function, <script type="math/tex">\mathcal{L} = - (G_t - V(s)) \log \pi(a \vert s)</script>.
<a href="https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits">tf.nn.sparse_softmax_cross_entropy_with_logits()</a> asks for the raw logits as inputs, rather then the probabilities after softmax, and that’s why we do not have a softmax layer on top of the policy network.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># Policy network</span>
<span class="n">pi</span> <span class="o">=</span> <span class="n">dense_nn</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="p">[</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="n">env</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">n</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s">'pi_network'</span><span class="p">)</span>
<span class="n">sampled_actions</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">multinomial</span><span class="p">(</span><span class="n">pi</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="c"># For sampling actions according to probabilities.</span>
<span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">variable_scope</span><span class="p">(</span><span class="s">'pi_optimize'</span><span class="p">):</span>
<span class="n">loss_pi</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span>
<span class="n">returns</span> <span class="o">*</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">sparse_softmax_cross_entropy_with_logits</span><span class="p">(</span>
<span class="n">logits</span><span class="o">=</span><span class="n">pi</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">actions</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s">'loss_pi'</span><span class="p">)</span>
<span class="n">optim_pi</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">AdamOptimizer</span><span class="p">(</span><span class="mf">0.001</span><span class="p">)</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">loss_pi</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">'adam_optim_pi'</span><span class="p">)</span>
</code></pre></div></div>
<p>During the episode rollout, the return is calculated as follows:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># env = gym.make(...)</span>
<span class="c"># gamma = 0.99</span>
<span class="c"># sess = tf.Session(...)</span>
<span class="k">def</span> <span class="nf">act</span><span class="p">(</span><span class="n">ob</span><span class="p">):</span>
<span class="k">return</span> <span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">sampled_actions</span><span class="p">,</span> <span class="p">{</span><span class="n">states</span><span class="p">:</span> <span class="p">[</span><span class="n">ob</span><span class="p">]})</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_episodes</span><span class="p">):</span>
<span class="n">ob</span> <span class="o">=</span> <span class="n">env</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
<span class="n">done</span> <span class="o">=</span> <span class="bp">False</span>
<span class="n">obs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">actions</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">rewards</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">returns</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">while</span> <span class="ow">not</span> <span class="n">done</span><span class="p">:</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">act</span><span class="p">(</span><span class="n">ob</span><span class="p">)</span>
<span class="n">new_ob</span><span class="p">,</span> <span class="n">r</span><span class="p">,</span> <span class="n">done</span><span class="p">,</span> <span class="n">info</span> <span class="o">=</span> <span class="n">env</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="n">obs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ob</span><span class="p">)</span>
<span class="n">actions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="n">rewards</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">r</span><span class="p">)</span>
<span class="n">ob</span> <span class="o">=</span> <span class="n">new_ob</span>
<span class="c"># Estimate returns backwards.</span>
<span class="n">return_so_far</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">rewards</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
<span class="n">return_so_far</span> <span class="o">=</span> <span class="n">gamma</span> <span class="o">*</span> <span class="n">return_so_far</span> <span class="o">+</span> <span class="n">r</span>
<span class="n">returns</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">return_so_far</span><span class="p">)</span>
<span class="n">returns</span> <span class="o">=</span> <span class="n">returns</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="c"># Update the policy network with the data from one episode.</span>
<span class="n">sess</span><span class="o">.</span><span class="n">run</span><span class="p">([</span><span class="n">optim_pi</span><span class="p">],</span> <span class="n">feed_dict</span><span class="o">=</span><span class="p">{</span>
<span class="n">states</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">obs</span><span class="p">),</span>
<span class="n">actions</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">actions</span><span class="p">),</span>
<span class="n">returns</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">returns</span><span class="p">),</span>
<span class="p">})</span>
</code></pre></div></div>
<p>The full implementation of REINFORCE is <a href="https://github.com/lilianweng/deep-reinforcement-learning-gym/blob/master/playground/policies/reinforce.py">here</a>.</p>
<h2 id="actor-critic">Actor-Critic</h2>
<p>The <a href="/lil-log/2018/04/08/policy-gradient-algorithms.html#actor-critic">actor-critic</a> algorithm learns two models at the same time, the actor for learning the best policy and the critic for estimating the state value.</p>
<ol>
<li>Initialize the actor network, <script type="math/tex">\pi(a \vert s)</script> and the critic, <script type="math/tex">V(s)</script></li>
<li>Collect a new transition (s, a, r, s’): Sample the action <script type="math/tex">a \sim \pi(a \vert s)</script> for the current state s, and get the reward r and the next state s’.</li>
<li>Compute the TD target during episode rollout, <script type="math/tex">G_t = r + \gamma V(s')</script> and TD error, <script type="math/tex">\delta_t = r + \gamma V(s') - V(s)</script>.</li>
<li>Update the critic network by minimizing the critic loss: <script type="math/tex">L_c = (V(s) - G_t)</script>.</li>
<li>Update the actor network by minimizing the actor loss: <script type="math/tex">L_a = - \delta_t \log \pi(a \vert s)</script>.</li>
<li>Set s’ = s and repeat step 2.-5.</li>
</ol>
<p>Overall the implementation looks pretty similar to REINFORCE with an extra critic network. The full implementation is here.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># Inputs</span>
<span class="n">states</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,</span> <span class="n">observation_size</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s">'state'</span><span class="p">)</span>
<span class="n">actions</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s">'action'</span><span class="p">)</span>
<span class="n">td_targets</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">placeholder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s">'td_target'</span><span class="p">)</span>
<span class="c"># Actor: action probabilities</span>
<span class="n">actor</span> <span class="o">=</span> <span class="n">dense_nn</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="p">[</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="n">env</span><span class="o">.</span><span class="n">action_space</span><span class="o">.</span><span class="n">n</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s">'actor'</span><span class="p">)</span>
<span class="c"># Critic: action value (Q-value)</span>
<span class="n">critic</span> <span class="o">=</span> <span class="n">dense_nn</span><span class="p">(</span><span class="n">states</span><span class="p">,</span> <span class="p">[</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s">'critic'</span><span class="p">)</span>
<span class="n">action_ohe</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">actions</span><span class="p">,</span> <span class="n">act_size</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">'action_one_hot'</span><span class="p">)</span>
<span class="n">pred_value</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">critic</span> <span class="o">*</span> <span class="n">action_ohe</span><span class="p">,</span> <span class="n">reduction_indices</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s">'q_acted'</span><span class="p">)</span>
<span class="n">td_errors</span> <span class="o">=</span> <span class="n">td_targets</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">pred_value</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">variable_scope</span><span class="p">(</span><span class="s">'critic_train'</span><span class="p">):</span>
<span class="n">loss_c</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">td_errors</span><span class="p">))</span>
<span class="n">optim_c</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">AdamOptimizer</span><span class="p">(</span><span class="mf">0.01</span><span class="p">)</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">loss_c</span><span class="p">)</span>
<span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">variable_scope</span><span class="p">(</span><span class="s">'actor_train'</span><span class="p">):</span>
<span class="n">loss_a</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span>
<span class="n">tf</span><span class="o">.</span><span class="n">stop_gradient</span><span class="p">(</span><span class="n">td_errors</span><span class="p">)</span> <span class="o">*</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">sparse_softmax_cross_entropy_with_logits</span><span class="p">(</span>
<span class="n">logits</span><span class="o">=</span><span class="n">actor</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">actions</span><span class="p">),</span>
<span class="n">name</span><span class="o">=</span><span class="s">'loss_actor'</span><span class="p">)</span>
<span class="n">optim_a</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">AdamOptimizer</span><span class="p">(</span><span class="mf">0.01</span><span class="p">)</span><span class="o">.</span><span class="n">minimize</span><span class="p">(</span><span class="n">loss_a</span><span class="p">)</span>
<span class="n">train_ops</span> <span class="o">=</span> <span class="p">[</span><span class="n">optim_c</span><span class="p">,</span> <span class="n">optim_a</span><span class="p">]</span>
</code></pre></div></div>
<p>The tensorboard graph is always helpful:
<img src="/lil-log/assets/images/actor-critic-tensorboard-graph.png" alt="ac-tensorflow" /></p>
<h2 id="references">References</h2>
<p>[1] <a href="https://www.tensorflow.org/api_docs/">Tensorflow API Docs</a></p>
<p>[2] Christopher JCH Watkins, and Peter Dayan. <a href="https://link.springer.com/content/pdf/10.1007/BF00992698.pdf">“Q-learning.”</a> Machine learning 8.3-4 (1992): 279-292.</p>
<p>[3] Hado Van Hasselt, Arthur Guez, and David Silver. <a href="https://arxiv.org/pdf/1509.06461.pdf">“Deep Reinforcement Learning with Double Q-Learning.”</a> AAAI. Vol. 16. 2016.</p>
<p>[4] Hado van Hasselt. <a href="http://papers.nips.cc/paper/3964-double-q-learning.pdf">“Double Q-learning.”</a> NIPS, 23:2613–2621, 2010.</p>
<p>[5] Ziyu Wang, et al. <a href="https://arxiv.org/pdf/1511.06581.pdf">Dueling network architectures for deep reinforcement learning.</a> ICML. 2016.</p>Lilian WengLet’s see how to implement a number of classic deep reinforcement learning models in code.