This paper's been doing the rounds, so I thought I'd give a mechanistic interpretability take on what's going on here!
— Neel Nanda (@NeelNanda5) September 24, 2023
The core intuition is that "When you see 'A is', output B" is implemented as an asymmetric look-up table, with an entry for A->B.
B->A would be a separate entry https://t.co/XROT8D5MrS
Here's the rest of Nanda's tweet stream:
The key question to ask with a mystery like this about models is what algorithms are needed to get the correct answer, and how these can be implemented in transformer weights. These are what get reinforced when fine-tuning.
The two hard parts of "A is B" are recognising the input tokens A (out of all possible input tokens) and connecting this to the action to output tokens B (out of all possible output tokens). These are both hard!
Further, the A -> B look-up must happen on a single token position
Intuitively, the algorithm here has early attention heads attend to the prev token to create a previous token subspace on the Cruise token. Then an MLP neuron activates on "Current==Cruise & Prev==Tom" and outputs "Output=Mary", "Next Output=Lee" and "Next Next Output=Pfeiffer"
"Output=Mary" directly connects to the unembed, and "Next Output=Lee" etc gets moved by late attention heads to subsequent tokens once Mary is output.
Crucially, there's an asymmetry between "input A" and "output A". Inputs are around at early layers, come from input embeddings, and touch the input weights of MLP neurons. Outputs are around more at late layers, compose with the unembedding, and come from output weights of MLPs
This is especially true with multi-token A and B. Detecting "Tom Cruise" is saying "the current token embedding is Cruise, and the prev token space says Tom", while output "Tom Cruise" means to output the token Tom, and then a late attn head move "output Cruise" to the next token
Thus, when given a gradient signal to output B given "A is" it reinforces/creates a lookup "A -> B", but doesn't create "B->A", these are different lookups, in different parameters, and there's no gradient signal from one to the other.
How can you fix this? Honestly, I can't think of anything. I broadly think of this as LLMs working as intended. They have a 1 way flow from inputs to outputs, and a fundamental asymmetry between inputs and outputs. It's wild to me to expect symmetry/flow reversing to be possible
Why is this surprising at all then? My guess is that symmetry is intuitive to us, and we're used to LLMs being capable of surprising and impressive things, so it's weird to see something seemingly basic missing.
LLMs are not human! Certain things are easy for us and not for them, and vice versa. My guess is that the key difference here is that when detecting/outputting specific tokens, the LLM has no notion of a variable that can take on arbitrary values - a direction has fixed meaning Neel Nanda
A better analogy might be in-context learning, where LLMs CAN use "variables". The text "Tom Cruise is the son of Mary Lee Pfeiffer. Mary Lee Pfeiffer is the mother of..." has the algorithmic solution "Attend to the subject of sentence 1 (Tom Cruise), and copy to the output"
Unsurprisingly, the model has no issue with reversing facts in context! Intuitively, when I remember a fact A is B, it's closer to a mix of retrieving it into my "context window" and then doing in-context learning, rather than pure memorised recall.
* * * * *
The paper: The Reversal Curse: LLMs trained on "A is B" fail to learn "B is A."
No comments:
Post a Comment