This repo contains the code for the experiments in our paper:
Towards a Theoretical Understanding of the 'Reversal Curse' via Training Dynamics
Hanlin Zhu, Baihe Huang, Shaolun Zhang, Michael Jordan, Jiantao Jiao, Yuandong Tian, Stuart Russell
conda create -n reversal_curse python=3.10
conda activate reversal_curse
pip install -r requirements.txt
To run the standard reversal logic experiment of the form
CUDA_VISIBLE_DEVICES=0 python3 -m src.scripts.train_reverse
To run the standard reversal logic experiment of the form
CUDA_VISIBLE_DEVICES=0 python3 -m src.scripts.train_reverse_logits
To run the standard reversal logic experiment of the form
CUDA_VISIBLE_DEVICES=0 python3 -m src.scripts.train_reverse_embed
To run the reversal logic experiment with In-Context Learning of the form
CUDA_VISIBLE_DEVICES=0 python3 -m src.scripts.train_reverse_ICL
To run the standard Chain-of-Thought experiment of the form
CUDA_VISIBLE_DEVICES=0 python3 -m src.scripts.train_chain
To run the standard Chain-of-Thought experiment of the form
CUDA_VISIBLE_DEVICES=0 python3 -m src.scripts.train_reverse_logits
To run the alternative version of Chain-of-Thought experiment with correlated tokens of the form
CUDA_VISIBLE_DEVICES=0 python3 -m src.scripts.train_chain_related_tokens
When running the scripts for each experiment, there are several command line argument that can be passed to customize model configuration, training hyperparameters and dataset generation. The following example demonstrates how these arguments can be passed via command line, and their respective default values:
CUDA_VISIBLE_DEVICES=0 python3 -m src.scripts.train_reverse \
--pos_encode_type 'absolute' \ # positional embedding: 'null', 'absolute', 'rotary'
--n_layers 24 \ # number of transformer layers
--embed_dim 768 \ # dimension of the token embedding vectors
--vocab_size 800 \ # size of the vocabulary from which the datasets are constructed
--word_size 1 \ # number of tokens that forms an entity
--seed 1234 \
--num_epochs 3000 \
--batch_size 64 \
--lr 0.01 \
--decay 0.9
--betas (0.9, 0.999) \
--loss_whole_sequence \ # see note
--freeze_wte_wpe \ # see note
--output_dir 'exp_reverse' \ # where plots will be saved
NOTE: The defautl setting is without passing the follow flags:
With --loss_whole_sequence
flag, loss is applied to entire input sequence, otherwise only applied to tokens corresponding to the last entity.
With --freeze_wte_wpe
flag, the token embedding matrix and the positional embedding matrices of the model are frozen, otherwise they are trainable.
We used this implementation of the Rotary Embeddings in out experiments.