Language Modeling on One GPU: Single-headed attention competes with transformers.
The latest large, pretrained language models rely on trendy layers based on transformer networks. New research shows that these newfangled layers may not be necessary.
What’s new: Networks such as BERT and ERNIE take advantage of multi-headed attention layers to outcompete LSTM language models. But training these layers requires lots of compute on enormous GPU clusters. Stephen Merity of d⁄dx Times Labs struck a blow for garage AI with Single Headed Attention RNN (SHA-RNN), which nearly matched state-of-the-art performance after training on a single GPU for less than 24 hours. As he puts it in a tartly worded paper, “Take that, Sesame Street.”
Key insight: The author set out to find a high-performance language model suitable for his personal computer. He used a single attention head out of skepticism that multiple heads are worth their computational cost. Simplifying the transformer’s feed-forward network enabled him to run the model on a single GPU.
How it works: SHA-RNN is built on an LSTM to represent more explicitly the sequential nature of text.
- The model reads an input text sequence token by token and predicts the next token, usually a word or root of a word. The LSTM’s memory component stores important learned features.
- The LSTM’s output layer feeds the single-headed attention layer, which models relationships between tokens across the sequence.
- The attention layer’s output feeds a so-called boom layer. This layer replaces the transformer’s usual two feed-forward layers with a single feed-forward layer plus a summing layer to maintain vector length.
Results: Merity tested SHA-RNN by compressing the enwik8 dataset. More accurate language models use fewer bits to represent a sequence because they know, to some extent, which words will occur. SHA-RNN achieved 1.068 bits per character compared to 0.99 by Sparse Transformer — slightly less accurate, but in half as many parameters.
Yes, but: An LSTM is a good choice for sequential language-prediction tasks like enwik8. In non-sequential tasks such as fill-in-the-blanks, multi-headed attention is a better choice. A version of Transformer-XL that has even fewer parameters than SHA-RNN performed better on the compression task.
Why it matters: SHA-RNN isn’t an out-and-out replacement for transformer-based networks. But it shows that LSTMs remain relevant and useful in language modeling. And if you’re looking for a way to get people to read your research, the author’s style offers pointers: This paper is a very entertaining read!
We’re thinking: Researchers like to focus on optimizing state-of-the-art methods, and media hype frequently chases the latest leaderboard topper. Yet foundational algorithms remain valuable in a variety of contexts.