GANs for Sequence Modeling

Today we’ll be trying to demystify the SeqGAN paper. We’ll go over

  • Why we might want to use GANs for sequence modeling,
  • Why we can’t apply GANs to sequence modeling directly,
  • And how we can reformulate sequence modeling with reinforcement learning to get past this barrier.

Reviewing RNNs

Let’s remind ourselves of how language modeling is typically done these days. We have some prefix of words $w_1,\ldots,w_{n}$ and we want to generate $w_{n+1}$, so in particular we want to learn the conditional distribution

\[\begin{equation} P\left(w_{n+1} | w_1, \ldots, w_{n}\right)\\ \end{equation}\]

for every $w_{n+1}$ in our vocabulary $V$. If we know this conditional distribution, one reasonable $w_{n+1}$ to choose might be

\[\begin{equation} \underset{w_{n+1} \in V}{\arg \max}\hspace{0.1cm} P\left(w_{n+1} | w_1, \ldots, w_{n}\right). \end{equation}\]

In practice, we approximate this conditional distribution for every word $w_{n+1}$ and for every prefix $w_1,\ldots,w_{n}$ by a language model (LM), which consumes the prefix words $w_1, \ldots, w_{n-1}$ and passes them to an encoder to generate a hidden state $h_{n-1}$, a vector representation of the prefix. RNN language models will consume the prefix words one at a time, while transformer language models will consume them all at once. The language model then passes $h_{n-1}$ along with $w_n$ through a decoder to generate a vector of scores (logits) $s_1, \ldots, s_{| V |}$, one for each word in the vocabulary $V$. These scores can then be passed through a $\operatorname{softmax}$ to get squashed into probabilities to arrive at a predicted conditional distribution \(\begin{equation} \hat P\left(w_{n+1} | w_1, \ldots, w_{n}\right). \end{equation}\)

During training, $w_1,\ldots,w_n, w_{n+1}$ are consecutive words from our dataset and we penalize the model by the negative log likelihood loss

\[\begin{equation} L(w_1, \ldots, w_n, w_{n+1}) = -\log \hat P\left( w_{n+1} | w_1, \ldots, w_n \right). \end{equation}\]

Now suppose we have a trained LM and we want to generate text from it (we call this inference). We can do this by generating one word at a time and appending to a prefix which we condition on to get the next word, feeding outputed words back into our LM autoregressively to get new next words. To keep things simple, suppose we use a greedy decoding - that is, at each step, we take $\arg \max$ over the output probability distribution (or equivalently output logits) to choose the next word. We then have the relation

\[\begin{equation} w_{n+1} = \arg \max \operatorname{LM}(h_{n-1}, w_n), \end{equation}\]

recalling that $h_{n-1}$ is a hidden state representing the prefix $w_1, \ldots, w_{n-1}$.

Thus, during training, the input prefixes come from the data, while during inference, they come from the outputs of the model.


RNNs are unreasonably effective. At the very least, they might be better at blog writing than this poor author. However, they are not devoid of weaknesses.

Problem: Exposure bias

Recall that we have a distribution shift between training and inference - during training, inputs come from the data, while during inference, inputs come from the outputs of the model. We call this problem exposure bias - the model is never exposed to its own output during training, and so during inference with every successive word we add to the prefix from the model output, the generated text distribution drifts further from the data text distribution.

Of course, one solution is to expose the model to its own input during training (this is known as scheduled sampling). However, briefly, the problem with this approach is that the label for an input still comes from the data, so if the model generates “I like eating chicken,” the next word the model generates might be “nuggets,” which is perfectly reasonable. However, the excerpt from the text may have in fact been “I like eating elderberry jam”, in which case “nuggets” is a ridiculous answer to the data prefix “I like eating elderberry”. We’ll come back to the exposure bias problem shortly.

Problem: Penalizing full generations

In traditional LM training, the penalties are assigned word by word - given a prefix from the data, maximize the predicted likelihood of the next word from the data. However, this chained generation strategy does not assign any penalty holistically to the entire generation. And in fact for greedily decoded generations, backpropagating a holistic penalty is futile.

Recall that in greedy decoding, we generate a piece of text word by word in sequence and pick the next word by taking an $\arg \max$ over the $| V |$ logits in the output layer of our LM, where $\arg \max (\mathbf x)$ gives us the index of $\mathbf x$ of greatest value. Taking zero-indexing to be the convention, $\arg \max([7, 10]) = 1$. The below figure shows a plot of $z = \arg \max(x, y)$. Try hovering your cursor over the figure. Do you notice any problems this might cause for gradient updates?