Representations in autoregressive models
May 11, 2023
I took a computer vision course in college with a rather provocative professor. One of the more memorable lectures opened with a declaration: representations in machine learning are everything1. With a good enough representation, everything is basically a linear classifier. Focus on the representation, not on the network.
When we were simultaneously training neural networks with millions of parameters, I thought it a rather insane thing to say.
Technically he's right - of course. If you can shove most of the difficult work of taking an input and projecting it into a numerical representation (and that numerical representation is conditioned on the task that you want to accomplish), you by definition have turned your problem into a separable one. Once you've solved your problem, you've basically... solved your problem. Technically speaking the representations can collapse to one common set of representations for True and one common set of representations for False.
It's interesting to think about this again with the recent wave of autoregressive models. The latency in the current generation of generative models are caused because they feed every newly generated token back into the decoder - one by one by one. The longer your desired output the longer it's going to take the model to generate the answer.
But if you can fit a lot of data into the initial request to the model, and manage to frame the problem it as a binary classification, results can be delivered almost instantly. You still need the model to encode the initial string but this is lightning fast. It's a few big tensor multiplications and can be parallelized accordingly. That's it. Results happen in one decoder step.
Representations are hidden to the average user of a language model but they're hiding just underneath the surface. It's the representation that lives within a layer or two short of the projection head, right before the network has to decide the next word to output. It has a sense of what you're intending to do; at least in so far as that intent minimizes the perplexity of what comes next.
Interestingly, encouraging the model to "think" or "criticize" itself in writing helps to refine these representations further. Encouraging some text-based linear reasoning can be enough to arrive at a right answer. This is true even though jumping straight to predicting the output might result in a wrong one.
Another perspective on LLMs is that they're universal representation agents, and able to condition a good representation based on the goals that you want it to achieve. Once it has that internal state the eventual projection to [True,False] is the easy part. Representations might not be everything but they come pretty close.
Representations being the numerical equivalents of real world objects. Neural networks operate on numbers; not words or images. So before a network can even start learning, you need to make some choices about how to turn those words into floating points (bigram, character-level, wordpiece, byte-pair encoding, etc). ↢