-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
263 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
# The code is adapted from Huggingface. | ||
# coding=utf-8 | ||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source | ||
# TODO: bump transformers version in requirements at next release. | ||
|
||
# 0. imports | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from accelerate import PartialState | ||
from datasets import Dataset, load_dataset | ||
from peft import LoraConfig | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments | ||
|
||
from trl import DPOTrainer, is_xpu_available | ||
|
||
|
||
# Define and parse arguments. | ||
@dataclass | ||
class ScriptArguments: | ||
""" | ||
The arguments for the DPO training script. | ||
""" | ||
|
||
# data parameters | ||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) | ||
|
||
# training parameters | ||
model_name_or_path: Optional[str] = field(default="models/rlhf_step1_sft/", metadata={"help": "the model name"}) | ||
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"}) | ||
#per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"}) | ||
# dh | ||
per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "batch size per device"}) | ||
gradient_accumulation_steps: Optional[int] = field( | ||
default=1, metadata={"help": "the number of gradient accumulation steps"} | ||
) | ||
output_dir: Optional[str] = field(default="outputdpo", metadata={"help": "the output directory"}) | ||
fp16: Optional[bool] = field( | ||
default=False, metadata={"help": "Whether to activate fp16 mixed precision during training"} | ||
) | ||
bf16: Optional[bool] = field( | ||
default=False, metadata={"help": "Whether to activate bf16 mixed precision during training"} | ||
) | ||
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"}) | ||
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"}) | ||
max_target_length: Optional[int] = field( | ||
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} | ||
) | ||
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"}) | ||
#max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) | ||
max_steps: Optional[int] = field(default=500, metadata={"help": "max number of training steps"}) | ||
|
||
# lora parameters | ||
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"}) | ||
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) | ||
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) | ||
# instrumentation | ||
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"}) | ||
report_to: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' | ||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' | ||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' | ||
}, | ||
) | ||
# debug argument for distributed training | ||
ignore_bias_buffers: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" | ||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" | ||
}, | ||
) | ||
gradient_checkpointing: Optional[bool] = field( | ||
default=False, metadata={"help": "Whether to use gradient checkpointing or no"} | ||
) | ||
gradient_checkpointing_kwargs: Optional[dict] = field( | ||
default=None, | ||
metadata={ | ||
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" | ||
}, | ||
) | ||
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) | ||
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) | ||
generate_during_eval: Optional[bool] = field(default=False, metadata={"help": "Generate during evaluation"}) | ||
|
||
|
||
def extract_anthropic_prompt(prompt_and_response): | ||
"""Extract the anthropic prompt from a prompt and response pair.""" | ||
search_term = "\n\nAssistant:" | ||
search_term_idx = prompt_and_response.rfind(search_term) | ||
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" | ||
return prompt_and_response[: search_term_idx + len(search_term)] | ||
|
||
|
||
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset: | ||
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. | ||
The dataset is converted to a dictionary with the following structure: | ||
{ | ||
'prompt': List[str], | ||
'chosen': List[str], | ||
'rejected': List[str], | ||
} | ||
Prompts should be structured as follows: | ||
\n\nHuman: <prompt>\n\nAssistant: | ||
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. | ||
""" | ||
import pdb; pdb.set_trace() | ||
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir) | ||
if sanity_check: | ||
dataset = dataset.select(range(min(len(dataset), 1000))) | ||
|
||
def split_prompt_and_responses(sample) -> Dict[str, str]: | ||
prompt = extract_anthropic_prompt(sample["chosen"]) | ||
return { | ||
"prompt": prompt, | ||
"chosen": sample["chosen"][len(prompt) :], | ||
"rejected": sample["rejected"][len(prompt) :], | ||
} | ||
|
||
return dataset.map(split_prompt_and_responses) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser(ScriptArguments) | ||
script_args = parser.parse_args_into_dataclasses()[0] | ||
|
||
if script_args.load_in_8bit and script_args.load_in_4bit: | ||
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") | ||
elif script_args.load_in_8bit or script_args.load_in_4bit: | ||
quantization_config = BitsAndBytesConfig( | ||
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit | ||
) | ||
# Copy the model to each device | ||
device_map = ( | ||
{"": f"xpu:{PartialState().local_process_index}"} | ||
if is_xpu_available() | ||
else {"": PartialState().local_process_index} | ||
) | ||
torch_dtype = torch.bfloat16 | ||
else: | ||
# device_map = None | ||
# dh | ||
device_map = "auto" | ||
quantization_config = None | ||
torch_dtype = None | ||
|
||
# 1. load a pretrained model | ||
model = AutoModelForCausalLM.from_pretrained( | ||
script_args.model_name_or_path, | ||
device_map=device_map, | ||
quantization_config=quantization_config, | ||
torch_dtype=torch_dtype, | ||
) | ||
|
||
if script_args.ignore_bias_buffers: | ||
# torch distributed hack | ||
import pdb; pdb.set_trace() | ||
model._ddp_params_and_buffers_to_ignore = [ | ||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool | ||
] | ||
import pdb; pdb.set_trace() | ||
|
||
if not script_args.use_peft: | ||
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) | ||
else: | ||
# If one uses PEFT, there is no need to load a reference model ## dh: TODO: CHECK THIS | ||
model_ref = None | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path) | ||
if tokenizer.pad_token is None: | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
# 2. Load the Anthropic Helpful-Harmless dataset | ||
# train_dataset = get_hh("train", sanity_check=script_args.sanity_check) | ||
|
||
# 3. Load evaluation dataset | ||
# eval_dataset = get_hh("test", sanity_check=script_args.sanity_check) | ||
|
||
# dh | ||
dataset= load_dataset("csv", data_files="data/rlhf_training_data_d2ai.csv", split="train") | ||
def feature_format(sample) -> Dict[str, str]: | ||
return { | ||
"prompt": sample["input"], | ||
"chosen": sample["chosen"], | ||
"rejected": sample["rejected"], | ||
} | ||
dataset = dataset.map(feature_format) | ||
train_eval = dataset.train_test_split(test_size=0.1) | ||
import pdb; pdb.set_trace() | ||
train_dataset = train_eval["train"] | ||
eval_dataset = train_eval["test"] | ||
|
||
|
||
# 4. initialize training arguments: | ||
training_args = TrainingArguments( | ||
per_device_train_batch_size=script_args.per_device_train_batch_size, | ||
max_steps=script_args.max_steps, | ||
remove_unused_columns=False, | ||
gradient_accumulation_steps=script_args.gradient_accumulation_steps, | ||
learning_rate=script_args.learning_rate, | ||
evaluation_strategy="steps", | ||
logging_first_step=True, | ||
logging_steps=10, # match results in blog post | ||
eval_steps=500, | ||
output_dir=script_args.output_dir, | ||
optim="rmsprop", | ||
warmup_steps=150, | ||
report_to=script_args.report_to, | ||
bf16=script_args.bf16, | ||
fp16=script_args.fp16, | ||
gradient_checkpointing=script_args.gradient_checkpointing, | ||
# TODO: uncomment that on the next transformers release | ||
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, | ||
) | ||
|
||
if script_args.use_peft: | ||
peft_config = LoraConfig( | ||
r=script_args.peft_lora_r, | ||
lora_alpha=script_args.peft_lora_alpha, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
else: | ||
peft_config = None | ||
|
||
# 5. initialize the DPO trainer | ||
dpo_trainer = DPOTrainer( | ||
model, | ||
model_ref, | ||
args=training_args, | ||
beta=script_args.beta, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
tokenizer=tokenizer, | ||
max_length=script_args.max_length, | ||
max_target_length=script_args.max_target_length, | ||
max_prompt_length=script_args.max_prompt_length, | ||
generate_during_eval=script_args.generate_during_eval, | ||
peft_config=peft_config, | ||
) | ||
|
||
# 6. train | ||
import pdb; pdb.set_trace() | ||
dpo_trainer.train() |