Greater sequence lengths will set us free
# March 20, 2023
GPT-4 was announced this past week. Some key takeaways from their paper:
- Greater performance on human tests (Uniform Bar Exam, LSAT, AP Calculus BC) in addition to ML benchmarks, showing greater logical reasoning ability and the capability to synthesize information across multiple academic domains
- Introduction of multimodal input where images can be included alongside text, and where text prompts can reference the content of the images themselves
- Greater sequence lengths available, with models for 8K tokens and 32K tokens compared to GPT's current 4k
Improvements in (1) and (2) speak for themselves. But personally I'm far more excited about the trends we're seeing in (3).
Historical Attention Limitations
The explosion of transformer architectures (BERT all the way to the GPT family) was facilitated through attention. Attention allowed for far richer modeling of long-term text dependencies, since the end of a sequence could directly reference the start of a sequence with no loss of representational power. Generative models use this same insight, with the outputted text able to consider all the text that is provided in the prompt.
Attention unlocked the key performance leap from LSTMs and RNNs, where both suffer from a local-window bias because of their need to back-propagate through time. The further back you go in a sequence the more the gradient decreases because of continuous multiplication, classically known as the vanishing gradient problem. This prevented such long-range dependencies from being learned during a regular training loop.
Attention did have one core limitation. Whereas LSTMs can (theoretically) support an unlimited sequence length, attention models force a maximum sequence length up-front so they can add positional encodings and fit all time-steps in memory concurrently. This sequence length also isn't particularly long. It was 512 tokens for the original generation of transformers and has slowly been creeping into the thousands with larger server-based models like GPT. GPT-3 was capped at 2048 and 3.5 is 4096. This is roughly 3000 words assuming a 3/4 token:word ratio. Since classic attention complexity grows quadratically with input length, it becomes progressively harder to increase the model's context window.
Sequence Length In the Wild
Complex input domains (academic papers, financial reports, legal, code) typically have much larger windows of relevance. The preamble might provide term definition for things that follow. The "Related Works" section might provide context for the advancements revealed later on. The author might sketch out two arguments in their own sections and link them together in the third.
Chat might be an even more salient example with the explosion of assistants like ChatGPT and Microsoft Copilot. A conversation naturally meanders - and to provide a human-like experience, a model needs to be able to cite information provided way earlier in the dialogue. The aggregate of earlier information typically dictates the conversation path to follow. Ideally, it would be able to consider all previous sessions you've had with the model; building up its internal knowledge of your habits and preferences to adapt its communication style over time.
Previous implementations to parse these domains would rely on hacks - some pre-processing model to extract high-likelihood text regions, information retrieval of top likelihood documents, or heuristics about the user's current state. The limitation of LLM sequence length constrained the model from being able to learn the best internal representation for these domains.
Linear Attention Models
As far as I've seen, OpenAI hasn't released technical details on how they're increasing sequence length for GPT-4. If I had to guess it's some clever partitioning of GPU resources, a lower level optimized attention head, and potentially some compression of the input bits - like switching from 32/16 floating-point to 8bit representations.
There's been a niche field of research into sub-quadratic or linear attention implementations that allow for much longer context windows to fit into memory. These largely have not caught on, as mentioned in a good survey of the field:
Many applications even in language and vision are still dominated by vanilla Transformers with quadratic attention and none of these xformer variants have caught on as the defacto standard. There might be multiple explanations from multiple angles for this phenomena... As noted by (Rabe and Staats, 2021), for applications that require to flex on sequence length and memory needs time to time, it might be suffice to ‘just sequentially process it’ even if that might not be inherently as satisfying as finding a theoretical approximate.
In other words, typically sliding the context window of 512, 2048, or 4096 is enough to model most problems. But some of this perspective no doubt stems from the fields in which ML is already applied, not fields which might benefit from it in the future. Most of the classic benchmarks of language models are short text classification or generation that cap out at a few thousand tokens.
Context windows are more important with the current breed of chatbots. As LLMs have to interface with a large amount of user chat history in one conversation, or instruction follow over large and complex task definitions, it needs a greater view on user context.
I'm optimistic that the increase of data in fields that require long-context windows will encourage more of this research, and the deployment of these models into highly optimized LLMs. Existing linear attention models have mostly been technical proof of concepts, without shipping an end-to-end generative model like LLaMA or GPT. Some of this comes down to different goals. Many linear attention models are non-autoregressive because of their training process, so they can't do the text generation that has exploded over the past year. But others can and simply haven't benefited from the same computational resources as the LLMs being shipped from OpenAI and Facebook. I bet this will change.
The approach that's gaining the widest traction is Flash Attention. Unlike most of the other linear attention architectures, it provides a numerically exact implementation of attention that is still linear in memory usage (thanks to lower-level cuda memory access optimizations). Most of the new OSS models are being trained with Flash Attention by default; it's easily swappable for classic attention without some of the performance drawbacks of approximate methods. Having the same API might be what the linear attention community needed to really take off.
Lossless attention is non-human, but maybe that's okay
The framing of quadratic attention is quite different from how humans think. It provides the entire sequence at one time, allowing the model to recover the exact content that was provided earlier in a sequence. Humans obviously don't have this level of precision. When you're having a conversation, you can remember generally what was said a few minutes ago, but you can't remember word-for-word. Brains do varying degrees of information compression; short-term memory has moderate compression to allow for higher fidelity and long-term memory has more aggressive compression to allow more content to be stored.
In ML terms humans are creating a continuous representation of the conversation. Let's call this
R1. We contextualize new information
I that we receive based on our current representation. This serves two goals: to understand the additional information even when not self-explanatory
P(I|R1) as well as to update our internal representation for future turns
R1 -> R2.
LLMs that use quadric attention doesn't do this, which probably sets an upper limit on how large their context windows can grow. One benefit of linear attention research is typically an intermediary stage (like kernel approximation) that does vaguely represent the human compression process. But this compression process could be modeled more directly.
- We could model the representation state more explicitly by providing an external vector that captures session state. This vector necessarily has to be very high dimensional to handle all the subtleties of potentially pertinent information that happened previously. But in theory - and with enough training data - I believe a model can learn a useful compression into this intermediary state. It's a way to outsource some of the memory storage requirements without having to optimize attention directly.
- We could use existing LLMs to continuously summarize a longer session over time. Instead of verbatim appending new inputs to the session (
I1 + I2), we continuously ask the model to generate a new
R. This is guaranteed to keep the session plus input less than the maximum allowable sequence length, to allow for even longer context windows to be given in a text-only domain.
The second is much easier, since it layers on top of existing models. The drawback might be the same that occurs to people. Some concepts are more easily captured as abstract thoughts before pushing them into words. An abstract vector representation might provide that property better than forced language summarization.
I'm particularly excited to see production models with growing sequence lengths, alongside increasing research into sub-quadratic attention. The business mania around NLP right now might be the push that we need to double down on longer context windows. The longer the input sequence the more useful the current generation of LLMs will be: information extraction, summarization, and basic inference. I'm excited to get my hands on gpt4-32k; and even more excited for the prospect of gpt5-1m. A man can dream, right?