Skip to content

Commit

Permalink
DPO training on d2l data. Version 0
Browse files Browse the repository at this point in the history
  • Loading branch information
llauraa23 committed Jan 24, 2024
1 parent beeedaa commit 04b9fa5
Showing 1 changed file with 263 additions and 0 deletions.
263 changes: 263 additions & 0 deletions pykoi/rlhf/rl_finetuning_dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# The code is adapted from Huggingface.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source
# TODO: bump transformers version in requirements at next release.

# 0. imports
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from accelerate import PartialState
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments

from trl import DPOTrainer, is_xpu_available


# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the DPO training script.
"""

# data parameters
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

# training parameters
model_name_or_path: Optional[str] = field(default="models/rlhf_step1_sft/", metadata={"help": "the model name"})
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
#per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
# dh
per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
output_dir: Optional[str] = field(default="outputdpo", metadata={"help": "the output directory"})
fp16: Optional[bool] = field(
default=False, metadata={"help": "Whether to activate fp16 mixed precision during training"}
)
bf16: Optional[bool] = field(
default=False, metadata={"help": "Whether to activate bf16 mixed precision during training"}
)
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: Optional[int] = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
#max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
max_steps: Optional[int] = field(default=500, metadata={"help": "max number of training steps"})

# lora parameters
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
# instrumentation
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
default=None,
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
},
)
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
generate_during_eval: Optional[bool] = field(default=False, metadata={"help": "Generate during evaluation"})


def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
import pdb; pdb.set_trace()
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))

def split_prompt_and_responses(sample) -> Dict[str, str]:
prompt = extract_anthropic_prompt(sample["chosen"])
return {
"prompt": prompt,
"chosen": sample["chosen"][len(prompt) :],
"rejected": sample["rejected"][len(prompt) :],
}

return dataset.map(split_prompt_and_responses)


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

if script_args.load_in_8bit and script_args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
# Copy the model to each device
device_map = (
{"": f"xpu:{PartialState().local_process_index}"}
if is_xpu_available()
else {"": PartialState().local_process_index}
)
torch_dtype = torch.bfloat16
else:
# device_map = None
# dh
device_map = "auto"
quantization_config = None
torch_dtype = None

# 1. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
device_map=device_map,
quantization_config=quantization_config,
torch_dtype=torch_dtype,
)

if script_args.ignore_bias_buffers:
# torch distributed hack
import pdb; pdb.set_trace()
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
import pdb; pdb.set_trace()

if not script_args.use_peft:
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
else:
# If one uses PEFT, there is no need to load a reference model ## dh: TODO: CHECK THIS
model_ref = None

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# 2. Load the Anthropic Helpful-Harmless dataset
# train_dataset = get_hh("train", sanity_check=script_args.sanity_check)

# 3. Load evaluation dataset
# eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)

# dh
dataset= load_dataset("csv", data_files="data/rlhf_training_data_d2ai.csv", split="train")
def feature_format(sample) -> Dict[str, str]:
return {
"prompt": sample["input"],
"chosen": sample["chosen"],
"rejected": sample["rejected"],
}
dataset = dataset.map(feature_format)
train_eval = dataset.train_test_split(test_size=0.1)
import pdb; pdb.set_trace()
train_dataset = train_eval["train"]
eval_dataset = train_eval["test"]


# 4. initialize training arguments:
training_args = TrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
max_steps=script_args.max_steps,
remove_unused_columns=False,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
logging_first_step=True,
logging_steps=10, # match results in blog post
eval_steps=500,
output_dir=script_args.output_dir,
optim="rmsprop",
warmup_steps=150,
report_to=script_args.report_to,
bf16=script_args.bf16,
fp16=script_args.fp16,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next transformers release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
)

if script_args.use_peft:
peft_config = LoraConfig(
r=script_args.peft_lora_r,
lora_alpha=script_args.peft_lora_alpha,
bias="none",
task_type="CAUSAL_LM",
)
else:
peft_config = None

# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=script_args.max_length,
max_target_length=script_args.max_target_length,
max_prompt_length=script_args.max_prompt_length,
generate_during_eval=script_args.generate_during_eval,
peft_config=peft_config,
)

# 6. train
import pdb; pdb.set_trace()
dpo_trainer.train()

0 comments on commit 04b9fa5

Please sign in to comment.