diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md index 55ae4c47d..74f9a9223 100644 --- a/docs/advanced-data-preprocessing.md +++ b/docs/advanced-data-preprocessing.md @@ -204,14 +204,17 @@ Users can also pass any number of `kwargs` arguments required for each data hand #### Preexisting data handlers This library currently supports the following [preexisting data handlers](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py#L156): - - `tokenize_and_apply_input_masking`: - Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets. - `add_tokenizer_eos_token`: Appends the tokenizer's EOS token to a specified dataset field. - `apply_custom_data_formatting_template`: Applies a custom template (e.g., Alpaca style) to format dataset elements. + By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py) + - `tokenize_and_apply_input_masking`: + Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets. + By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py) - `apply_custom_jinja_template`: Applies a custom jinja template (e.g., Alpaca style) to format dataset elements. + By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py) - `apply_tokenizer_chat_template`: Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates. - `duplicate_columns`: diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index d91326040..c56813b1a 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -58,6 +58,7 @@ def tokenize_and_apply_input_masking( column_names: List[str], input_field_name: str, output_field_name: str, + add_eos_token: bool = True, **kwargs, ): """Function (data handler) to tokenize and apply instruction masking on dataset @@ -68,6 +69,7 @@ def tokenize_and_apply_input_masking( column_names: Name of all the columns in the dataset. input_field_name: Name of the input (instruction) field in dataset output_field_name: Name of the output field in dataset + add_eos_token: should add tokenizer.eos_token to text or not, defaults to True **kwargs: Any additional args passed to the handler Returns: Formatted Dataset element with input_ids, labels and attention_mask columns @@ -83,7 +85,11 @@ def tokenize_and_apply_input_masking( input_text = element[input_field_name] output_text = element[output_field_name] - combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token) + eos_token = "" + if add_eos_token: + eos_token = tokenizer.eos_token + + combined = combine_sequence(input_text, output_text, eos_token=eos_token) tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) @@ -131,6 +137,7 @@ def apply_custom_data_formatting_template( tokenizer: AutoTokenizer, dataset_text_field: str, template: str, + add_eos_token: bool = True, **kwargs, ): """Function (data handler) to format datasets with Alpaca style / other templates. @@ -142,12 +149,14 @@ def apply_custom_data_formatting_template( dataset_text_field: Text column name of the dataset where formatted text is saved. template: Template to format data with. Features of Dataset should be referred to by {{key}} + add_eos_token: should add tokenizer.eos_token to text or not, defaults to True Returns: Formatted Dataset element by formatting dataset with template+tokenizer.EOS_TOKEN Saves the result to dataset_text_field argument. """ - template += tokenizer.eos_token + if add_eos_token: + template += tokenizer.eos_token def replace_text(match_obj): captured_groups = match_obj.groups() @@ -174,6 +183,7 @@ def apply_custom_jinja_template( tokenizer: AutoTokenizer, dataset_text_field: str, template: str, + add_eos_token: bool = True, **kwargs, ): """Function (data handler) to format datasets with jinja templates. @@ -185,12 +195,14 @@ def apply_custom_jinja_template( dataset_text_field: formatted_dataset_field. template: Template to format data with. Features of Dataset should be referred to by {{key}}. + add_eos_token: should add tokenizer.eos_token to text or not, defaults to True Returns: Formatted HF Dataset element by formatting dataset with provided jinja template Saves the result to dataset_text_field argument. """ + if add_eos_token: + template += tokenizer.eos_token - template += tokenizer.eos_token template = process_jinja_placeholders(template) env = SandboxedEnvironment(undefined=StrictUndefined)