Skip to content

Commit

Permalink
code cleanup for d2l demo. In SFT, make data collator, formatting fun…
Browse files Browse the repository at this point in the history
…ction, whether to disable evalution configurable
  • Loading branch information
llauraa23 committed Jan 24, 2024
1 parent 58f946c commit beeedaa
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 74 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Demo for the supervised fine tuning.
python -m example.rlhf.supervised_finetuning_demo
python -m example.rlhf.supervised_finetuning_demo_d2l
"""

from peft import LoraConfig
from pykoi.chat import QuestionAnswerDatabase
from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID,
QA_CSV_HEADER_QUESTION,
Expand All @@ -26,27 +27,29 @@
print("My local database has {} samples in total".format(my_data_pd.shape[0]))

# run supervised finetuning
from peft import LoraConfig
config = RLHFConfig(base_model_path="mistralai/Mistral-7B-Instruct-v0.1",
dataset_type="local_csv", dataset_name="data/chapter22_trnvalfromseed_data_processed.csv",
train_test_split_ratio=0, # ratio for test set DH:TODO: COBINE TRAIN AND EVAL
train_test_split_ratio=0, # ratio for test set DH:TODO: COBINE TRAIN AND EVAL
max_seq_length=896,
per_device_eval_batch_size = 1,
log_freq=20,
# dh: NOTE: 1 EPOCH iterates the dataset once. So log freq 20 means iterating 20 entries when training batch size = 1.
per_device_eval_batch_size=1,
log_freq=20,
# dh: NOTE: 1 EPOCH iterates the dataset once. So log freq 20 means iterating 20 entries when training batch size = 1.
# (i.e., log_freq = 0.12 epoch when the dataset has 166 entires).
save_freq=40000,
num_train_epochs=20,
max_steps=-1, # if a positive number is given, it will override num_train_epochs
max_steps=-1, # if a positive number is given, it will override num_train_epochs
device_map="auto",
lora_config_rl = LoraConfig(
r=512,
lora_alpha=1024,
lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","o_proj",], # "gate_proj","up_proj","down_proj",], #"lm_head",],
lora_config_rl=LoraConfig(
r=512,
lora_alpha=1024,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", ], # "gate_proj","up_proj","down_proj",], #"lm_head",],
bias="none",
task_type="CAUSAL_LM"
),
),
data_collator="DataCollatorForCompletionOnlyLM",
no_evaluation=True,
prepare_text="d2l",
)
rlhf_step1_sft = SupervisedFinetuning(config)
rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft")
12 changes: 12 additions & 0 deletions pykoi/rlhf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ class RLHFConfig:
),
metadata={"help": "LoRA configuration."},
)
data_collator: Optional[str] = field(
default=None,
metadata={"help": "The name of data collator to use for training."},
)
no_evaluation: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable evaluations during training."},
)
prepare_text: Optional[str] = field(
default="sample",
metadata={"help": "How to prepare the text for the model."},
)

# Step 2 reward modeling parameters
reward_model_path: Optional[str] = field(
Expand Down
15 changes: 10 additions & 5 deletions pykoi/rlhf/customize_data_collator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Union
from transformers import DataCollatorForLanguageModeling
import numpy as np


class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
def torch_call(
self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)

# The prompt ends with the response key plus a newline. We encode this and then try to find it in the
Expand All @@ -16,7 +19,8 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
for i in range(len(examples)):

response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
for idx in np.where(
batch["labels"][i] == response_token_ids[0])[0]:
response_token_ids_start_idx = idx
break

Expand All @@ -27,9 +31,10 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D

response_token_ids_end_idx = response_token_ids_start_idx + 1

# Make pytorch loss function ignore all tokens up through the end of the response key
# Make pytorch loss function ignore all tokens up through the end
# of the response key
labels[i, :response_token_ids_end_idx] = -100

batch["labels"] = labels

return batch
return batch
141 changes: 85 additions & 56 deletions pykoi/rlhf/supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pykoi.telemetry.telemetry import Telemetry
from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM


class SupervisedFinetuning:
"""
A class representing the supervised finetuning trainer.
Expand All @@ -37,7 +38,10 @@ class SupervisedFinetuning:
trainer (SFTTrainer): The trainer object used for training the model.
"""

def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> None:
def __init__(
self,
rlhf_config: RLHFConfig,
enable_telemetry: bool = True) -> None:
"""
Initializes the SFTTrainer object.
Expand All @@ -47,17 +51,18 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No
"""
self._telemetry = Telemetry(enable_telemetry)
self._rlhf_config = rlhf_config
self.tokenizer = AutoTokenizer.from_pretrained(rlhf_config.base_model_path)
self.tokenizer = AutoTokenizer.from_pretrained(
rlhf_config.base_model_path)
# dh: add special tokens to tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
END_KEY = "### End"
INSTRUCTION_KEY = "### Instruction:"
RESPONSE_KEY = "### Response:"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
self.tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})
self.tokenizer.add_special_tokens(
{"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})
self.num_proc = (
self._rlhf_config.num_workers if not self._rlhf_config.streaming else None
)
self._rlhf_config.num_workers if not self._rlhf_config.streaming else None)
self.dataset = self.create_datasets(self.tokenizer, self._rlhf_config)
self.torch_dtype = torch.bfloat16 if self._rlhf_config.bf16 else torch.float16
# self.torch_dtype = torch.bfloat16 if bf16 else (torch.float16 if fp16 else torch.float32)
Expand All @@ -77,8 +82,7 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No
gradient_accumulation_steps=self._rlhf_config.gradient_accumulation_steps,
gradient_checkpointing=self._rlhf_config.gradient_checkpointing,
gradient_checkpointing_kwargs={
"use_reentrant": self._rlhf_config.gradient_checkpointing_use_reentrant
},
"use_reentrant": self._rlhf_config.gradient_checkpointing_use_reentrant},
fp16=self._rlhf_config.fp16,
bf16=self._rlhf_config.bf16,
weight_decay=self._rlhf_config.weight_decay,
Expand All @@ -93,18 +97,20 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No
)
# resize the token embeddings to include the added special tokens
self.model.resize_token_embeddings(len(self.tokenizer))

# dh: try the customized data collator that only predicts the answer part
data_collator = DataCollatorForCompletionOnlyLM(
tokenizer=self.tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)
data_collator = None
if self._rlhf_config.data_collator == "DataCollatorForCompletionOnlyLM":
# dh: try the customized data collator that only predicts the
# answer part
data_collator = DataCollatorForCompletionOnlyLM(
tokenizer=self.tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8)

self.trainer = SFTTrainer(
model=self.model,
args=self.training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["eval"],
peft_config=self._rlhf_config.lora_config_rl, ## TODO: DH: LoraConfig MAY BE IGNORED IF USING FROM_PRETRAINED
peft_config=self._rlhf_config.lora_config_rl,
# TODO: DH: LoraConfig MAY BE IGNORED IF USING FROM_PRETRAINED
packing=True,
data_collator=data_collator,
dataset_text_field="text",
Expand Down Expand Up @@ -163,8 +169,9 @@ def save(self, output_path=None):

def train_and_save(self, output_path=None):
start_event = SFTStartEvent(
start_time=time.time(), date_time=datetime.utcfromtimestamp(time.time())
)
start_time=time.time(),
date_time=datetime.utcfromtimestamp(
time.time()))
self._telemetry.capture(start_event)
self.trainer.train()
self.save(output_path)
Expand All @@ -180,10 +187,8 @@ def prepare_sample_text(self, example):
"""Prepare the text from a sample of the dataset."""
text = (
f"Question: {example[self._rlhf_config.question_title]}\n\n "
f" Answer: {example[self._rlhf_config.answer_title]}"
)
f" Answer: {example[self._rlhf_config.answer_title]}")
return text


def prepare_d2l_text(self, example):
"""Prepare the text from a sample of the d2l dataset ."""
Expand All @@ -198,7 +203,8 @@ def prepare_d2l_text(self, example):
DEFAULT_SEED = 42

# This is a training prompt that does not contain an input string. The instruction by itself has enough information
# to respond. For example, the instruction might ask for the year a historic figure was born.
# to respond. For example, the instruction might ask for the year a
# historic figure was born.
PROMPT_NO_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
Expand All @@ -214,7 +220,8 @@ def prepare_d2l_text(self, example):
)

# This is a training prompt that contains an input string that serves as context for the instruction. For example,
# the input might be a passage from Wikipedia and the intruction is to extract some information from it.
# the input might be a passage from Wikipedia and the intruction is to
# extract some information from it.
PROMPT_WITH_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
Expand All @@ -232,14 +239,17 @@ def prepare_d2l_text(self, example):
response="{response}",
end_key=END_KEY,
)

context = example.get("context")
if context:
text = PROMPT_WITH_INPUT_FORMAT.format(instruction=example["instruction"], response=example["response"], input=context)
text = PROMPT_WITH_INPUT_FORMAT.format(
instruction=example["instruction"],
response=example["response"],
input=context)
else:
text = PROMPT_NO_INPUT_FORMAT.format(instruction=example["instruction"], response=example["instruction"])


text = PROMPT_NO_INPUT_FORMAT.format(
instruction=example["instruction"],
response=example["instruction"])

return text

Expand All @@ -258,13 +268,16 @@ def create_datasets(self, tokenizer, args):
)
dataset = Dataset.from_dict(my_data_pd)
elif args.dataset_type == "local_csv":
## this way will load 1660 enetries
# this way will load 1660 enetries
# dataset = load_dataset("csv", data_files=args.dataset_name)
# dataset = dataset[args.split] # Convert DatasetDict to Dataset

# this way will load 166 entries

dataset = load_dataset("csv", data_files=args.dataset_name, split='train[:10%]')
dataset = load_dataset(
"csv",
data_files=args.dataset_name,
split='train[:10%]')

elif args.dataset_type == "huggingface":
dataset = load_dataset(
Expand All @@ -281,34 +294,50 @@ def create_datasets(self, tokenizer, args):
"No (supported) data files or dataset script found"
f" {args.dataset_type}"
)

# dh: temp change. No test set
# dataset = dataset.train_test_split(
# test_size=args.train_test_split_ratio, seed=args.seed
# )
print(
f"Size of the train set: {len(dataset)}. "
#f"Size of the train set: {len(dataset['train'])}. "
#f" Size of the validation set: {len(dataset['test'])}"
)

train_dataset = ConstantLengthDataset(
tokenizer,
dataset,
#dataset["train"], #dh: temp change. No test set
formatting_func=self.prepare_d2l_text,
infinite=True,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)
# temp change: no test set
# eval_dataset = ConstantLengthDataset(
# tokenizer,
# dataset["test"],
# formatting_func=self.prepare_d2l_text,
# infinite=False,
# seq_length=args.max_seq_length,
# # chars_per_token=chars_per_token,
# )
eval_dataset = None
if args.prepare_text == "d2l":
self.prepare_text = self.prepare_d2l_text
else:
self.prepare_text = self.prepare_sample_text
# No test set during training
if args.no_evaluation:
print(
f"Size of the train set: {len(dataset)}. "
)

train_dataset = ConstantLengthDataset(
tokenizer,
dataset,
formatting_func=self.prepare_text,
infinite=True,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)
eval_dataset = None
else:
dataset = dataset.train_test_split(
test_size=args.train_test_split_ratio, seed=args.seed
)
print(
f"Size of the train set: {len(dataset['train'])}. "
f" Size of the validation set: {len(dataset['test'])}")

train_dataset = ConstantLengthDataset(
tokenizer,
dataset["train"],
formatting_func=self.prepare_text,
infinite=True,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)

eval_dataset = ConstantLengthDataset(
tokenizer,
dataset["test"],
formatting_func=self.prepare_text,
infinite=False,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)

return {"train": train_dataset, "eval": eval_dataset}

0 comments on commit beeedaa

Please sign in to comment.