Jekyll2018-12-17T19:21:55+00:00https://weberna.github.io/feed.xmlweberna’s blogA machine learning blog on various topics. A splendid time guarenteed for all.Noah WeberWhy LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass2017-11-15T05:29:36+00:002017-11-15T05:29:36+00:00https://weberna.github.io/blog/2017/11/15/LSTM-Vanishing-Gradients<h1 id="lstms-the-gentle-giants">LSTMs: The Gentle Giants</h1>
<p>On their surface, LSTMs (and related architectures such as GRUs) seems like wonky, overly complex contraptions. Indeed, at first it seems almost sacrilegious to
add these bulky accessories to our beautifully elegant <a href="https://tatar.ucsd.edu/jeffelman/">Elman-style</a> recurrent neural networks (RNNs)! However, unlike bloated software (such as Skype), this extra complexity is
warranted in the case of LSTMs (also unlike Skype is the fact that LSTM/GRUs usually <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness/">work pretty well</a>). If you have read any paper that appeared around 2015-2016 that uses LSTMs you probably know that LSTMS solve the vanishing gradient problem that had plagued vanilla RNNs before hand.</p>
<p>If you don’t already know, the vanishing gradient problem arises when, during backprop, the error signal used to train the network exponentially decreases the further you travel backwards in your network. The effect of this is that the
layers closer to your input don’t get trained. In the case of RNNs (which can be unrolled and thought of as feed forward networks with shared weights) this means that you don’t keep track of any long term dependencies. This is kind of
a bummer, since the whole point of an RNN is to keep track of long term dependencies. The situation is analogous to having a video chat application that can’t handle video chats!</p>
<p>Looking at these big pieces of machinery its hard to get a concrete understanding of exactly <em>why</em> they solve the vanishing gradient problem. The purpose of this blog post is to <del>put it on my resume</del> give a brief explanation
as to why LSTMs (and related models) solve the vanishing gradient problem. The reason <em>why</em> is actually pretty simple, which is all the more reason to know it.
If you are unfamiliar with LSTM models I would check out this <a href="http://colah.github.io/posts/2015-08-Understanding-LSTMs/">post</a>.</p>
<p><strong>Notation</strong>
Notation is always the pain when describing LSTMs since there are so many variables. I will list all notation here for convienence:</p>
<ul>
<li>\(E_t\)=Error at time \(t\), assume \(E_t\) is a function of output \(y_t\)</li>
<li>\(W_R\)=The recurrent set of weights (other sets of weights denoted with a different subscript)</li>
<li>\(tanh, \sigma\)=The activation function tanh, or sigmoid</li>
<li>\(h_t\)=The hidden vector at time \(t\)</li>
<li>\(C_t\)=The LSTM cell state at time \(t\)</li>
<li>\(o_t, f_t, i_t\)=The LSTM output, forget, and input gates at time \(t\)</li>
<li>\(x_t, y_t\)=The input and output at time \(t\)</li>
</ul>
<p><strong>LSTM Equation Reference</strong>
Quickly, here is a little review of the LSTM equations, with the biases left off (and mostly the same notation as <a href="http://colah.github.io/posts/2015-08-Understanding-LSTMs/">Chris Olah’s post</a>:</p>
<ul>
<li>\(f_t=\sigma(W_f[h_{t-1},x_t])\)</li>
<li>\(i_t=\sigma(W_i[h_{t-1},x_t])\)</li>
<li>\(o_t=\sigma(W_o[h_{t-1},x_t])\)</li>
<li>\(\widetilde{C}_t=tanh(W_C[h_{t-1},x_t])\)</li>
<li>\(C_t=f_tC_{t-1} + i_t\widetilde{C}_t\)</li>
<li>\(h_t=o_ttanh(C_t)\)</li>
</ul>
<h1 id="the-case-of-the-vanishing-gradients">The Case of the Vanishing Gradients</h1>
<p>To understand why LSTMs help, we need to understand the problem with vanilla RNNs. In a vanilla RNN, the hidden vector and the output is computed as such:</p>
<script type="math/tex; mode=display">h_t = tanh(W_Ix_t + W_Rh_{t-1})\\
y_t = W_Oh_t</script>
<p>To do backpropagation through time to train the RNN, we need to compute the gradient of \(E\) with respect to \(W_R\). The overall error gradient is equal to the sum of the error gradients at each time step.
For step \(t\), we can use the multivariate chain rule to derive the error gradient as:</p>
<script type="math/tex; mode=display">\frac{\partial E_t}{\partial W_R} = \sum^{t}_{i=0} \frac{\partial E_t}{\partial y_t}\frac{\partial y_t}{\partial h_t}\frac{\partial h_t}{\partial h_i}\frac{\partial h_i}{\partial W_R}</script>
<p>Now everything here can be computed pretty easily <em>except</em> the term \(\frac{\partial h_t}{\partial h_i}\), which needs another chain rule application to compute:</p>
<script type="math/tex; mode=display">\frac{\partial h_t}{\partial h_i} = \frac{\partial h_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial h_{t-2}}...\frac{\partial h_{i+1}}{\partial h_i} = \prod^{t-1}_{k=i} \frac{\partial h_{k+1}}{\partial h_k}</script>
<p>Now let us look at a single one of these terms by taking the derivative of \(h_{k+1}\) with respect to \(h_{k}\)(where <em>diag</em> turns a vector into a diagonal matrix)<sup id="fnref:1"><a href="#fn:1" class="footnote">1</a></sup>:</p>
<script type="math/tex; mode=display">\frac{\partial h_{k+1}}{\partial h_k} = diag(f'(W_Ix_i + W_Rh_{i-1}))W_R</script>
<p>Thus, if we want to backpropagate through \(k\) timesteps, this gradient will be :</p>
<script type="math/tex; mode=display">\frac{\partial h_{k}}{\partial h_1} = \prod_i^k diag(f'(W_Ix_i + W_Rh_{i-1}))W_R</script>
<p>As shown in <a href="https://arxiv.org/pdf/1211.5063.pdf">this paper</a>, if the dominant eigenvalue of the matrix \(W_R\) is greater than 1, the gradient explodes. If it is less than 1, the gradient vanishes.<sup id="fnref:4"><a href="#fn:4" class="footnote">2</a></sup>
The fact that this equation leads to either vanishing or exploding gradients should make intuitive sense. Note that the values of \(f’(x)\) will always be less than 1. So if the magnitude of the values of
\(W_R\) are too small, then inevitably the derivative will go to 0. The repeated multiplications of values less than one would overpower the repeated multiplications of \(W_R\). On the contrary, make \(W_R\) <em>too</em> big and
the derivative will go to infinity since the exponentiation of \(W_R\) will overpower the repeated multiplication of the values less than 1. In practice, the
vanishing gradient is more common, so we will mostly focus on that.</p>
<p>The derivative \(\frac{\partial h_{k}}{\partial h_1}\) is essentially telling us how much our hidden
state at time \(k\) will change when we change the hidden state at time 1 by a little bit. According to the above math, if the gradient vanishes it means
the earlier hidden states have no real effect on the later hidden states, meaning no long term dependencies are learned! This can be formally proved, and has been in <a href="https://arxiv.org/pdf/1211.5063.pdf">many papers</a>, including the <a href="http://www.bioinf.jku.at/publications/older/2604.pdf">original LSTM paper</a>.</p>
<h1 id="preventing-vanishing-gradients-with-lstms">Preventing Vanishing Gradients with LSTMs</h1>
<p>As we can see above, the biggest culprit in causing our gradients to vanish is that dastardly recursive derivative we need to compute: \(\frac{\partial h_t}{\partial h_i}\). If only this derivative was ‘well behaved’ (that is, it doesn’t go to 0 or infinity as we backpropagate through layers) then we could learn long term dependencies!</p>
<p><strong>The original LSTM solution</strong>
The original motivation behind the LSTM was to make this recursive derivative have a constant value. If this is the case then our gradients would neither explode or
vanish. How is this accomplished? As you may know, the LSTM introduces a separate cell state \(C_t\). In the original 1997 LSTM, the value for \(C_t\) depends on the previous value
of the cell state and an update term weighted by the input gate value (for motivation on why the input/output gates are needed, I would check out <a href="https://r2rt.com/written-memories-understanding-deriving-and-extending-the-lstm.html">this great post</a>):</p>
<script type="math/tex; mode=display">C_t = C_{t-1} + i\widetilde{C}_t</script>
<p>This formulation doesn’t work well because the cell state tends to grow uncontrollably. In order to prevent this unbounded growth, a <a href="https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf">forget gate was added</a> to scale the previous cell state, leading to the more modern formulation:</p>
<script type="math/tex; mode=display">C_t = fC_{t-1} + i\widetilde{C}_t</script>
<p><strong>A common misconception</strong>
Most explanations for why LSTMs solve the vanishing gradient state that under this update rule, the recursive derivative is equal to 1 (in the case of the original LSTM)
or \(f\) (in the case of the modern LSTM)<sup id="fnref:2"><a href="#fn:2" class="footnote">3</a></sup> and is thus well behaved! One thing that is often forgotten is that \(f\), \(i\), and \(\widetilde{C}_t\) are all functions of \(C_t\), and thus we must take them into consideration when calculating the gradient.</p>
<p>The reason for this misconception is pretty reasonable. In the original LSTM formulation in 1997, the recursive gradient actually was equal to 1. The reason for this
is because, in order to enforce this constant error flow, the gradient calculation was truncated so as not to flow back to the input or candidate gates. So with respect
to \(C_{t-1}\) they could be treated as constants. Here what they say in the <a href="http://www.bioinf.jku.at/publications/older/2604.pdf">original paper</a>:</p>
<blockquote>
<p>However,to ensure non-decaying error backprop through internal states of memory cells, as with truncated
BPTT (e.g.,Williams and Peng 1990), errors arriving at “memory cell net inputs” [the cell output, input, forget, and candidate gates] …do not get propagated back
further in time (although they do serve to change the incoming weights).Only within memory cells [the cell state],errors are propagated back through previous internal states.</p>
</blockquote>
<p>In fact truncating the gradients in this way was done up till about 2005, until the publication of <a href="ftp://ftp.idsia.ch/pub/juergen/nn_2005.pdf">this paper</a> by Alex
Graves. Since most popular neural network frameworks now do auto differentiation, its likely that you are using the full LSTM gradient formulation too! So, does the
above argument about why LSTMs solve the vanishing gradient change when using the full gradient? The answer is no, actually it remains mostly the same. It just
gets a bit messier.</p>
<p><strong>Looking at the full LSTM gradient</strong><sup id="fnref:3"><a href="#fn:3" class="footnote">4</a></sup>
To understand why nothing really changes when using the full gradient, we need to look at what happens to the recursive gradient when we take the full gradient.
As we stated before, the recursive derivative is the main thing that is causing the vanishing gradient, so lets expand out the full derivative for
\(\frac{\partial C_t}{\partial C_{t-1}}\). First recall that in the LSTM, \(C_t\) is a function of \(f_t\) (the forget gate), \(i_t\) (the input gate),
and
\(\widetilde{C}_t\) (the candidate cell state), each of these being a function of \(C_{t-1}\) (since they are all functions of \(h_{t-1}\)). Via the multivariate chain rule we get:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align*}
\frac{\partial C_t}{\partial C_{t-1}} &= \frac{\partial C_t}{\partial f_{t}}\frac{\partial f_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}} + \frac{\partial C_t}{\partial i_{t}}\frac{\partial i_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}} \\
&+ \frac{\partial C_t}{\partial \widetilde{C}_{t}}\frac{\partial \widetilde{C}_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}} + \frac{\partial C_t}{\partial C_{t-1}}
\end{align*} %]]></script>
<p>Now lets explicitly write out these derivatives:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align*}
\frac{\partial C_t}{\partial C_{t-1}} &= C_{t-1}\sigma'(\cdot)W_f*o_{t-1}tanh'(C_{t-1}) \\
&+ \widetilde{C}_t\sigma'(\cdot)W_i*o_{t-1}tanh'(C_{t-1}) \\
&+ i_t\tanh'(\cdot)W_C*o_{t-1}tanh'(C_{t-1}) \\
&+ f_t
\end{align*} %]]></script>
<p>Now if we want to backpropagate back \(k\) time steps, we simply multiply terms in the form of the one above \(k\) times. Note the big difference between this
recursive gradient and the one for vanilla RNNs. In vanilla RNNs, the terms \(\frac{\partial h_t}{\partial h_{t-1}}\) will eventually take on a values
that are either always above 1 or always in the range \([0,1]\), this is essentially what leads to the vanishing/exploding gradient problem. The terms here, \(\frac{\partial C_t}{\partial C_{t-1}}\), <em>at any time step</em> can take on either values that are greater than 1 or values in the range \([0,1]\). Thus if we extend to an infinite amount
of time steps, it is not guarenteed that we will end up converging to 0 or infinity (unlike in vanilla RNNs). If we start to converge to zero, we can always set
the values of \(f_t\) (and other gate values) to be higher in order to bring the value of \(\frac{\partial C_t}{\partial C_{t-1}}\) closer to 1, thus preventing the gradients from
vanishing (or at the very least, preventing them from vanishing <em>too</em> quickly). One important thing to note is that the values \(f_t\), \(o_t\), \(i_t\), and
\(\widetilde{C}_t\) are things that the network <em>learns</em> to set (conditioned on the current input and hidden state). Thus, in this way the network learns to
decide <em>when</em> to let the gradient vanish, and <em>when</em> to preserve it, by setting the gate values accordingly!</p>
<p>This might all seem magical, but it really is just the result of two main things:</p>
<ul>
<li>The additive update function for the cell state gives a derivative thats much more ‘well behaved’</li>
<li>The gating functions allow the network to decide how much the gradient vanishes, and can take on different values at each time step. The values that they take on are learned functions
of the current input and hidden state.</li>
</ul>
<p>And that is essentially it. It is good to know that truncating the gradient (as done in the original LSTM) is not too integral to explaining why the LSTM can prevent
the vanishing gradient. As we see, the arguments for why it prevents the vanishing gradient remain somewhat similar even when taking the full gradient into account.
Thanks for <del>reading</del> <del>skimming</del> scrolling to the bottom to look at the comments.</p>
<div class="footnotes">
<ol>
<li id="fn:1">
<p>Keep in mind that this recursive partial derivative is a (Jacobian) matrix! <a href="#fnref:1" class="reversefootnote">↩</a></p>
</li>
<li id="fn:4">
<p>For intuition on the importance of the eigenvalues of the recurrent weight matrix, I would look <a href="https://smerity.com/articles/2016/orthogonal_init.html">here</a> <a href="#fnref:4" class="reversefootnote">↩</a></p>
</li>
<li id="fn:2">
<p>In the case of the forget gate LSTM, the recursive derivative will still be a produce of many terms between 0 and 1 (the forget gates at each time step), however in practice this is not as much of a problem compared to the case of RNNs. One thing to remember is that our network has direct control over what the values of \(f\) will be. If it needs to remember something, it can easily set the value of \(f\) to be high (lets say around 0.95). These values thus tend to shrink at a much slower rate than when compared to the derivative values of \(tanh\), which later on during the training processes, are likely to be saturated and thus have a value close to 0. <a href="#fnref:2" class="reversefootnote">↩</a></p>
</li>
<li id="fn:3">
<p>There are <em>lots</em> of little derivatives that need to be derived in order to do the full LSTM derivation. I won’t do that here, as we only need to look at one of them. The <a href="https://www.cs.toronto.edu/~graves/phd.pdf">PhD thesis of Alex Graves</a> lists the derivate formulas needed, for those interested. <a href="#fnref:3" class="reversefootnote">↩</a></p>
</li>
</ol>
</div>Noah WeberLSTMs: The Gentle Giants On their surface, LSTMs (and related architectures such as GRUs) seems like wonky, overly complex contraptions. Indeed, at first it seems almost sacrilegious to add these bulky accessories to our beautifully elegant Elman-style recurrent neural networks (RNNs)! However, unlike bloated software (such as Skype), this extra complexity is warranted in the case of LSTMs (also unlike Skype is the fact that LSTM/GRUs usually work pretty well). If you have read any paper that appeared around 2015-2016 that uses LSTMs you probably know that LSTMS solve the vanishing gradient problem that had plagued vanilla RNNs before hand.Steps Towards Understanding Deep Learning: The Information Bottleneck Connection (Part 1)2017-11-08T05:29:36+00:002017-11-08T05:29:36+00:00https://weberna.github.io/jekyll/update/2017/11/08/Information-Bottleneck-Part1<p align="center">
<img src="/images/blog1/knobs_label2.png" alt="my alt text" />
</p>
<h1 id="why-open-the-black-box">Why Open the Black Box?</h1>
<p>Since the dawn of time, deep learning models have been puzzling theoreticians as to
why they work so well. With the large amounts of parameters they have, the fact that they seem to generalize well is a confounding one. Of course getting
them to work well is often easier said then done. Today’s models typically have more buttons and switches to fiddle with than an airplane cockpit. And much like an airplane, flipping the wrong set of switches is likely to end in a crash and burn scenario.
Its definitely true that some good general machine learning knowledge can help in setting these parameters,
but there are many times where it seems the only thing we have to guide us is our shaky intuition.</p>
<p>A good theory can often serve as
an highly effective guide for practitioners. While understanding the theory of why deep learning models work may not completely eliminate all the knobs
we need to set, it would definitely make many choices much easier. Its no surprise that there has been some theory work on trying to understand deep learning models.</p>
<h1 id="looking-at-the-black-box-using-information-theory">Looking at the Black Box Using Information Theory</h1>
<p>One that has recently gotten some attention (in the form of articles from mainstream press such as <a href="https://www.quantamagazine.org/new-theory-cracks-open-the-black-box-of-deep-learning-20170921/">Quanta</a>, as well as
praise from Geoffrey Hinton himself) has been the work of <a href="http://naftali-tishby.strikingly.com/#my-current-lab">Naftali Tishby’s lab</a>, on providing an information theory based explanation
as to why these models can generalize. Most of the excitement revolves around the work titled <a href="https://arxiv.org/abs/1703.00810">“Opening the Black Box of Deep Neural Networks via Information”</a>
however the main theoretic arguments were actually presented two years ago, in <a href="https://arxiv.org/abs/1503.02406">2015</a>.</p>
<p>Since <em>some</em> of the conclusions above are empirical, some caution should be taken before fully accepting its conclusions
(indeed a <a href="https://openreview.net/forum?id=ry_WPG-A-&noteId=ry_WPG-A-">currently under review paper</a> in ICLR 2018 gives results that contradict the above paper<sup id="fnref:1"><a href="#fn:1" class="footnote">1</a></sup>), however the theory behind
it is incredibly appealing, and holds a lot of value in itself. In this post I hope to give an overview of the theory,
the main points behind it, and what it could possibly mean for the future of deep learning.
At this point I will proceed with describing the preliminaries to understanding Tishby’s information theory-deep learning connection: Rate Distortion Theory, and the Information Bottleneck principle.
The post assumes you have a grasp on the basic ideas of deep learning, and an understanding of <em>basic</em> concepts from information theory such entropy and mutual information (this concept is important!).
This <a href="http://colah.github.io/posts/2015-09-Visual-Information/">post</a> is a great introduction to these concepts if you need a hand.</p>
<h1 id="what-does-information-theory-have-to-do-with-deep-learning">What does information theory have to do with deep learning?</h1>
<p>What exactly is the connection between information theory (IT) and deep learning (DL)? Rather than jump right into connection between the two
it actually helps to first understand the Information Bottleneck (IB) principle. As we will see, once we understand
IB, the connection between IT and DL arises quite naturally. The main purpose of IB is to answer the following question: “How
do we define what information in relevant, in a rigorous way?” This seems rather counter-intuitive, as the ‘relevancy’ of
information seems rather subjective. Information theory as laid out by Claude Shannon omits any notion of ‘meaning’ in the
information, and without meaning, how can we possibly measure relevancy?</p>
<h1 id="rate-distortion-theory">Rate Distortion Theory</h1>
<p>Information theory does hold one possible answer to the question of relevancy however, in the form of lossy compression.
The goal of lossy compression is to find the ‘most compressed’ representation possible for our set of input data \(X\) such that we don’t lose <em>too</em>
much information about our original data \(X\). Our relevant information would thus be the information preserved in the
compressed version of \(X\). Rate Distortion Theory provides a clear formalization for these concepts.</p>
<p><strong>Defining Compression</strong>
The first thing we need to to formalize is how the ‘compression’ is actually done. We can formally think of compression
as defining a (possibly stochastic) function that maps an element \(x \in X\) to its corresponding compressed representation
(sometimes called the ‘codeword’ for \(x\)).
We will refer to the set of codewords of \(X\) as \(T\), with an arbitrary element of \(T\) being denoted as \(t\)
This map is implicitly defined through the probability distribution \(p(t|x)\), so by defining the distribution \(p(t|x)\)
we will have defined our method of compression/assignment of codewords.</p>
<p>The diagram below gives a visual representation of this compression (for a deterministic \(p(t|x)\), which is what we will have with most neural networks):</p>
<p align="center">
<img src="/images/blog1/compress.png" style="width: 400px" alt="my alt text" />
</p>
<p>As you can see, there may be many elements of \(X\) that map to the same \(t \in T\).</p>
<p><strong>Measuring Compression</strong>
We now need a good mathematical way to measure how ‘compressed’ our representation (decided by \(p(t|x)\)) is. One natural
way to measure how much we have compressed the set of codewords is to measure how much information a codeword contains (on average).
This can be done of course by using the entropy rate of T<sup id="fnref:2"><a href="#fn:2" class="footnote">2</a></sup>. The rate of T can be made arbitrarily large however simply
by packing redundant information in the codewords. What we are interested in measuring is the best rate we can achieve for a
certain type of encoding (recall, this is defined by \(p(t|x)\)). It turns out the mutual information between \(X\) and \(T\)
(denoted \(I(X;T)\) gives a lower bound on the rate of \(T\), thus we can use \(I(X;T)\) as a stand in for ‘the best rate achieved
for encoding \(p(t|x)\)’ and thus as a measure of how ‘compressed’ representation \(T\) is<sup id="fnref:3"><a href="#fn:3" class="footnote">3</a></sup>
The smaller \(I(X;T)\) the more
compressed a representation \(T\) is of \(X\) (this should make intuitive sense, smaller \(I(X;T)\) means \(T\) holds less
information about \(X\), which is what we would expect as our representation gets more compressed). Of course, higher values
of mutual information indicate a less compressed representation.</p>
<p><strong>Losing Information</strong> One more thing we need to formalize: we need to define how much information loss is too much information loss, and more particularly, what constitutes as information loss. We do this by defining a distortion function \(d(x,t)\), which takes
in an element of \(X\) and \(T\) and outputs a single value indicating how different (ie distorted) \(x\) and \(t\) are. The exact
function is arbitrary and must be picked on a per task basis (if our data was images, we would probably want a different distortion measure then if we were using something like text data).
Example distortion functions include squared difference or
Hamming distance between the input \(x\) and \(t\), but there are many possible distortion functions that can be defined.</p>
<p>Now that we have a distortion function, we would like to be able to measure how much information is lost when using
the encoding defined by \(p(t|x)\). For this we can use the expected distortion \(D(p)\):</p>
<script type="math/tex; mode=display">D(p)= \sum_{x,t} p(x,t)d(x,t)</script>
<p>A smaller value of \(D(p)\) can be reached by making sure \(x,t\) pairs with
high joint probability (remember, \(p(x,t)=p(t|x)p(x)\)) have a low value
of \(d(x,t)\).</p>
<p>We can define how much information loss is too much by defining a threshold value \(D^*\), and refusing to use any compression method (ie
a specific set of representations \(T\))
whose expected distortion \(D(p) > D^*\).</p>
<p><strong>Rate Distortion Curves</strong> With all this in place we can now answer the golden question: What’s
the most we can possibly compress an input (equivalently, how low can we get the rate of \(T\)) if we are allowed to distort
the input up to threshold \(D^*\)? Note that we ask this question without
regard to the compression method. This is a theoretical limit, the best
that we can (or ever will) do. The answer is deceptively simple, the best
(lowest) rate we can get is equal to the mutual information under
the best possible compression method \(p(t|x)\). This value is
described through the rate distortion function, which takes in a threshold
\(D^*\) and returns the rate of the best possible \(p(t|x)\) we can use:</p>
<script type="math/tex; mode=display">R(D^*) = \min_{p(t|x):D(p) \leq D^*} I(X;T)</script>
<p>We can look at this problem as a constrained optimization problem, where we wish to find a \(p(t|x)\) that minimize \(I(T;X)\), with
a constraint on \(D(p)\) (this is typically turned into an unconstrained problem by using Lagrangian multipliers). Given we know \(P(X)\), the above optimization problem can, in practice, be solved!<sup id="fnref:4"><a href="#fn:4" class="footnote">4</a></sup>
Note that for each value of \(D^*\) we get a different value of
best possible rate, \(R\). We can plot \(R\) as a function of \(D^*\)
to get what is called the rate distortion curve. An example Rate Distortion Curve (taken from the <a href="https://en.wikipedia.org/wiki/Rate%E2%80%93distortion_theory">Wikipedia article</a>) is shown below.
This curve represents the optimal values of \(R\), any values above the curve are sub optimal
(meaning the representations \(T\) are sub optimal),
any values below the curve are theoretically impossible to achieve. This
idea of a rate distortion curve will be central later on, so make sure
you are comfortable with it. One of the important things to note on the curve is that there
is a constant theoretical tug of war between rate and distortion. You want to lower the rate? You’ll inevitably need
to add more distortion. Want to lower amount of distortion? Well your rate is going to have to go up.</p>
<p align="center">
<img src="/images/blog1/rdf_curve.png" alt="no image" />
</p>
<p>This is the main gist of of RDT, the question is can it be used to
define what information in our signal is relevant? Alas, the answer is
no. The biggest problem is that the distortion function itself must
be defined. By defining the distortion function, we are essentially
defining what information is relevant in the signal \(X\), which is the
very thing we are trying to find out! It’s clear that some of the ideas
of RDT can be useful for defining relevant information, however we need to be a bit more clever in terms of how we define relevancy.</p>
<p><strong>TL;DR Rate Distortion Theory</strong>
Here are the main takeaways from Rate Distortion Theory:</p>
<ul>
<li>Mutual Information \(I(X;T)\) can be seen as measuring the amount of <em>compression</em></li>
<li>There’s always a tug-a-war between our rate and distortion values</li>
<li>With the pairs of optimal rate and distortion values, we can define a curve called the <em>rate distortion curve</em></li>
<li>We have to provide a distortion measure, and because of this, Rate Distortion Theory doesn’t serve our mission of defining information relevancy in a rigorous way.</li>
</ul>
<h1 id="the-information-bottleneck-method">The Information Bottleneck Method</h1>
<p>The information bottleneck principle (IB) is quite similar to RDT, with some key differences. Instead of defining relevance directly through the set of codewords \(T\), IB instead
defines relevance through another variable \(Y\). That is, given an input signal \(X\), we would like to find a
compressed representation of \(X\) (again this set of codewords \(T\)) such that \(T\) preserves as much information about
the output signal \(Y\) as possible. Like RDT, we still want to obtain as compressed a representation of \(X\) as possible
(ie we still want to minimize \(I(X;T)\), however now our constraint is different. Our goal in IB is to choose
a representation \(T\) (which, like RDT, is still defined by \(p(t|x)\)) which preserves as much information about \(Y\) as
possible. How do we measure how much information about \(Y\) the representations \(T\) have? Mutual information of course!
Rather than constraining how much distortion occurs like in RDT, in IB we constrain how much information about \(Y\)
we are willing to lose by compressing \(X\) into \(T\). We do this by indicating the minimum value of \(I(T;Y)\) (denoted as \(I^*\)) we are willing
to have in a representation \(T\). This leads to the optimization problem:</p>
<script type="math/tex; mode=display">\min_{p(t|x):I(T;Y) \geq I^*} I(X;T)</script>
<p>Like RDT, this optimization problem can also be solved (given we know \(P(X,Y)\)).</p>
<p><strong>The Bottleneck</strong> In order to get a good representation \(T\), we must ‘squeeze out’ any information in \(X\) that is irrelevant to \(Y\), leaving us
only the parts of \(X\) relevant to \(Y\). This is the <em>bottleneck</em> in the information bottleneck principle.
For each value of \(I(T;Y)\) there is a corresponding minimal value of \(I(X;T)\). So just like in RDT we can define a sort
of rate distortion curve of corresponding optimal \(I(T;Y)\), \(I(X;T)\) pairs (to avoid confusion, I will refer to this as
the rate-information curve).
Just like in a rate distortion curve, this curve defines the optimal values, this time with impossible values laying
above the curve, and sub optimal values below the curve.
Note that just like in RDT, there is always tug of war between the rate value \(I(X;T)\), and the information value \(I(Y;T)\). It is important to realize that this curve is defined
solely by the distribution \(P(X,Y)\), since given this distribution we can solve the above optimization problem.
The set of all points (all possible pairs of \(I(X;T), I(T;Y)\) values, both optimal and sub-optimal) form what is called the <em>information plane</em>. As
we shall see, analyzing deep learning models with respect to the information plane will be a key insight the IB principle
brings to deep learning.</p>
<h1 id="deep-learning-and-the-information-bottleneck">Deep Learning and the Information Bottleneck</h1>
<p>We now come to the whole purpose of this write up, the IB connection with deep learning. As you might have been noticing, IB
sort of ‘smells’ like deep learning. In particular, you may have noticed
these following analogs: \(X\)=Inputs, \(T\)=Hidden Layers, \(Y\)=Outputs. Indeed you would be correct here!
In deep learning we can think of the set of codewords \(T\) as being the representations output by one of the hidden layers
of the network. Of course in deep learning, we typically have many hidden layers (let’s say we have \(k\) here), and thus
we have many sets of codewords for \(X\). For notational purposes, let \(T_i\) indicate the outputs for the \(ith\) hidden layer
. Other than the fact that we have \(k\) sets of codewords, we can treat everything pretty much the same as in IB theory.
With IB we can measure how good the representation of each layer is by looking at \(I(Y;T_i)\). We can additionally
see how much we have compressed the input \(X\) at the \(ith\) layer by looking at \(I(X;T_i)\).</p>
<p><strong>Losing Information</strong> An important fact to note is that each time we go through a layer of our network, we have no way to obtain extra
information about \(Y\). All the information that we know about \(Y\) comes from \(X\). Its not possible for new information about
\(Y\) to magically appear simply because we compress \(X\) into a new representation \(T_i\). So each time we obtain a new
representation \(T_i\) by going through the \(ith\) hidden layer, we lose a little information about \(Y\) (at the very best
we lose no information on \(Y\)). This can be described through the following inequality (called the <em>Data Processing Inequality</em>) for all \(j>i\):</p>
<script type="math/tex; mode=display">I(Y;X) \geq I(T_i;Y) \geq I(T_j;Y) \geq I(Y';Y)</script>
<p>\(Y’\) here indicates the outputs of our network.
Note that with each layer, we have an associated \(I(X;T)\) and \(I(T;Y)\)$ pair, which means we can thus plot it as
a point on the information curve.</p>
<p><strong>Learning by forgetting</strong> You may be scratching your head at this apparent contradiction. Doesn’t having multiple layers help us in deep learning? Why do we want to add more layers if
they simply make us lose information about \(Y\)?</p>
<p>First recall that the curve (and thus all values of \(I(X;T)\) and \(I(T;Y)\)) is defined
solely by the distribution \(P(X,Y)\), a distribution that we do not have access to. However we
do have access to a set of \(n\) samples, \(S\) from the true distribution. This of course, is our training
set. With \(S\) we can make an empirical estimate of the joint distribution, \(\hat{P}(X,Y)\) by using
some maximum likelihood method. With this empirical distribution, we can calculate the empirical
mutual information values, \(\hat{I}(X;T)\) and \(\hat{I}(T;Y)\). How far off our empirical estimate
\(\hat{I}(T;Y)\) is from \(I(T;Y)\) is proven to have the following bounds (similar bounds are proven
for \(\hat{I}(X;T)\) as well):</p>
<script type="math/tex; mode=display">I(T;Y) \leq \hat{I}(T;Y) + O\Big(\frac{|T|*|Y|}{\sqrt{n}}\Big)</script>
<p>Where \(|T|\) is the cardinality of \(T\), which is approximately equal to \(2^{I(X;T)}\). So the tightness
of the bound depends on how compressed \(T\) is! As \(T\) becomes more compressed, \(I(X;T)\) becomes
smaller, meaning the bound above becomes tighter, meaning our empirical estimate \(\hat{I}(T;Y)\)
is more likely to be closer to the true value of \(I(T;Y)\).</p>
<p>But why is this important? We don’t really care about estimating \(I(T;Y)\) in the end, what we
really care about is minimizing the generalization error (that is, minimize the probability of
a misclassification on an arbitrary instance from the test set). Surprisingly, it turns out that \(I(T;Y)\)
can actually act as a stand in for the generalization error!</p>
<p>For a classification problem, if we always choose a class \(y\) by picking the class that maximizes
the likelihood of \(y\) with respect to \(T\) (that is, maximizes the probability \(P(t|y)\)), then
\(I(T;Y)\) actually (roughly<sup id="fnref:5"><a href="#fn:5" class="footnote">5</a></sup>) forms a lower bound on the negative log probability of error.
Thus \(I(T;Y)\) can be used as a proxy for generalization performance! The higher
\(I(T;Y)\) is, the better our system generalizes. The lower \(I(X;T)\) is, the tighter the bound is
between our empirical estimate \(\hat{I}(T;Y)\) and the true generalization performance \(I(T;Y)\). As said
in Tishby and Zaslavsky, 2015:</p>
<blockquote>
<p>“The hidden layers must compress the input in order to reach a point where the worst case
generalization performance error is tolerable”.</p>
</blockquote>
<h1 id="looking-at-networks-in-the-information-plane">Looking at Networks in the Information Plane</h1>
<p>With this all set up, we can now think of our network as indicating a point somewhere on the
information plane (that is, with respect to \(I(T;Y)\) and \(I(X;T)\)). The poorly drawn chart below gives
a picture of this. The locations of our network on the information plane are indicated by the green
arrows. Of course, we start at a point (just the input \(X\)) with less compression, and more information about \(Y\) (since \(X\) contains all the information about \(Y\) that we can ever get). As we add hidden layers our representation becomes more and more compressed, and we lose information about \(Y\).
With our training data we can estimate an empirical information curve (represented by the black curve)
as well as a worst case bound for our estimated value of \(I(T;Y)\) (indicated by the red curve).</p>
<p>One important thing to remember about the chart is the
only thing that is really true with respect to the true distribution \(P(X,Y)\) is the red curve! The rest of the curves, including the location of our network on the information plane, are simply empirical estimates. However, we can take solace in knowing that we will never go below the red curve. Ideally
we would like to end up at the maximum value of the red curve, as this indicates the best worst case scenario for us, giving us the location on the information plane where the generalization bound is
tightest. We can refer to this point as \(R^*\). On the graph below, the point is indicated via the standard smiley face notation. Of course, there is no reason you should expect SGD to take us to the optimal point. How far we are from \(R^*\) on the \(I(X;T)\) axis is called the complexity
gap, \(C\). It indicates how much more we could have compressed our inputs. How far we are from \(R^*\) in
terms of \(I(T;Y)\) is called the generalization gap, \(G\). It indicates how much more information about
\(Y\) we could have stuffed into our representations.</p>
<p align="center">
<img src="/images/blog1/graph1.png" alt="no image" />
</p>
<h1 id="concluding-thoughts">Concluding Thoughts</h1>
<p>By looking through the lens of information theory we have developed a theoretically well motivated framework that provides us with another way of thinking about neural networks. (in terms of compression and information relevancy)
What does this mean for the future? The theoretical tools presented here could possibly
motivate a whole wide array of optimization techniques whose goal is to reach the elusive \(R^*\).
The whole concept of compression leading to better generalization might motivate theoretical analysis
into why techniques such as dropout (and other mysterious techniques) work. The possibilities are, as they say, countably infinite.
The next interesting step would be to try to analyze real networks in the information plane, which is
what <a href="https://arxiv.org/abs/1703.00810">“Opening the Black Box of Deep Neural Networks via Information”</a> does.
The next post will discuss these results, the contradictory results reported in a new ICLR submission, the
implications of these results, as
well as some of the practical challenges of analyzing networks in the information plane.</p>
<h1 id="one-more-concluding-thought">One more concluding thought…</h1>
<p>One last thing to note before you go. Above I said that \(I(T;Y)\) can form
a bound on generalization error <em>if</em> we use a decision rule which selects the class
\(y\) that maximizes the likelihood of \(y\). For networks that train on cross entropy error (most networks
that perform multi-class classification do), the decision rule typically employed is to select the
class the maximizes data likelihood (ie. maximizes the probability \(p(y|t)\)). This is the same as
maximizing \(p(t|y)\) if our distribution of class labels \(P(Y)\) is uniform, though this often is never
the case. So is this a cause for concern? I don’t really know, which is why I am putting it here. Intuitively switching the decision rules
from maximum likelihood to MAP seems like it shouldn’t have to much of an impact. I haven’t heard any
discussions about this, so it would be nice if future work could possibly address this issue (if its an issue). Feel
free to leave a comment if you have any insights into this.</p>
<h1 id="references">References</h1>
<ul>
<li>O. Shamir, S. Sabato, and N. Tishby. “Learning and generalization with the information bottleneck.” <em>ALT</em>. Vol 8. 2008.</li>
<li>R. Shwartz-Ziv and N. Tishby. “Opening the Black Box of Deep Neural Networks via Information”. <em>ArXiv</em>. <a href="https://arxiv.org/abs/1703.00810">https://arxiv.org/abs/1703.00810</a>. April 2017.</li>
<li>N. Tishby and N. Zaslavsky. “Deep Learning and the Information Bottleneck Principle”. <em>arXiv</em>. <a href="https://arxiv.org/abs/1503.02406">https://arxiv.org/abs/1503.02406</a>. March 2015.</li>
<li>N. Tishby, F.C. Pereira, and W. Bialek. “The information bottleneck method”. <em>arXiv</em>. <a href="https://arxiv.org/abs/physics/0004057">https://arxiv.org/abs/physics/0004057</a>. 1999.</li>
</ul>
<div class="footnotes">
<ol>
<li id="fn:1">
<p>For those interested, a rebuttal from Tishby and Shwartz-Ziv has been posted on the given link <a href="#fnref:1" class="reversefootnote">↩</a></p>
</li>
<li id="fn:2">
<p>There is a slight difference between the entropy rate and just plain old entropy, the difference however is not important here <a href="#fnref:2" class="reversefootnote">↩</a></p>
</li>
<li id="fn:3">
<p>It actually turns out that \(I(X;T)\) is even better than a lower on the ‘best achievable rate’… the two are equivalent! <a href="#fnref:3" class="reversefootnote">↩</a></p>
</li>
<li id="fn:4">
<p>The problem can be solved by an alternating algorithm called the Blahut-Arimoto algorithm, a similar method is used to solve the IB equations <a href="#fnref:4" class="reversefootnote">↩</a></p>
</li>
<li id="fn:5">
<p>See (Shamir, Sabato, and Tishby, 2008) for the proof of this bound, as well as the other generalization bounds for IB <a href="#fnref:5" class="reversefootnote">↩</a></p>
</li>
</ol>
</div>Noah Weber