Direct Preference Optimziation is method for improving the alignment of language models.
Current approaches for training large language models utilize an intial step of unsupervised training on a large corpus of data. This yields models that generate tokens most likely to follow prompts based on the conditional distribution of the data it was trained on. When we want our language model to be used as a chat bot or code assistant, this oftentimes produces undesirable text. High quality conversations or coding examples may be rare in our training corpus, thus rare in our model's output.
To address this, we can utilize a dataset to have our model generate text more closely aligned with a downstream task. Our dataset will have three columns: a prompt, a chosen output, and a rejected output. Given the prompt, we want our model to generate an output closer to the chosen output than the rejected output. In this example we use a dataset where the preferred output is considered more safe and appropriate for a chat bot.
How can we structure a this problem in a way that our model can learn from it?
One approach is Reinforcement Learning with Human Feedback (RLHF). With RLHF, we first train a binary classifier that learns to discriminate between outputs. A Bradley-Terry model is often used as the classifier here, where the probability that output y1 is preferred to y2 is defined as p(y1 >> y2 | x) = sigmoid(r(x,y1)-r(x,y2)), where >> denotes preferred to
. And r(x,y) is the reward. The reward function is learned using binary cross-entropy loss. As a starting point for our reward function, we can use a supervised and fine-tuned (SFT) LLM with the last layer changed to output a scalar value, representing the reward. The next step in RLHF is to use an algorithm like PPO to optimize our policy to maximize this reward while not straying too far from a reference policy. In this step, we sample prompts from our dataset and send them to our reward model to get a score. During this fine-tuning process, we aim to maximize our reward while mimizing the KL divergence between the token-level probability distribution output of our fine-tuned model and of the output of our reference model.
A downside of this approach is that it is expensive to first train a reward model and then sample from it inside the training loop during the fine-tuning process. RLHF is also complex and often unstable.
An alternative approach is DPO, which doesn't use reinforcement learning. From the DPO paper:
Instead of using the preference model to define a preference loss to train a reward model and then train a policy that optimizes the learned reward model, DPO uses a change of variables to define the preference loss as a function of the policy directly. Given a dataset of human preferences over model responses, DPO can therefore optimize a policy using a simple binary cross entropy objective, producing the optimal policy to an implicit reward function fit to the preference data.
With DPO, the preference model is now expressed in terms of the optimal policy and reference policy instead of using the reward model. The loss function, as layed out in the paper, is changed accordingly.
The trl library has a DPOTrainer, subclassed from the Huggingface Trainer class. DPOTrainer has an implementation of the DPO loss function, as well as methods to help with tokenization of inputs, generating predictions, and evaluation. It also supports utilization of PEFT (Parameter-Efficient Fine-Tuning), which we will use to implement LoRA (Low-Rank Adaption of Large Language Models). This, along with quantization (utilizing 8-bit model weights), allows us to reduce the number of trainable parameters and memory footprint of the model so that we can fine-tune the model on a local machine with a GPU.
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling
from trl import DPOTrainer
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import DataCollatorForLanguageModeling
# note: to get bitsandbytes to work on windows, uninstall bitsandbytes and reinstall with
# pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl
def process_hh_rlhf_sample(sample):
"""
sample is a dictionary with keys 'chosen' and 'rejected'.
Extract the prompt and the two completions from the sample.
Find the index of the last substring '\n\nAssistant:
Return a dictionary with keys prompt, chosen, and rejected.
"""
term = '\n\nAssistant: '
end_of_prompt_index = sample['chosen'].rfind(term)
# extract the prompt
prompt = sample['chosen'][:end_of_prompt_index+len(term)]
# extract the chosen completion
chosen = sample['chosen'][len(prompt):]
# extract the rejected completion
rejected = sample['rejected'][len(prompt):]
return {'prompt': prompt, 'chosen': chosen, 'rejected': rejected}
def get_anthropic_hh_rlhf_dataset(split='train'):
"""
The Anthropic HH-RLHF dataset contains 160k training examples and 8k test examples.
Each example is a dictionary with two keys: 'chosen' and 'rejected'.
Each of these includes the prompt and the completion.
I want to extract the prompt, chosen completion, and rejected completion.
https://arxiv.org/abs/2204.05862
https://huggingface.co/datasets/Anthropic/hh-rlhf
"""
dataset = load_dataset('Anthropic/hh-rlhf', split=split)
return dataset.map(process_hh_rlhf_sample)
bnb_config = BitsAndBytesConfig(
load_in_8bit=True
)
torch_dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
'stabilityai/stablelm-2-1_6b',
quantization_config=bnb_config,
torch_dtype=torch_dtype,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
'stabilityai/stablelm-2-1_6b',
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
# https://github.com/huggingface/trl/issues/1073
tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token})
tokenizer.bos_token_id = tokenizer.eos_token_id
train_dataset = get_anthropic_hh_rlhf_dataset(split='train')
test_dataset = get_anthropic_hh_rlhf_dataset(split='test[:1000]') # use a small test set for now
# define the training arguments
training_args = TrainingArguments(
max_steps=64, # only 64 gradient updates, not even one epoch
remove_unused_columns=False,
learning_rate=1e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
output_dir='output',
logging_strategy='steps',
logging_dir='logs',
logging_steps=16,
lr_scheduler_type='constant' # default is linear
)
peft_config = LoraConfig(
r=64, # dimension of the low-rank matrices
lora_alpha=16, # scaling factor for the weight matrices
bias='none', # don't train bias params
task_type='CASUAL_LM',
target_modules=[
'q_proj',
'k_proj',
'v_proj',
'o_proj',
'gate_proj',
'up_proj',
'down_proj',
'lm_head',
]
)
model = get_peft_model(model, peft_config)
def tokenize_func(examples):
return tokenizer(examples['prompt'], examples['chosen'], examples['rejected'], padding=True, truncation=True)
# tokenize the datasets
encoded_dataset_train = train_dataset.map(tokenize_func, batched=True)
encoded_dataset_test = test_dataset.map(tokenize_func, batched=True)
trainer = DPOTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
beta=0.1, # beta param for DPO loss
train_dataset=encoded_dataset_train,
eval_dataset=encoded_dataset_test,
max_length=512, # max length of the input
max_target_length=128, #
max_prompt_length=128,
generate_during_eval=False,
peft_config=peft_config
)
# Evaluate the model before training.
print(trainer.evaluate())
trainer.train()
# Save and evaluate the model after training.
trainer.save_model('out')
eval_results = trainer.evaluate()
print(eval_results)
# View the training loss after logged steps.
print(trainer.state.log_history)
# turn off warnings for this cell
import warnings
warnings.filterwarnings('ignore')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.to_device()
device = torch.device('cuda')
model = model.to(device)
prompt = 'Some popular cities for tourists are'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
input_ids = input_ids.to(device)
output_ids = model.generate(input_ids, max_length=128, do_sample=True, num_return_sequences=1)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))