Soft attention vs. hard attention

dimid picture dimid · Feb 22, 2016 · Viewed 16.8k times · Source

In this blog post, The Unreasonable Effectiveness of Recurrent Neural Networks, Andrej Karpathy mentions future directions for neural networks based machine learning:

The concept of attention is the most interesting recent architectural innovation in neural networks. [...] soft attention scheme for memory addressing is convenient because it keeps the model fully-differentiable, but unfortunately one sacrifices efficiency because everything that can be attended to is attended to (but softly). Think of this as declaring a pointer in C that doesn't point to a specific address but instead defines an entire distribution over all addresses in the entire memory, and dereferencing the pointer returns a weighted sum of the pointed content (that would be an expensive operation!). This has motivated multiple authors to swap soft attention models for hard attention where one samples a particular chunk of memory to attend to (e.g. a read/write action for some memory cell instead of reading/writing from all cells to some degree). This model is significantly more philosophically appealing, scalable and efficient, but unfortunately it is also non-differentiable.

I think I understood the pointer metaphor, but what is exactly attention and why is the hard one not differentiable?

I found an explanation about attention here, but still confused about the soft/hard part.

Answer

Sahil picture Sahil · Mar 7, 2016

What is exactly attention?

To be able to understand this question, we need to dive a little into certain problems which attention seeks to solve. I think one of the seminal papers on hard attention is Recurrent Models of Visual Attention and I would encourage the reader to go through that paper, even if it doesn't seem fully comprehensible at first.

To answer the question of what exactly is attention, I'll try and pose a different question which I believe is easier to answer. Which is, Why attention?. The paper I have linked seeks to answer that question succinctly and I'll reproduce a part of the reasoning here.

Imagine you were blindfolded and taken to a surprise birthday party and you just opened your eyes. What would you see? Birthday Party!

Now, when we say you see the picture, that's a shorter version of the following more technically correct sequence of actions, which is, to move your eyes around over time and gather information about the scene. You don't see every pixel of the image at once. You attend to certain aspects of the picture one time-step at a time and aggregate the information. Even in such a cluttered picture for example, you would recognize your uncle Bill and cousin Sam :). Why is that? Because you attend to certain salient aspects of the current image.

That is exactly the kind of power we want to give to our neural network models. Why? Think of this as some sort of regularization. (This portion of the answer references the paper) Your usual convolutional network model does have the ability to be able to recognize cluttered images but how do we find the exact set of weights which are "good"? That is a difficult task. By providing the network with a new architecture-level feature which allows it to attend to different parts of image sequentially and aggregate information over time, we make that job easier, because now the network can simply learn to ignore the clutter (or so is the hope).

I hope this answers the question What is hard attention?. Now onto the nature of its differentiability. Well, remember how we conveniently picked the correct spots to look at, while looking at the birthday picture? How did we do that? This process involves making choices which are difficult to represent in terms of a differentiable function of the input(image). For example, Based on what you've looked at already and the image, decide where to look next. You could have a neural network which outputs the answer here, but we do not know the correct answer! There is no correct answer in fact. How then are we to train the network parameters? Neural network training depends critically on a differentiable loss function of the inputs. Examples of such loss functions include the log-likelihood loss function, squared loss function etc. But in this case, we do not have a correct answer of where to look next. How then can we define a loss? This is where a field of machine learning called reinforcement learning(RL) comes in. RL allows you to do a gradient in the space of policies by using methods such as the reinforce method and the actor critic algorithms.

What is soft attention?

This part of the answer borrows from a paper which goes by the name teaching machines to read and comprehend. A major problem with RL methods such as the reinforce method is they have a high variance (in terms of the gradient of the expected reward computed) which scales linearly with the number of hidden units in your network. That's not a good thing, especially if you're going to build a large network. Hence, people try to look for differentiable models of attention. All this means is that the attention term and as a result the loss function are a differentiable function of the inputs and hence all gradients exist. Hence we can use our standard backprop algorithm along-with one of the usual loss functions for training our network. So what is soft attention?

In the context of text, it refers to the ability of the model to choose to associate more importance with certain words in the document vis-a-vis other tokens. If you're reading a document and have to answer a question based on it, concentrating on certain tokens in the document might help you answer the question better, than to just read each token as if it were equally important. That is the basic idea behind soft attention in text. The reason why it is a differentiable model is because you decide how much attention to pay to each token based purely on the particular token and the query in hand. You could for example represent the tokens of the document and the query in the same vector space and look at dot product/cosine similarity as a measure of how much attention should you pay to that particular token, given that query. Note that the cosine distance operation is completely differentiable with respect to its inputs and hence the overall model ends up being differentiable. Note that the exact model used by the paper differs and this argument is just for demonstration's sake, although other models do use a dot product based attention-score.