diff --git a/pykoi/rlhf/rl_finetuning_dpo.py b/pykoi/rlhf/rl_finetuning_dpo.py new file mode 100644 index 0000000..4a8fc38 --- /dev/null +++ b/pykoi/rlhf/rl_finetuning_dpo.py @@ -0,0 +1,263 @@ +# The code is adapted from Huggingface. +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source +# TODO: bump transformers version in requirements at next release. + +# 0. imports +from dataclasses import dataclass, field +from typing import Dict, Optional + +import torch +from accelerate import PartialState +from datasets import Dataset, load_dataset +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments + +from trl import DPOTrainer, is_xpu_available + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the DPO training script. + """ + + # data parameters + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + + # training parameters + model_name_or_path: Optional[str] = field(default="models/rlhf_step1_sft/", metadata={"help": "the model name"}) + learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"}) + #per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"}) + # dh + per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "batch size per device"}) + gradient_accumulation_steps: Optional[int] = field( + default=1, metadata={"help": "the number of gradient accumulation steps"} + ) + output_dir: Optional[str] = field(default="outputdpo", metadata={"help": "the output directory"}) + fp16: Optional[bool] = field( + default=False, metadata={"help": "Whether to activate fp16 mixed precision during training"} + ) + bf16: Optional[bool] = field( + default=False, metadata={"help": "Whether to activate bf16 mixed precision during training"} + ) + max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"}) + max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"}) + max_target_length: Optional[int] = field( + default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} + ) + label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"}) + #max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + max_steps: Optional[int] = field(default=500, metadata={"help": "max number of training steps"}) + + # lora parameters + use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"}) + peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) + peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) + # instrumentation + sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"}) + report_to: Optional[str] = field( + default=None, + metadata={ + "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' + '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' + 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' + }, + ) + # debug argument for distributed training + ignore_bias_buffers: Optional[bool] = field( + default=False, + metadata={ + "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "Whether to use gradient checkpointing or no"} + ) + gradient_checkpointing_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" + }, + ) + load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) + load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) + generate_during_eval: Optional[bool] = field(default=False, metadata={"help": "Generate during evaluation"}) + + +def extract_anthropic_prompt(prompt_and_response): + """Extract the anthropic prompt from a prompt and response pair.""" + search_term = "\n\nAssistant:" + search_term_idx = prompt_and_response.rfind(search_term) + assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" + return prompt_and_response[: search_term_idx + len(search_term)] + + +def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset: + """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'prompt': List[str], + 'chosen': List[str], + 'rejected': List[str], + } + + Prompts should be structured as follows: + \n\nHuman: \n\nAssistant: + Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. + """ + import pdb; pdb.set_trace() + dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir) + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def split_prompt_and_responses(sample) -> Dict[str, str]: + prompt = extract_anthropic_prompt(sample["chosen"]) + return { + "prompt": prompt, + "chosen": sample["chosen"][len(prompt) :], + "rejected": sample["rejected"][len(prompt) :], + } + + return dataset.map(split_prompt_and_responses) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + if script_args.load_in_8bit and script_args.load_in_4bit: + raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") + elif script_args.load_in_8bit or script_args.load_in_4bit: + quantization_config = BitsAndBytesConfig( + load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit + ) + # Copy the model to each device + device_map = ( + {"": f"xpu:{PartialState().local_process_index}"} + if is_xpu_available() + else {"": PartialState().local_process_index} + ) + torch_dtype = torch.bfloat16 + else: + # device_map = None + # dh + device_map = "auto" + quantization_config = None + torch_dtype = None + + # 1. load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + device_map=device_map, + quantization_config=quantization_config, + torch_dtype=torch_dtype, + ) + + if script_args.ignore_bias_buffers: + # torch distributed hack + import pdb; pdb.set_trace() + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + import pdb; pdb.set_trace() + + if not script_args.use_peft: + model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) + else: + # If one uses PEFT, there is no need to load a reference model ## dh: TODO: CHECK THIS + model_ref = None + + tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # 2. Load the Anthropic Helpful-Harmless dataset + # train_dataset = get_hh("train", sanity_check=script_args.sanity_check) + + # 3. Load evaluation dataset + # eval_dataset = get_hh("test", sanity_check=script_args.sanity_check) + + # dh + dataset= load_dataset("csv", data_files="data/rlhf_training_data_d2ai.csv", split="train") + def feature_format(sample) -> Dict[str, str]: + return { + "prompt": sample["input"], + "chosen": sample["chosen"], + "rejected": sample["rejected"], + } + dataset = dataset.map(feature_format) + train_eval = dataset.train_test_split(test_size=0.1) + import pdb; pdb.set_trace() + train_dataset = train_eval["train"] + eval_dataset = train_eval["test"] + + + # 4. initialize training arguments: + training_args = TrainingArguments( + per_device_train_batch_size=script_args.per_device_train_batch_size, + max_steps=script_args.max_steps, + remove_unused_columns=False, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + learning_rate=script_args.learning_rate, + evaluation_strategy="steps", + logging_first_step=True, + logging_steps=10, # match results in blog post + eval_steps=500, + output_dir=script_args.output_dir, + optim="rmsprop", + warmup_steps=150, + report_to=script_args.report_to, + bf16=script_args.bf16, + fp16=script_args.fp16, + gradient_checkpointing=script_args.gradient_checkpointing, + # TODO: uncomment that on the next transformers release + # gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, + ) + + if script_args.use_peft: + peft_config = LoraConfig( + r=script_args.peft_lora_r, + lora_alpha=script_args.peft_lora_alpha, + bias="none", + task_type="CAUSAL_LM", + ) + else: + peft_config = None + + # 5. initialize the DPO trainer + dpo_trainer = DPOTrainer( + model, + model_ref, + args=training_args, + beta=script_args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + max_length=script_args.max_length, + max_target_length=script_args.max_target_length, + max_prompt_length=script_args.max_prompt_length, + generate_during_eval=script_args.generate_during_eval, + peft_config=peft_config, + ) + + # 6. train + import pdb; pdb.set_trace() + dpo_trainer.train()