-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support sft training on d2l #100
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's rename this file to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""Demo for the supervised fine tuning. | ||
|
||
python -m example.rlhf.supervised_finetuning_demo | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. corrected |
||
""" | ||
|
||
from pykoi.chat import QuestionAnswerDatabase | ||
from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID, | ||
QA_CSV_HEADER_QUESTION, | ||
QA_CSV_HEADER_VOTE_STATUS) | ||
from pykoi.rlhf import RLHFConfig, SupervisedFinetuning | ||
|
||
# get data from local database | ||
qa_database = QuestionAnswerDatabase() | ||
my_data_pd = qa_database.retrieve_all_question_answers_as_pandas() | ||
my_data_pd = my_data_pd[ | ||
[ | ||
QA_CSV_HEADER_ID, | ||
QA_CSV_HEADER_QUESTION, | ||
QA_CSV_HEADER_ANSWER, | ||
QA_CSV_HEADER_VOTE_STATUS, | ||
] | ||
] | ||
|
||
# analyze the data | ||
print(my_data_pd) | ||
print("My local database has {} samples in total".format(my_data_pd.shape[0])) | ||
|
||
# run supervised finetuning | ||
from peft import LoraConfig | ||
config = RLHFConfig(base_model_path="mistralai/Mistral-7B-Instruct-v0.1", | ||
dataset_type="local_csv", dataset_name="data/chapter22_trnvalfromseed_data_processed.csv", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: I do not see this file in the data folder? Is it missing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if I should add these data files. |
||
train_test_split_ratio=0.1, | ||
max_seq_length=896, | ||
per_device_eval_batch_size = 1, | ||
lora_config_rl = LoraConfig( | ||
r=512, | ||
lora_alpha=1024, | ||
lora_dropout=0.05, | ||
target_modules=["q_proj","k_proj","v_proj","o_proj",], # "gate_proj","up_proj","down_proj",], #"lm_head",], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: what is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It specifies the modules (e.g., which components in the attention layers) to be updated when we using peft training. |
||
bias="none", | ||
task_type="CAUSAL_LM" | ||
), | ||
) | ||
rlhf_step1_sft = SupervisedFinetuning(config) | ||
rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Any, Dict, List, Tuple, Union | ||
from transformers import DataCollatorForLanguageModeling | ||
import numpy as np | ||
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: in this example, https://huggingface.co/docs/trl/sft_trainer#advanced-usage, it looks like it directly imports it looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is Rachel and Yunfan's customized implementation |
||
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: | ||
batch = super().torch_call(examples) | ||
|
||
# The prompt ends with the response key plus a newline. We encode this and then try to find it in the | ||
# sequence of tokens. This should just be a single token. | ||
RESPONSE_KEY = "### Response:" | ||
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" | ||
response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL) | ||
|
||
labels = batch["labels"].clone() | ||
|
||
for i in range(len(examples)): | ||
|
||
response_token_ids_start_idx = None | ||
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]: | ||
response_token_ids_start_idx = idx | ||
break | ||
|
||
if response_token_ids_start_idx is None: | ||
raise RuntimeError( | ||
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}' | ||
) | ||
|
||
response_token_ids_end_idx = response_token_ids_start_idx + 1 | ||
|
||
# Make pytorch loss function ignore all tokens up through the end of the response key | ||
labels[i, :response_token_ids_end_idx] = -100 | ||
|
||
batch["labels"] = labels | ||
|
||
return batch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add a new line at the end of the file. If you have not done so, please setup your dev environment following https://www.notion.so/goldpiggy/Python-Linter-and-formatter-Setup-30fb3b81f0904af889832e4c697c5ec9?pvs=4 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I resolved it and ran pylint on other files as well. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from pykoi.rlhf.config import RLHFConfig | ||
from pykoi.telemetry.events import SFTStartEvent, SFTStopEvent | ||
from pykoi.telemetry.telemetry import Telemetry | ||
|
||
from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM | ||
|
||
class SupervisedFinetuning: | ||
""" | ||
|
@@ -48,6 +48,13 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No | |
self._telemetry = Telemetry(enable_telemetry) | ||
self._rlhf_config = rlhf_config | ||
self.tokenizer = AutoTokenizer.from_pretrained(rlhf_config.base_model_path) | ||
# dh: add special tokens to tokenizer | ||
self.tokenizer.pad_token = self.tokenizer.eos_token | ||
END_KEY = "### End" | ||
INSTRUCTION_KEY = "### Instruction:" | ||
RESPONSE_KEY = "### Response:" | ||
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" | ||
self.tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]}) | ||
self.num_proc = ( | ||
self._rlhf_config.num_workers if not self._rlhf_config.streaming else None | ||
) | ||
|
@@ -83,13 +90,23 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No | |
load_in_8bit=self._rlhf_config.load_in_8bit, | ||
device_map=self._rlhf_config.device_map, | ||
) | ||
# resize the token embeddings to include the added special tokens | ||
self.model.resize_token_embeddings(len(self.tokenizer)) | ||
|
||
# dh: try the customized data collator that only predicts the answer part | ||
data_collator = DataCollatorForCompletionOnlyLM( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: shall we make this configurable to avoid breaking running the code in the old way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
tokenizer=self.tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 | ||
) | ||
|
||
self.trainer = SFTTrainer( | ||
model=self.model, | ||
args=self.training_args, | ||
train_dataset=self.dataset["train"], | ||
eval_dataset=self.dataset["eval"], | ||
peft_config=self._rlhf_config.lora_config_rl, | ||
peft_config=self._rlhf_config.lora_config_rl, ## TODO: DH: LoraConfig MAY BE IGNORED IF USING FROM_PRETRAINED | ||
packing=True, | ||
data_collator=data_collator, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: could you please help explain in the PR description why we added this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got this that this is for training the instruction following objective by masking out the query instead of the casual language model objective for only the next token. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay |
||
dataset_text_field="text", | ||
) | ||
|
||
def train(self): | ||
|
@@ -103,6 +120,8 @@ def load_lora( | |
base_model_path: Optional[str] = None, | ||
lora_model_path: Optional[str] = None, | ||
): | ||
#import pdb; pdb.set_trace() | ||
# dh: not used | ||
if base_model_path is None: | ||
base_model_path = self._rlhf_config.base_model_path | ||
|
||
|
@@ -163,6 +182,65 @@ def prepare_sample_text(self, example): | |
f" Answer: {example[self._rlhf_config.answer_title]}" | ||
) | ||
return text | ||
|
||
|
||
def prepare_d2l_text(self, example): | ||
"""Prepare the text from a sample of the d2l dataset .""" | ||
INTRO_BLURB = ( | ||
"Below is an instruction that describes a task. Write a response that appropriately completes the request." | ||
) | ||
Comment on lines
+189
to
+191
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq that does Unless user always put this as a part of their system prompt, I am wondering if user forget to include this as a part of their system prompt. It might hurt the inference performance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is Yunfei's prompt and I kept it in order to reproduce his result in pykoi. I agree with the case you mentioned. |
||
INSTRUCTION_KEY = "### Instruction:" | ||
INPUT_KEY = "Input:" | ||
RESPONSE_KEY = "### Response:" | ||
END_KEY = "### End" | ||
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" | ||
DEFAULT_SEED = 42 | ||
|
||
# This is a training prompt that does not contain an input string. The instruction by itself has enough information | ||
# to respond. For example, the instruction might ask for the year a historic figure was born. | ||
PROMPT_NO_INPUT_FORMAT = """{intro} | ||
{instruction_key} | ||
{instruction} | ||
{response_key} | ||
{response} | ||
{end_key}""".format( | ||
intro=INTRO_BLURB, | ||
instruction_key=INSTRUCTION_KEY, | ||
instruction="{instruction}", | ||
response_key=RESPONSE_KEY, | ||
response="{response}", | ||
end_key=END_KEY, | ||
) | ||
|
||
# This is a training prompt that contains an input string that serves as context for the instruction. For example, | ||
# the input might be a passage from Wikipedia and the intruction is to extract some information from it. | ||
PROMPT_WITH_INPUT_FORMAT = """{intro} | ||
{instruction_key} | ||
{instruction} | ||
{input_key} | ||
{input} | ||
{response_key} | ||
{response} | ||
{end_key}""".format( | ||
intro=INTRO_BLURB, | ||
instruction_key=INSTRUCTION_KEY, | ||
instruction="{instruction}", | ||
input_key=INPUT_KEY, | ||
input="{input}", | ||
response_key=RESPONSE_KEY, | ||
response="{response}", | ||
end_key=END_KEY, | ||
) | ||
|
||
context = example.get("context") | ||
if context: | ||
text = PROMPT_WITH_INPUT_FORMAT.format(instruction=example["instruction"], response=example["response"], input=context) | ||
else: | ||
text = PROMPT_NO_INPUT_FORMAT.format(instruction=example["instruction"], response=example["instruction"]) | ||
|
||
|
||
|
||
return text | ||
|
||
def create_datasets(self, tokenizer, args): | ||
if args.dataset_type == "local_db": | ||
|
@@ -181,6 +259,7 @@ def create_datasets(self, tokenizer, args): | |
elif args.dataset_type == "local_csv": | ||
dataset = load_dataset("csv", data_files=args.dataset_name) | ||
dataset = dataset[args.split] # Convert DatasetDict to Dataset | ||
dataset2 = load_dataset("csv", data_files=args.dataset_name, split='train[:10%]') | ||
elif args.dataset_type == "huggingface": | ||
dataset = load_dataset( | ||
args.dataset_name, | ||
|
@@ -208,15 +287,15 @@ def create_datasets(self, tokenizer, args): | |
train_dataset = ConstantLengthDataset( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One caveat here is that While However, I am a bit confused that your dataset is not prepared to train on response only (mask out query) but still casual langauge model object for next token. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the collator mask out the query? |
||
tokenizer, | ||
dataset["train"], | ||
formatting_func=self.prepare_sample_text, | ||
formatting_func=self.prepare_d2l_text, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: same as my comments above, we should make this configurable to maintain the old functionality. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
infinite=True, | ||
seq_length=args.max_seq_length, | ||
# chars_per_token=chars_per_token, | ||
) | ||
eval_dataset = ConstantLengthDataset( | ||
tokenizer, | ||
dataset["test"], | ||
formatting_func=self.prepare_sample_text, | ||
formatting_func=self.prepare_d2l_text, | ||
infinite=False, | ||
seq_length=args.max_seq_length, | ||
# chars_per_token=chars_per_token, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please help describe why this is called auto-rater.ipynb, it looks like this file is used for generating QA dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This script for auto evaluation is incomplete. It also contains redundant code from previous QA generation code. Since the latest version of pykoi/uniflow already have auto rater, shall we remove my related commits?