2023/04/01

"Attention" and "Transformers" in Large Language Models

Everyone is talking about OpenAI's ChatGPT these days. Here's a very quick attempt to summarize the core idea behind large language models (LLMs) like GPT.

"Attention is all you need" (aka the transformer paper) published in 2017 by Vaswani et al from Google is still the mother of current LLMs, including GPT.  "Effective Approaches to Attention-based Neural Machine Translation", an earlier paper by Luong et al from Stanford, was also quite important.

These are sequence-to-sequence models, i.e. their job is mapping an input sequence of text into an output sequence of text. Applications include translation from one language to another, answering questions,  having a conversation, etc.

They use language embeddings (made famous by Word2vec in 2013 and later by BERT, both also from Google) as the basic encoding/decoding building blocks, i.e. mapping text to vectors of real numbers in an "embedding space".

The main new idea is in the architecture of the neural network between the input encoding and output decoding stages. The model uses the preceding terms in the current output sequence to decide which parts of the input sequence to pay more "attention" to for the next output term. A bit more precisely: the previous output is a "query" which gets used to generate a linear combination of "keys" from the input which maps to a linear combination of "values" also from the input. That in turn gets transformed into the next output term with a few more layers in a plain feed forward network (i.e. a bunch of layers of neurons, where each neuron is putting a linear combination of inputs into non-linear activation function). Each step has trainable weights.

There are also clever tricks besides "attention". One is positional encoding to represent the order so the same input term in a different position has different effects even though, unlike in recurrent neural networks,  in transformers the network just sees them as bag of words that could be in any order. Another is layer normalization to sort of keep the nonlinear outputs within a reasonable area in the embedding vector space.

This architecture, as far as I know, was not derived explicitly from the way human brains work. The "attention" analogy is really useful, but there are no principles saying this architecture is more fundamental to intelligence, or more natural, than many others. It just happens to produce remarkably good results when the weights are trained properly.

So that's the basic idea of contemporary LLMs. Of course in some sense, all computer neural networks are  just a bunch of matrix multiplications and ad-hoc activation functions. But you can't just connect a large number of mathematical "neurons" randomly in a network and hope it learns something.  The choice of architecture, i.e. how the "neurons" are connected, is key. On top of that, there is still an enormous amount of innovation/engineering to make the real world language models, not to mention turn them into a product like ChatGPT or Google Bard.