From b655e1ae6f0e7cf7f93933c0cc2dd72023a8d0f3 Mon Sep 17 00:00:00 2001 From: Sukriti Sharma Date: Mon, 1 Jul 2024 11:33:59 -0600 Subject: [PATCH] minor refactor to allow modular functions (#224) * minor refactor to allow modular functions Signed-off-by: Sukriti-Sharma4 * minor fix in import Signed-off-by: Sukriti-Sharma4 * minor fix to imports Signed-off-by: Sukriti-Sharma4 * fix linting Signed-off-by: Sukriti-Sharma4 * fix formatting Signed-off-by: Sukriti-Sharma4 --------- Signed-off-by: Sukriti-Sharma4 --- tests/utils/test_preprocessing_utils.py | 28 +++- tuning/sft_trainer.py | 37 +---- tuning/utils/preprocessing_utils.py | 196 ++++++++++++++++-------- 3 files changed, 162 insertions(+), 99 deletions(-) diff --git a/tests/utils/test_preprocessing_utils.py b/tests/utils/test_preprocessing_utils.py index e13486bb1..7a807da99 100644 --- a/tests/utils/test_preprocessing_utils.py +++ b/tests/utils/test_preprocessing_utils.py @@ -13,6 +13,7 @@ ) # Local +from tuning.config import configs from tuning.utils.preprocessing_utils import ( combine_sequence, get_data_trainer_kwargs, @@ -180,14 +181,29 @@ def test_get_trainer_kwargs_with_custom_masking(use_validation_data): assert trainer_kwargs["formatting_func"] is not None -# Tests for fetching train args +# Tests for validating data args +# Invalid args return ValueError @pytest.mark.parametrize( - "dataset_text_field, response_template", + "data_args, packing", [ - ("input", None), - (None, "output"), + # dataset_text_field with no response_template + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA, + dataset_text_field="output", + ), + False, + ), + # response template with no dataset_text_field or formatter + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA, + response_template="\n### Label:", + ), + False, + ), ], ) -def test_validate_args(dataset_text_field, response_template): +def test_validate_args(data_args, packing): with pytest.raises(ValueError): - validate_data_args(dataset_text_field, response_template) + validate_data_args(data_args, packing) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 045c9aa66..b221db898 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -34,7 +34,7 @@ TrainerCallback, ) from transformers.utils import is_accelerate_available, logging -from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer +from trl import SFTConfig, SFTTrainer import datasets import fire import transformers @@ -62,6 +62,7 @@ USER_ERROR_EXIT_CODE, write_termination_log, ) +from tuning.utils.preprocessing_utils import get_data_collator, validate_data_args def train( @@ -195,14 +196,6 @@ def train( } ) - # TODO: near term - how response template ids are parsed out needs to be cleaned. - # The [2:] here applies if response template has \n prefix, it is needed to strip \n, - # otherwise template is not found. We will create issue to clean this out after we discuss - # data formats and collators we will support. - response_template_ids = tokenizer.encode( - data_args.response_template, add_special_tokens=False - )[2:] - max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) logger.info("Max sequence length is %s", max_seq_length) if train_args.max_seq_length > tokenizer.model_max_length: @@ -244,31 +237,14 @@ def train( packing = True else: logger.info("Packing is set to False") - if data_args.response_template is None: - # TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization - # We should do this validation up front, then do the encoding, then handle the collator - raise ValueError("Response template is None, needs to be set for training") - data_collator = DataCollatorForCompletionOnlyLM( - response_template_ids, - tokenizer=tokenizer, - ignore_index=configs.IGNORE_INDEX, - ) packing = False - # Currently we support formatted datasets with single sequence instances. - if not (data_args.dataset_text_field or data_args.data_formatter_template): - raise ValueError( - "dataset_text_field and data_formatter_template are None. \ - One of them needs to be set for training" - ) - # Only one of dataset_text_field or data_formatter_template should be set. - if data_args.dataset_text_field and data_args.data_formatter_template: - raise ValueError( - "dataset_text_field and data_formatter_template are both set,\ - but are mutually exclusive options" - ) + # Validate if data args are set properly + validate_data_args(data_args, packing) + data_collator = get_data_collator(packing, data_args.response_template, tokenizer) # load the data by parsing JSON + ### TODO: all the jSON file formatting will be moved to a separate function data_files = {"train": data_args.training_data_path} if data_args.validation_data_path: data_files["validation"] = data_args.validation_data_path @@ -310,6 +286,7 @@ def train( logger.info( "Validation dataset length is %s", len(formatted_validation_dataset) ) + ### JSON file formatting ends here if framework is not None and framework.requires_agumentation: model, (peft_config,) = framework.augmentation( diff --git a/tuning/utils/preprocessing_utils.py b/tuning/utils/preprocessing_utils.py index 7de077973..545e16352 100644 --- a/tuning/utils/preprocessing_utils.py +++ b/tuning/utils/preprocessing_utils.py @@ -25,25 +25,141 @@ from tuning.config import configs -def validate_data_args( +def validate_data_args(data_args: configs.DataArguments, packing: bool): + + assert isinstance( + data_args.training_data_path, str + ), "Training data path has to be set and str" + + # Dataset containing single sequence needs a response template for masking + if data_args.response_template is None and data_args.dataset_text_field is not None: + if packing is False: + raise ValueError( + "Since dataset_text_field is provided and packing is disabled, \ + needs a corresponding response template for masking" + ) + + # Currently if packing is false, we require a response_template. This may change in future. + if packing is False: + if data_args.response_template is None: + raise ValueError( + "Response template is None, needs to be set for training \ + with packing disabled." + ) + + if data_args.response_template: + # To use Response template, pass datasets with single sequence instances \ + # or a formatter template to create single sequence on the fly. + if not (data_args.dataset_text_field or data_args.data_formatter_template): + raise ValueError( + "dataset_text_field and data_formatter_template are None. \ + One of them needs to be set to use response_template" + ) + # Only one of dataset_text_field or data_formatter_template should be set. + if data_args.dataset_text_field and data_args.data_formatter_template: + raise ValueError( + "dataset_text_field and data_formatter_template are both set,\ + but are mutually exclusive options" + ) + # TODO(s) In future seupport two more formats: + # 1. Allow no response template, and JSON with input/output fields and mask input + + # 2. Allow pretokenized Dataset besides JSON. + + +def get_data_collator( + packing: bool, + response_template: Optional[str], + tokenizer: AutoTokenizer, +) -> Callable: + """Create and return the the appropriate collator type based on the configuration for packing, + response_template, and dataset_text_field. + + Args: + packing: bool + Whether or not we should apply packing or not. + response_template: Optional[str] + Response template to be used for formatting by TRL. + tokenizer: AutoTokenizer + Loaded tokenizer object to be used by the collator. + + Returns: + Callable + Callable collator to be leveraged by the trainer. + """ + if not packing: + # TODO: near term - how response template ids are parsed out needs to be cleaned. + # The [2:] here applies if response template has \n prefix, it is needed to strip \n, + # otherwise template is not found. We will create issue to clean this out after we discuss + # data formats and collators we will support. + if response_template: + response_template_ids = tokenizer.encode( + response_template, add_special_tokens=False + )[2:] + return DataCollatorForCompletionOnlyLM( + response_template=response_template_ids, + tokenizer=tokenizer, + ignore_index=configs.IGNORE_INDEX, + ) + # TO DO with future changes, + # 1. Support no packing and seq2seq colator without response template + # # if dataset_text_field is None and response_template is None: + # # Use the seq2seq data collator; + # # Note that this automatically pads labels with -100 + # return DataCollatorForSeq2Seq( + # tokenizer=tokenizer, padding=True, max_length=max_sequence_length + # ) + # 2. add anything needed for preprocessed input + + +################################################################################### +### The functions below are not yet used. Iterative development towards new features + + +def get_data_collator_temp( + packing: bool, dataset_text_field: Optional[str], response_template: Optional[str], -): - # Dataset containing single sequence needs a single sequence and a response template - if dataset_text_field is None and response_template is not None: - raise ValueError( - "Needs a corresponding dataset_text_feld \ - in which to look for response_template" - ) - if response_template is None and dataset_text_field is not None: - raise ValueError( - "Since dataset_text_field is provided, \ - needs a corresponding response template for masking" - ) - # Dataset containing JSON with fields and a formatter template - # TO DO load JSON and check input/output field is present + max_sequence_length: int, + tokenizer: AutoTokenizer, +) -> Callable: + """Create and return the the appropriate collator type based on the configuration for packing, + response_template, and dataset_text_field. - # in future : pretokenized Dataset may be added. + Args: + packing: bool + Whether or not we should apply packing or not. + dataset_text_field: Optional[str] + Dataset text field fto be used for formatting by TRL. + response_template: Optional[str] + Response template to be used for formatting by TRL. + max_sequence_length: int + Max sequence length to be used for sequence tokenization. + tokenizer: AutoTokenizer + Loaded tokenizer object to be used by the collator. + + Returns: + Callable + Callable collator to be leveraged by the trainer. + """ + if not packing: + if dataset_text_field is None and response_template is None: + # Use the seq2seq data collator; note that this automatically pads labels with -100 + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=True, max_length=max_sequence_length + ) + # TODO: near term - how response template ids are parsed out needs to be cleaned. + # The [2:] here applies if response template has \n prefix, it is needed to strip \n, + # otherwise template is not found. We will create issue to clean this out after we discuss + # data formats and collators we will support. + response_template_ids = tokenizer.encode( + response_template, add_special_tokens=False + )[2:] + return DataCollatorForCompletionOnlyLM( + response_template=response_template_ids, + tokenizer=tokenizer, + ignore_index=configs.IGNORE_INDEX, + ) def get_data_trainer_kwargs( @@ -82,7 +198,7 @@ def get_data_trainer_kwargs( Dict[str, Any] Data related kwargs to be used by the SFT Trainer. """ - data_collator = get_data_collator( + data_collator = get_data_collator_temp( packing, dataset_text_field, response_template, max_sequence_length, tokenizer ) eval_dataset = None @@ -122,52 +238,6 @@ def get_data_trainer_kwargs( return data_kwargs -def get_data_collator( - packing: bool, - dataset_text_field: Optional[str], - response_template: Optional[str], - max_sequence_length: int, - tokenizer: AutoTokenizer, -) -> Callable: - """Create and return the the appropriate collator type based on the configuration for packing, - response_template, and dataset_text_field. - - Args: - packing: bool - Whether or not we should apply packing or not. - dataset_text_field: Optional[str] - Dataset text field fto be used for formatting by TRL. - response_template: Optional[str] - Response template to be used for formatting by TRL. - max_sequence_length: int - Max sequence length to be used for sequence tokenization. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - - Returns: - Callable - Callable collator to be leveraged by the trainer. - """ - if not packing: - if dataset_text_field is None and response_template is None: - # Use the seq2seq data collator; note that this automatically pads labels with -100 - return DataCollatorForSeq2Seq( - tokenizer=tokenizer, padding=True, max_length=max_sequence_length - ) - # TODO: near term - how response template ids are parsed out needs to be cleaned. - # The [2:] here applies if response template has \n prefix, it is needed to strip \n, - # otherwise template is not found. We will create issue to clean this out after we discuss - # data formats and collators we will support. - response_template_ids = tokenizer.encode( - response_template, add_special_tokens=False - )[2:] - return DataCollatorForCompletionOnlyLM( - response_template=response_template_ids, - tokenizer=tokenizer, - ignore_index=configs.IGNORE_INDEX, - ) - - def get_formatted_dataset( data_path: str, dataset_text_field: str, tokenizer: AutoTokenizer ) -> Dataset: