From 5c03aa8753c7412463ee97b70b23cdb3db48dc28 Mon Sep 17 00:00:00 2001 From: Abhishek Maurya <124327945+Abhishek-TAMU@users.noreply.github.com> Date: Tue, 4 Feb 2025 22:49:28 -0500 Subject: [PATCH 01/11] feat: Add support for jinja based template rendering of the dataset (#438) Signed-off-by: Abhishek --- docs/advanced-data-preprocessing.md | 2 + .../predefined_data_configs/__init__.py | 3 ++ .../apply_custom_jinja_template.yaml | 14 ++++++ tests/data/test_data_handlers.py | 46 +++++++++++++++++++ tests/data/test_data_preprocessing_utils.py | 15 +++++- tuning/data/data_handlers.py | 40 ++++++++++++++++ tuning/utils/config_utils.py | 32 +++++++++++++ 7 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md index dd22a99d9..2e53dc61e 100644 --- a/docs/advanced-data-preprocessing.md +++ b/docs/advanced-data-preprocessing.md @@ -210,6 +210,8 @@ This library currently supports the following [preexisting data handlers](https: Formats a dataset by appending an EOS token to a specified field. - `apply_custom_data_formatting_template`: Applies a custom template (e.g., Alpaca style) to format dataset elements. + - `apply_custom_data_formatting_jinja_template`: + Applies a custom jinja template (e.g., Alpaca style) to format dataset elements. - `apply_tokenizer_chat_template`: Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates. diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index c199406c6..8ecc07b8f 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -22,6 +22,9 @@ DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" ) +DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "apply_custom_jinja_template.yaml" +) DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml" ) diff --git a/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml new file mode 100644 index 000000000..474068fe8 --- /dev/null +++ b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml @@ -0,0 +1,14 @@ +dataprocessor: + type: default +datasets: + - name: apply_custom_data_jinja_template + data_paths: + - "FILE_PATH" + data_handlers: + - name: apply_custom_data_formatting_jinja_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + dataset_text_field: "dataset_text_field" + template: "dataset_template" \ No newline at end of file diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index d2a390fe9..bfe366ef8 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -25,6 +25,7 @@ # Local from tuning.data.data_handlers import ( + apply_custom_data_formatting_jinja_template, apply_custom_data_formatting_template, combine_sequence, ) @@ -57,6 +58,32 @@ def test_apply_custom_formatting_template(): assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response +def test_apply_custom_formatting_jinja_template(): + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + formatted_dataset_field = "formatted_data_field" + formatted_dataset = json_dataset.map( + apply_custom_data_formatting_jinja_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + # First response from the data file that is read. + expected_response = ( + "### Input: @HMRCcustomers No this is my first job" + + " \n\n ### Response: no complaint" + + tokenizer.eos_token + ) + + assert formatted_dataset_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response + + def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): """Tests that the formatting function will throw error if wrong keys are passed to template""" json_dataset = datasets.load_dataset( @@ -76,6 +103,25 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): ) +def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(): + """Tests that the jinja formatting function will throw error if wrong keys are passed to template""" + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" + formatted_dataset_field = "formatted_data_field" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises(KeyError): + json_dataset.map( + apply_custom_data_formatting_jinja_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + + @pytest.mark.parametrize( "input_element,output_element,expected_res", [ diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 8de5dfc36..95153d1a4 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -29,6 +29,7 @@ # First Party from tests.artifacts.predefined_data_configs import ( + DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, @@ -693,6 +694,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), @@ -731,7 +736,10 @@ def test_process_dataconfig_file(data_config_path, data_path): # Modify dataset_text_field and template according to dataset formatted_dataset_field = "formatted_data_field" - if datasets_name == "apply_custom_data_template": + if datasets_name in ( + "apply_custom_data_template", + "apply_custom_data_jinja_template", + ): template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { "dataset_text_field": formatted_dataset_field, @@ -753,7 +761,10 @@ def test_process_dataconfig_file(data_config_path, data_path): assert set(train_set.column_names) == column_names elif datasets_name == "pretokenized_dataset": assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) - elif datasets_name == "apply_custom_data_template": + elif datasets_name in ( + "apply_custom_data_template", + "apply_custom_data_jinja_template", + ): assert formatted_dataset_field in set(train_set.column_names) diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index 5b80dc4bb..d993dee31 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -19,8 +19,12 @@ import re # Third Party +from jinja2 import Environment, StrictUndefined from transformers import AutoTokenizer +# Local +from tuning.utils.config_utils import process_jinja_placeholders + ### Utils for custom masking / manipulating input / output strs, etc def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): @@ -108,6 +112,8 @@ def apply_custom_data_formatting_template( Expects to be run as a HF Map API function. Args: element: the HF Dataset element loaded from a JSON or DatasetDict object. + tokenizer: Tokenizer to be used for the EOS token, which will be appended + when formatting the data into a single sequence. Defaults to empty. template: Template to format data with. Features of Dataset should be referred to by {{key}} formatted_dataset_field: Dataset_text_field @@ -137,6 +143,39 @@ def replace_text(match_obj): } +def apply_custom_data_formatting_jinja_template( + element: Dict[str, str], + tokenizer: AutoTokenizer, + dataset_text_field: str, + template: str, + **kwargs, +): + """Function to format datasets with jinja templates. + Expects to be run as a HF Map API function. + Args: + element: the HF Dataset element loaded from a JSON or DatasetDict object. + tokenizer: Tokenizer to be used for the EOS token, which will be appended + when formatting the data into a single sequence. Defaults to empty. + dataset_text_field: formatted_dataset_field. + template: Template to format data with. Features of Dataset + should be referred to by {{key}}. + Returns: + Formatted HF Dataset + """ + + template += tokenizer.eos_token + template = process_jinja_placeholders(template) + env = Environment(undefined=StrictUndefined) + jinja_template = env.from_string(template) + + try: + rendered_text = jinja_template.render(element=element, **element) + except Exception as e: + raise KeyError(f"Dataset does not contain field in template. {e}") from e + + return {dataset_text_field: rendered_text} + + def apply_tokenizer_chat_template( element: Dict[str, str], tokenizer: AutoTokenizer, @@ -157,5 +196,6 @@ def apply_tokenizer_chat_template( "tokenize_and_apply_input_masking": tokenize_and_apply_input_masking, "apply_dataset_formatting": apply_dataset_formatting, "apply_custom_data_formatting_template": apply_custom_data_formatting_template, + "apply_custom_data_formatting_jinja_template": apply_custom_data_formatting_jinja_template, "apply_tokenizer_chat_template": apply_tokenizer_chat_template, } diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index b5dede937..061d6017b 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -18,6 +18,7 @@ import json import os import pickle +import re # Third Party from peft import LoraConfig, PromptTuningConfig @@ -135,3 +136,34 @@ def txt_to_obj(txt): except UnicodeDecodeError: # Otherwise the bytes are a pickled python dictionary return pickle.loads(message_bytes) + + +def process_jinja_placeholders(template: str) -> str: + """ + Function to detect all placeholders of the form {{...}}. + - If the inside has a space (e.g. {{Tweet text}}), + rewrite to {{ element['Tweet text'] }}. + - If it doesn't have a space (e.g. {{text_label}}), leave it as is. + - If it is already using dictionary-style access ({{ element['xyz'] }}), do nothing. + + Args: + template: str + Return: template: str + """ + + pattern = r"\{\{([^}]+)\}\}" + matches = re.findall(pattern, template) + + for match in matches: + original_placeholder = f"{{{{{match}}}}}" + trimmed = match.strip() + + if trimmed.startswith("element["): + continue + + # If there's a space in the placeholder name, rewrite it to dictionary-style + if " " in trimmed: + new_placeholder = f"{{{{ element['{trimmed}'] }}}}" + template = template.replace(original_placeholder, new_placeholder) + + return template From 496daf98eb5cf88d695f29683e5e0f2485b3545c Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Wed, 5 Feb 2025 12:09:26 +0530 Subject: [PATCH 02/11] Fix bug in aim tracker where the server based tracking was not picked (#454) up. Signed-off-by: Dushyant Behl --- tuning/trackers/aimstack_tracker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py index 8985c1966..314443f9c 100644 --- a/tuning/trackers/aimstack_tracker.py +++ b/tuning/trackers/aimstack_tracker.py @@ -121,7 +121,7 @@ def get_hf_callback(self): if url is not None: aim_callback = RunIDExporterAimCallback(repo=url, experiment=exp) - if repo: + elif repo: aim_callback = RunIDExporterAimCallback(repo=repo, experiment=exp) else: self.logger.error( From f32ac38e5b07978179d765c588be5b768b59587a Mon Sep 17 00:00:00 2001 From: PRINCE KUMAR Date: Wed, 5 Feb 2025 15:44:35 +0530 Subject: [PATCH 03/11] Removed duplicate main_process_port entry (#444) Signed-off-by: Prince Kumar Co-authored-by: Prince Kumar --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index af3da30f4..5c00587a3 100644 --- a/README.md +++ b/README.md @@ -322,7 +322,6 @@ Below example runs multi-GPU fine tuning on 8 GPUs with FSDP: # OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved accelerate launch \ ---main_process_port $MASTER_PORT \ --config_file fixtures/accelerate_fsdp_defaults.yaml \ --num_processes=8 \ --main_process_port=$MASTER_PORT \ From 59a72cd597f0ab140c59739e41d1f7be7b86538a Mon Sep 17 00:00:00 2001 From: Hari Date: Thu, 6 Feb 2025 00:29:16 +0530 Subject: [PATCH 04/11] fix: space missing from data_formatter_template causing mismatch with response_template (#455) Signed-off-by: Harikrishnan Balagopal --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5c00587a3..d35b38904 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ Example: Train.json }, ... ]` -data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}` +data_formatter_template: `### Input: {{input}} \n\n## Label: {{output}}` Formatting will happen on the fly while tuning. The keys in template should match fields in the dataset file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`. @@ -299,7 +299,7 @@ python tuning/sft_trainer.py \ --gradient_accumulation_steps 4 \ --learning_rate 1e-5 \ --response_template "\n## Label:" \ ---data_formatter_template: "### Input: {{input}} \n\n##Label: {{output}}" +--data_formatter_template: "### Input: {{input}} \n\n## Label: {{output}}" ``` From f88f031a09604cc86c529b468792f12699f72a2a Mon Sep 17 00:00:00 2001 From: Abhishek Maurya <124327945+Abhishek-TAMU@users.noreply.github.com> Date: Fri, 7 Feb 2025 08:12:08 -0500 Subject: [PATCH 05/11] SandboxedEnvironment in Jinja template (#456) Signed-off-by: Abhishek --- tests/data/test_data_handlers.py | 14 +++++++++++--- tuning/data/data_handlers.py | 25 +++++++++++++++++++++---- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index bfe366ef8..4a3736e6d 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -103,15 +103,23 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): ) -def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(): +@pytest.mark.parametrize( + "template", + [ + "### Input: {{ not found }} \n\n ### Response: {{ text_label }}", + "### Input: }} Tweet text {{ \n\n ### Response: {{ text_label }}", + "### Input: {{ Tweet text }} \n\n ### Response: {{ ''.__class__ }}", + "### Input: {{ Tweet text }} \n\n ### Response: {{ undefined_variable.split() }}", + ], +) +def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(template): """Tests that the jinja formatting function will throw error if wrong keys are passed to template""" json_dataset = datasets.load_dataset( "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL ) - template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" formatted_dataset_field = "formatted_data_field" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - with pytest.raises(KeyError): + with pytest.raises((KeyError, ValueError)): json_dataset.map( apply_custom_data_formatting_jinja_template, fn_kwargs={ diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index d993dee31..cdfb263b2 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -19,7 +19,8 @@ import re # Third Party -from jinja2 import Environment, StrictUndefined +from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError +from jinja2.sandbox import SandboxedEnvironment, SecurityError from transformers import AutoTokenizer # Local @@ -165,13 +166,29 @@ def apply_custom_data_formatting_jinja_template( template += tokenizer.eos_token template = process_jinja_placeholders(template) - env = Environment(undefined=StrictUndefined) - jinja_template = env.from_string(template) + env = SandboxedEnvironment(undefined=StrictUndefined) + + try: + jinja_template = env.from_string(template) + except TemplateSyntaxError as e: + raise ValueError( + f"Invalid template syntax in provided Jinja template. {e.message}" + ) from e try: rendered_text = jinja_template.render(element=element, **element) + except UndefinedError as e: + raise KeyError( + f"The dataset does not contain the key used in the provided Jinja template. {e.message}" + ) from e + except SecurityError as e: + raise ValueError( + f"Unsafe operation detected in the provided Jinja template. {e.message}" + ) from e except Exception as e: - raise KeyError(f"Dataset does not contain field in template. {e}") from e + raise ValueError( + f"Error occurred while rendering the provided Jinja template. {e.message}" + ) from e return {dataset_text_field: rendered_text} From d48d48360c0ce167f6394a6804f42b8ba1b4420d Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Mon, 10 Feb 2025 22:12:37 +0530 Subject: [PATCH 06/11] docs: Add documentation on how to do EPT runs with our library. (#461) * add documentation on ept Signed-off-by: Dushyant Behl * Update docs/ept.md Co-authored-by: Will Johnson Signed-off-by: Dushyant Behl * Apply suggestions from code review Co-authored-by: Will Johnson Signed-off-by: Dushyant Behl Signed-off-by: Dushyant Behl * Add additional information Signed-off-by: Dushyant Behl * fix statement Signed-off-by: Dushyant Behl --------- Signed-off-by: Dushyant Behl Signed-off-by: Dushyant Behl Co-authored-by: Will Johnson --- README.md | 4 ++ docs/ept.md | 112 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 docs/ept.md diff --git a/README.md b/README.md index d35b38904..f3bdaed4d 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ - [Prompt Tuning](#prompt-tuning) - [Fine Tuning](#fine-tuning) - [FMS Acceleration](#fms-acceleration) +- [Extended Pre-Training](#extended-pre-training) - [Inference](#inference) - [Running a single example](#running-a-single-example) - [Running multiple examples](#running-multiple-examples) @@ -828,6 +829,9 @@ Number of trainable parameters = 13,631,488 The `fms_acceleration.cli` can do more to search for all available configs, plugins and arguments, [see the advanced flow](https://github.com/foundation-model-stack/fms-acceleration#advanced-flow). +## Extended Pre-Training + +We also have support for extended pre training where users might wanna pretrain a model with large number of samples. Please refer our separate doc on [EPT Use Cases](./docs/ept.md) ## Inference Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time. diff --git a/docs/ept.md b/docs/ept.md new file mode 100644 index 000000000..c6fbdb379 --- /dev/null +++ b/docs/ept.md @@ -0,0 +1,112 @@ +# Extended Pre Training Support +Our library also supports Extended Pre-Training (EPT), which is generally useful when users want to train a pretrained model on a large number of samples. The training behaviour of EPT is similar to that of pretraining where users might wanna make sure the models runs through entire corpus of data available and be trained on whole set of tokens without any specific masking. + +See [below](#additional-information) for information on when this document was last updated and the release which supports this feature. + +## Packing support + +We support training via `packing` dataset samples by specifing `--packing=True` in the command line parameters. Users can choose to specify `--max_seq_len=` to provide the maxium sequence length of each chunk post packing. + +We provide below details on how to use different style of datasets with the library. + +## Non-Tokenized Dataset + +### Single Non-Tokenized Dataset +Users can pass a single dataset to the library by using a [data_config](./advanced-data-preprocessing.md#data-config). +Lets say you have a `JSONL` data file which contains text to be trained on in each line that you want to perform EPT on, you can create a `data_config` for the dataset in this manner, + +Example dataset, + +``` +{"Tweet":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"} +{"Tweet":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"} +... +``` + +Sample data config for the above use case. +``` +dataprocessor: + type: default +datasets: + - name: non_tokenized_text_dataset + data_paths: + - "" + data_handlers: + - name: apply_custom_data_formatting + arguments: + remove_columns: all + batched: false + fn_kwargs: + dataset_text_field: "dataset_text_field" +``` + +And the commandline passed to the library should include following. + +``` +--data_config --packing=True --max_seq_len 8192 +``` + +Please note that for non tokenized dataset our code adds `EOS_TOKEN` to the lines, for e.g. `Tweet` column before passing that as a dataset. + +### Multiple Non Tokenized Datasets + +If a user wants to utilize multiple datasets and want to [`sample`](./advanced-data-preprocessing.md#how-the-user-can-write-data-configs) the datasets. This can be achieved by specifying multiple datasets in the data config with different sampling ratios. + +Sample data config for sampling among multiple datasets +``` +dataprocessor: + type: default + sampling_stopping_strategy: first_exhausted + seed: 66 +datasets: + - name: non_tokenized_text_dataset_1 + sampling: 0.3 + data_paths: + - "FILE_PATH" + data_handlers: + - name: apply_custom_data_formatting_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + dataset_text_field: "dataset_text_field" + template: "dataset_template" + - name: non_tokenized_text_dataset_2 + sampling: 0.4 + data_paths: + - "FILE_PATH" + data_handlers: + - name: apply_custom_data_formatting_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + dataset_text_field: "dataset_text_field" + template: "dataset_template" + - name: non_tokenized_text_dataset_3 + sampling: 0.3 + data_paths: + - "FILE_PATH" + data_handlers: + - name: apply_custom_data_formatting_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + dataset_text_field: "dataset_text_field" + template: "dataset_template" +``` + +NOTE: More in-depth documentation of `sampling_stopping_strategy` and how to specify data mixing parameters in the `data_config` is covered in the [data mixing](./advanced-data-preprocessing.md#data-mixing) section of the advanced data preprocessing documentation + +Here also the command line arguments would be + +``` +--data_config --packing=True --max_seq_len 8192 +``` + +The code again would add `EOS_TOKEN` to the non tokenized data before using it and also note that the `dataset_text_field` is assumed to be same across all datasets for now. + +### Additional Information +This feature is supported post [v2.3.1](https://github.com/foundation-model-stack/fms-hf-tuning/releases/tag/v2.3.1) of this library. +Post Last Updated On: 10-02-2025 \ No newline at end of file From 381fdd55d42d2fffbf8ef0ed53d55bfe83373924 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Wed, 12 Feb 2025 01:02:15 +0530 Subject: [PATCH 07/11] feat: Rename data handlers and add a new one for EPT scenarios (#460) * Rename data handlers and add a data handler for EPT user case. Signed-off-by: Dushyant Behl * Fix minor bug in formatting where input_ids was missing post duplication. Signed-off-by: Dushyant Behl * Add docstring Signed-off-by: Dushyant Behl * change name of dataset in data config yaml Signed-off-by: Dushyant Behl --------- Signed-off-by: Dushyant Behl --- docs/advanced-data-preprocessing.md | 8 +- .../predefined_data_configs/__init__.py | 3 + .../apply_custom_jinja_template.yaml | 2 +- .../duplicate_columns.yaml | 14 +++ tests/artifacts/testdata/__init__.py | 4 + ...h_maykeye_tinyllama_v0_only_input_ids.json | 32 ++++++ tests/data/test_data_handlers.py | 61 ++++++++++- ...ng_utils.py => test_data_preprocessing.py} | 0 tests/test_sft_trainer.py | 52 ++++++++- tuning/data/data_handlers.py | 102 +++++++++++++++--- tuning/data/setup_dataprocessor.py | 2 +- 11 files changed, 254 insertions(+), 26 deletions(-) create mode 100644 tests/artifacts/predefined_data_configs/duplicate_columns.yaml create mode 100644 tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0_only_input_ids.json rename tests/data/{test_data_preprocessing_utils.py => test_data_preprocessing.py} (100%) diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md index 2e53dc61e..55ae4c47d 100644 --- a/docs/advanced-data-preprocessing.md +++ b/docs/advanced-data-preprocessing.md @@ -206,14 +206,16 @@ Users can also pass any number of `kwargs` arguments required for each data hand This library currently supports the following [preexisting data handlers](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py#L156): - `tokenize_and_apply_input_masking`: Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets. - - `apply_dataset_formatting`: - Formats a dataset by appending an EOS token to a specified field. + - `add_tokenizer_eos_token`: + Appends the tokenizer's EOS token to a specified dataset field. - `apply_custom_data_formatting_template`: Applies a custom template (e.g., Alpaca style) to format dataset elements. - - `apply_custom_data_formatting_jinja_template`: + - `apply_custom_jinja_template`: Applies a custom jinja template (e.g., Alpaca style) to format dataset elements. - `apply_tokenizer_chat_template`: Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates. + - `duplicate_columns`: + Duplicates one column of the dataset to another column. These handlers could be requested by their same name and users can lookup the function args from [here](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py) diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index 8ecc07b8f..7d0317466 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -34,3 +34,6 @@ DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml" ) +DATA_CONFIG_DUPLICATE_COLUMNS = os.path.join( + PREDEFINED_DATA_CONFIGS, "duplicate_columns.yaml" +) diff --git a/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml index 474068fe8..6dcf031d3 100644 --- a/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml +++ b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml @@ -5,7 +5,7 @@ datasets: data_paths: - "FILE_PATH" data_handlers: - - name: apply_custom_data_formatting_jinja_template + - name: apply_custom_jinja_template arguments: remove_columns: all batched: false diff --git a/tests/artifacts/predefined_data_configs/duplicate_columns.yaml b/tests/artifacts/predefined_data_configs/duplicate_columns.yaml new file mode 100644 index 000000000..e94482b67 --- /dev/null +++ b/tests/artifacts/predefined_data_configs/duplicate_columns.yaml @@ -0,0 +1,14 @@ +dataprocessor: + type: default +datasets: + - name: pre_tokenized_with_only_input_ids + data_paths: + - "FILE_PATH" + data_handlers: + - name: duplicate_columns + arguments: + remove_columns: all + batched: false + fn_kwargs: + old_column: "input_ids" + new_column: "labels" \ No newline at end of file diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index c76f065b7..b7d0d7b6f 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -53,6 +53,10 @@ TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join( JSON_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json" ) +TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON = os.path.join( + JSON_DATA_DIR, + "twitter_complaints_tokenized_with_maykeye_tinyllama_v0_only_input_ids.json", +) TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join( JSONL_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" ) diff --git a/tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0_only_input_ids.json b/tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0_only_input_ids.json new file mode 100644 index 000000000..eb022f541 --- /dev/null +++ b/tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0_only_input_ids.json @@ -0,0 +1,32 @@ +[ + { + "input_ids": [1, 16121, 9211, 31871, 1662, 31866, 31856, 7416, 17632, 369, 1398, 433, 322, 629, 712, 1784, 13, 13, 8458, 31922, 21597, 31871, 697, 9566] + }, + { + "input_ids": [1, 16121, 9211, 31871, 1662, 31892, 1260, 31825, 11273, 503, 31857, 632, 5284, 365, 329, 553, 1280, 31905, 960, 365, 6194, 289, 11025, 31844, 365, 473, 987, 12207, 4218, 389, 31822, 31853, 31854, 31886, 31852, 31852, 31854, 11300, 31847, 3873, 1507, 31843, 13, 13, 8458, 31922, 21597, 31871, 697, 9566] + }, + { + "input_ids": [1, 16121, 9211, 31871, 960, 312, 473, 31876, 31824, 685, 629, 31822, 31878, 4449, 5861, 287, 1662, 1299, 1574, 1590, 31833, 263, 1360, 1299, 1574, 289, 623, 31822, 31824, 16346, 312, 31876, 31836, 994, 277, 3560, 567, 31843, 672, 322, 260, 29458, 288, 629, 14881, 31843, 2628, 1423, 1662, 31858, 601, 1662, 31858, 601, 8378, 13, 13, 8458, 31922, 21597, 31871, 9566] + }, + { + "input_ids": [1, 16121, 9211, 31871, 1662, 7766, 1078, 8123, 17561, 308, 3456, 1833, 975, 10849, 291, 4372, 15379, 504, 10011, 2368, 1512, 31822, 31855, 31852, 31852, 1243, 31843, 3007, 322, 433, 31843, 13, 13, 8458, 31922, 21597, 31871, 9566] + }, + { + "input_ids": [1, 16121, 9211, 31871, 12371, 2208, 26657, 31844, 560, 14138, 31843, 21994, 1257, 24870, 496, 31829, 8198, 19057, 13, 13, 8458, 31922, 21597, 31871, 697, 9566] + }, + { + "input_ids": [1, 16121, 9211, 31871, 1662, 31836, 651, 307, 395, 13094, 672, 1467, 701, 333, 515, 31844, 504, 1097, 2266, 282, 305, 781, 31902, 21626, 31822, 31824, 5540, 397, 560, 5253, 662, 365, 31876, 263, 4985, 31854, 8903, 16801, 291, 612, 31925, 2011, 1129, 31824, 31843, 1358, 31873, 19919, 31824, 31865, 31829, 469, 2131, 31874, 13, 13, 8458, 31922, 21597, 31871, 697, 9566] + }, + { + "input_ids": [1, 16121, 9211, 31871, 1662, 31900, 307, 31837, 473, 382, 685, 266, 3195, 17532, 329, 260, 1173, 9363, 352, 1671, 1881, 646, 619, 31822, 31882, 5556, 504, 2091, 31822, 31882, 31843, 31855, 31861, 405, 499, 382, 863, 260, 31822, 31878, 4449, 2540, 2042, 31902, 13, 13, 8458, 31922, 21597, 31871, 697, 9566] + }, + { + "input_ids": [1, 16121, 9211, 31871, 1662, 14390, 16373, 337, 312, 435, 697, 1579, 291, 266, 3925, 322, 1434, 291, 3877, 31843, 1456, 365, 499, 1419, 562, 433, 31902, 13, 13, 8458, 31922, 21597, 31871, 9566] + }, + { + "input_ids": [1, 16121, 9211, 31871, 7265, 7550, 389, 1662, 31856, 2226, 11596, 27771, 898, 31843, 3259, 647, 312, 498, 288, 635, 31844, 518, 3822, 397, 2168, 28910, 31873, 13627, 4107, 1708, 31843, 312, 31876, 608, 1090, 629, 10279, 289, 1662, 29966, 31831, 5605, 13, 13, 8458, 31922, 21597, 31871, 9566] + }, + { + "input_ids": [1, 16121, 9211, 31871, 1662, 31884, 1450, 7064, 31847, 6538, 30894, 4472, 289, 362, 828, 31843, 864, 685, 541, 9932, 843, 584, 18694, 31986, 13, 13, 8458, 31922, 21597, 31871, 697, 9566] + } +] diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index 4a3736e6d..c61dada4e 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -21,13 +21,19 @@ import pytest # First Party -from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL +from tests.artifacts.testdata import ( + MODEL_NAME, + TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_JSON, + TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON, +) # Local from tuning.data.data_handlers import ( - apply_custom_data_formatting_jinja_template, apply_custom_data_formatting_template, + apply_custom_jinja_template, combine_sequence, + duplicate_columns, ) @@ -66,7 +72,7 @@ def test_apply_custom_formatting_jinja_template(): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) formatted_dataset_field = "formatted_data_field" formatted_dataset = json_dataset.map( - apply_custom_data_formatting_jinja_template, + apply_custom_jinja_template, fn_kwargs={ "tokenizer": tokenizer, "dataset_text_field": formatted_dataset_field, @@ -121,7 +127,7 @@ def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(temp tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) with pytest.raises((KeyError, ValueError)): json_dataset.map( - apply_custom_data_formatting_jinja_template, + apply_custom_jinja_template, fn_kwargs={ "tokenizer": tokenizer, "dataset_text_field": formatted_dataset_field, @@ -162,3 +168,50 @@ def test_combine_sequence_adds_eos(input_element, output_element, expected_res): expected_res += tokenizer.eos_token assert isinstance(comb_seq, str) assert comb_seq == expected_res + + +@pytest.mark.parametrize( + "dataset, old, new", + [ + (TWITTER_COMPLAINTS_DATA_JSONL, "input_ids", "labels"), + (TWITTER_COMPLAINTS_TOKENIZED_JSON, "input_ids", "labels"), + (TWITTER_COMPLAINTS_DATA_JSONL, None, None), + (TWITTER_COMPLAINTS_DATA_JSONL, "input_ids", None), + ], +) +def test_duplicate_columns_throws_error_on_wrong_args(dataset, old, new): + """Ensure that duplicate_columns data handler throws error if column names are wrong.""" + d = datasets.load_dataset("json", data_files=dataset) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises(ValueError): + d.map( + duplicate_columns, + fn_kwargs={ + "tokenizer": tokenizer, + "old_column": old, + "new_column": new, + }, + ) + + +def test_duplicate_columns_copies_columns(): + """Ensure that duplicate_columns data handler copies and maintains both columns.""" + old = "input_ids" + new = "labels" + d = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + updated_dataaset = d.map( + duplicate_columns, + fn_kwargs={ + "tokenizer": tokenizer, + "old_column": old, + "new_column": new, + }, + ) + + first_element = updated_dataaset["train"][0] + assert new in first_element + assert old in first_element + assert first_element[new] == first_element[old] diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing.py similarity index 100% rename from tests/data/test_data_preprocessing_utils.py rename to tests/data/test_data_preprocessing.py diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 8faa3746c..c7c73891d 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -36,6 +36,7 @@ from build.utils import serialize_args from scripts.run_inference import TunedCausalLM from tests.artifacts.predefined_data_configs import ( + DATA_CONFIG_DUPLICATE_COLUMNS, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, ) @@ -58,6 +59,7 @@ TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON, TWITTER_COMPLAINTS_TOKENIZED_PARQUET, ) @@ -71,7 +73,7 @@ DataPreProcessorConfig, DataSetConfig, ) -from tuning.data.data_handlers import apply_dataset_formatting +from tuning.data.data_handlers import add_tokenizer_eos_token MODEL_ARGS = configs.ModelArguments( model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" @@ -880,6 +882,52 @@ def test_run_causallm_ft_and_inference_with_multiple_dataset( assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +def test_run_training_with_pretokenised_dataset_containing_input_ids(): + """Ensure that we can train on pretokenised dataset containing just input_ids by + choosing duplicate_columns data handler via data config.""" + with tempfile.TemporaryDirectory() as tempdir: + + data_args = copy.deepcopy(DATA_ARGS) + + # set training_data_path and response_template to none + data_args.response_template = None + data_args.training_data_path = None + + dataconfigfile = DATA_CONFIG_DUPLICATE_COLUMNS + datapath = TWITTER_COMPLAINTS_TOKENIZED_ONLY_INPUT_IDS_JSON + + # add data_paths in data_config file + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + with open(dataconfigfile, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + datasets = data["datasets"] + for _, d in enumerate(datasets): + d["data_paths"] = [datapath] + yaml.dump(data, temp_yaml_file) + data_args.data_config_path = temp_yaml_file.name + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, data_args, train_args) + + # validate full ft configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + @pytest.mark.parametrize( "dataset_path", [CHAT_DATA_SINGLE_TURN, CHAT_DATA_MULTI_TURN], @@ -1469,7 +1517,7 @@ def test_run_by_passing_additional_data_handlers(): TEST_HANDLER = "my_test_handler" def test_handler(element, tokenizer, **kwargs): - return apply_dataset_formatting(element, tokenizer, "custom_formatted_field") + return add_tokenizer_eos_token(element, tokenizer, "custom_formatted_field") # This data config calls for data handler to be applied to dataset preprocessor_config = DataPreProcessorConfig() diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index cdfb263b2..2951a5a26 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -16,6 +16,7 @@ # Standard from typing import Dict, List +import copy import re # Third Party @@ -59,6 +60,19 @@ def tokenize_and_apply_input_masking( output_field_name: str, **tokenizer_kwargs, ): + """Function (data handler) to tokenize and apply instruction masking on dataset + Expects to be run as a HF Map API function. + Args: + element: the HF Dataset element. + tokenizer: Tokenizer to be used for tokenization. + column_names: Name of all the columns in the dataset. + input_field_name: Name of the input (instruction) field in dataset + output_field_name: Name of the output field in dataset + **tokenizer_kwargs: Any additional kwargs to be passed to tokenizer + Returns: + Formatted Dataset element with input_ids, labels and attention_mask columns + """ + if (input_field_name or output_field_name) not in column_names: raise ValueError( f"Dataset should contain {input_field_name} \ @@ -89,12 +103,23 @@ def tokenize_and_apply_input_masking( } -def apply_dataset_formatting( +def add_tokenizer_eos_token( element: Dict[str, str], tokenizer: AutoTokenizer, dataset_text_field: str, **kwargs, ): + """Function (data handler) to add tokenizer's EOS token to text field of an element + Expects to be run as a HF Map API function. + Args: + element: the HF Dataset element. + tokenizer: Tokenizer to be used for the EOS token, which will be appended + when formatting the data into a single sequence. Defaults to empty. + dataset_text_field: Text column name of the dataset where EOS is to be added. + Returns: + Formatted Dataset element with EOS added to dataset_text_field of the element. + """ + if dataset_text_field not in element: raise KeyError(f"Dataset should contain {dataset_text_field} field.") return { @@ -109,19 +134,18 @@ def apply_custom_data_formatting_template( template: str, **kwargs, ): - """Function to format datasets with Alpaca style / other templates. + """Function (data handler) to format datasets with Alpaca style / other templates. Expects to be run as a HF Map API function. Args: - element: the HF Dataset element loaded from a JSON or DatasetDict object. + element: the HF Dataset element. tokenizer: Tokenizer to be used for the EOS token, which will be appended when formatting the data into a single sequence. Defaults to empty. + dataset_text_field: Text column name of the dataset where formatted text is saved. template: Template to format data with. Features of Dataset should be referred to by {{key}} - formatted_dataset_field: Dataset_text_field - eos_token: string EOS token to be appended while formatting data to a single sequence. - Defaults to empty Returns: - Formatted HF Dataset + Formatted Dataset element by formatting dataset with template+tokenizer.EOS_TOKEN + Saves the result to dataset_text_field argument. """ template += tokenizer.eos_token @@ -140,28 +164,31 @@ def replace_text(match_obj): return str(element[index_object]) return { - dataset_text_field: re.sub(r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template) + f"{dataset_text_field}": re.sub( + r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template + ) } -def apply_custom_data_formatting_jinja_template( +def apply_custom_jinja_template( element: Dict[str, str], tokenizer: AutoTokenizer, dataset_text_field: str, template: str, **kwargs, ): - """Function to format datasets with jinja templates. + """Function (data handler) to format datasets with jinja templates. Expects to be run as a HF Map API function. Args: - element: the HF Dataset element loaded from a JSON or DatasetDict object. + element: the HF Dataset element tokenizer: Tokenizer to be used for the EOS token, which will be appended when formatting the data into a single sequence. Defaults to empty. dataset_text_field: formatted_dataset_field. template: Template to format data with. Features of Dataset should be referred to by {{key}}. Returns: - Formatted HF Dataset + Formatted HF Dataset element by formatting dataset with provided jinja template + Saves the result to dataset_text_field argument. """ template += tokenizer.eos_token @@ -190,7 +217,7 @@ def apply_custom_data_formatting_jinja_template( f"Error occurred while rendering the provided Jinja template. {e.message}" ) from e - return {dataset_text_field: rendered_text} + return {f"{dataset_text_field}": rendered_text} def apply_tokenizer_chat_template( @@ -199,6 +226,16 @@ def apply_tokenizer_chat_template( dataset_text_field: str, **kwargs, ): + """Function (data handler) to apply tokenizers chat template to dataset elements. + Expects to be run as a HF Map API function. + Args: + element: the HF Dataset element. + tokenizer: Tokenizer to be used. + dataset_text_field: formatted_dataset_field. + Returns: + Formatted HF Dataset element by formatting dataset with tokenizer's chat template + Saves the result to dataset_text_field argument. + """ if tokenizer.chat_template is None: raise ValueError( "Tokenizer does not contain tokenizer.chat_template\ @@ -209,10 +246,45 @@ def apply_tokenizer_chat_template( } +def duplicate_columns( + element: Dict[str, str], + old_column: str, + new_column: str, + **kwargs, +): + """Function (data handler) to duplicate one columne of a dataset to another. + Expects to be run as a HF Map API function. + Args: + element: the HF Dataset element + old_column: Name of the column which is to be duplicated + new_column: Name of the new column where duplicated column is to be saved + Returns: + Formatted HF Dataset element with new_column where existing_columns content is deep copied. + """ + if not old_column or not new_column: + raise ValueError( + "for duplicating columns both old and new column name must be specified" + ) + if old_column not in element: + raise ValueError( + f"Cannot duplicate {old_column} to {new_column} as column {old_column} doesn't exist" + ) + if new_column in element: + raise ValueError( + f"Cannot duplicate {old_column} to f{new_column} as column {new_column} already exists" + ) + + return { + f"{old_column}": element[old_column], + f"{new_column}": copy.deepcopy(element[old_column]), + } + + AVAILABLE_DATA_HANDLERS = { "tokenize_and_apply_input_masking": tokenize_and_apply_input_masking, - "apply_dataset_formatting": apply_dataset_formatting, + "add_tokenizer_eos_token": add_tokenizer_eos_token, "apply_custom_data_formatting_template": apply_custom_data_formatting_template, - "apply_custom_data_formatting_jinja_template": apply_custom_data_formatting_jinja_template, + "apply_custom_jinja_template": apply_custom_jinja_template, "apply_tokenizer_chat_template": apply_tokenizer_chat_template, + "duplicate_columns": duplicate_columns, } diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 7921652b8..433ddbece 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -148,7 +148,7 @@ def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False): fn_kwargs["dataset_text_field"] = dataset_text_field if data_args.data_formatter_template is None: handler = DataHandlerConfig( - "apply_dataset_formatting", + "add_tokenizer_eos_token", arguments={"fn_kwargs": fn_kwargs, "batched": False}, ) else: From a89a4a336487b3bda5bde31f9a7b58a0563e3e02 Mon Sep 17 00:00:00 2001 From: Abhishek Maurya <124327945+Abhishek-TAMU@users.noreply.github.com> Date: Wed, 12 Feb 2025 10:57:34 -0500 Subject: [PATCH 08/11] fix tokenize_and_apply_input_masking kwargs (#465) Signed-off-by: Abhishek --- tests/data/test_data_preprocessing.py | 4 +++- tuning/data/data_handlers.py | 11 +++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/data/test_data_preprocessing.py b/tests/data/test_data_preprocessing.py index 95153d1a4..b435f2d2b 100644 --- a/tests/data/test_data_preprocessing.py +++ b/tests/data/test_data_preprocessing.py @@ -1173,9 +1173,10 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( def test_process_dataargs(data_args, is_padding_free): """Ensure that the train/eval data are properly formatted based on the data args / text field""" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + max_seq_length = 5 TRAIN_ARGS = configs.TrainingArguments( packing=False, - max_seq_length=1024, + max_seq_length=max_seq_length, output_dir="tmp", # Not needed but positional ) (train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs( @@ -1187,6 +1188,7 @@ def test_process_dataargs(data_args, is_padding_free): column_names = set(["input_ids", "attention_mask", "labels"]) assert set(eval_set.column_names) == column_names assert set(train_set.column_names) == column_names + assert len(train_set[0]["input_ids"]) == max_seq_length else: assert dataset_text_field in train_set.column_names assert dataset_text_field in eval_set.column_names diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index 2951a5a26..d91326040 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -58,7 +58,7 @@ def tokenize_and_apply_input_masking( column_names: List[str], input_field_name: str, output_field_name: str, - **tokenizer_kwargs, + **kwargs, ): """Function (data handler) to tokenize and apply instruction masking on dataset Expects to be run as a HF Map API function. @@ -68,7 +68,7 @@ def tokenize_and_apply_input_masking( column_names: Name of all the columns in the dataset. input_field_name: Name of the input (instruction) field in dataset output_field_name: Name of the output field in dataset - **tokenizer_kwargs: Any additional kwargs to be passed to tokenizer + **kwargs: Any additional args passed to the handler Returns: Formatted Dataset element with input_ids, labels and attention_mask columns """ @@ -85,11 +85,10 @@ def tokenize_and_apply_input_masking( combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token) - fn_kwargs = tokenizer_kwargs.get("fn_kwargs", {}) - tokenizer_inner_kwargs = fn_kwargs.get("tokenizer_kwargs", {}) + tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) - tokenized_comb_seqs = tokenizer(combined, **tokenizer_inner_kwargs) - tokenized_input = tokenizer(input_text, **tokenizer_inner_kwargs) + tokenized_comb_seqs = tokenizer(combined, **tokenizer_kwargs) + tokenized_input = tokenizer(input_text, **tokenizer_kwargs) masked_labels = [-100] * len( tokenized_input.input_ids From f1fd130d5f6622e0105b232b24b2dae18051283f Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Thu, 13 Feb 2025 21:16:58 +0530 Subject: [PATCH 09/11] feat: Add support for renaming and retaining columns in data preprocessor (#466) * Add functionality to rename and retain dataset columns in data preprocessor. Signed-off-by: Dushyant Behl * fix fmt Signed-off-by: Dushyant Behl * add unit tests Signed-off-by: Dushyant Behl * Update advanced-data-preprocessing.md Signed-off-by: Dushyant Behl --------- Signed-off-by: Dushyant Behl --- docs/advanced-data-preprocessing.md | 10 ++++ docs/ept.md | 4 +- .../predefined_data_configs/__init__.py | 3 + .../rename_retain_columns.yaml | 20 +++++++ tests/data/test_data_preprocessing.py | 55 +++++++++++++++++++ tests/test_sft_trainer.py | 5 ++ tuning/data/data_config.py | 14 +++++ tuning/data/data_processors.py | 20 +++++++ 8 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 tests/artifacts/predefined_data_configs/rename_retain_columns.yaml diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md index 55ae4c47d..ad2434143 100644 --- a/docs/advanced-data-preprocessing.md +++ b/docs/advanced-data-preprocessing.md @@ -60,6 +60,10 @@ definitions: type: float builder: type: string + rename_columns: + type: object + retain_columns: + type: object data_paths: type: array items: @@ -118,6 +122,8 @@ Users can create a data config file in any of YAML or JSON format they choose (w - `name` (optional, str): A unique identifier for the dataset. - `data_paths` (optional, list): A `list` of file paths or directories containing the dataset. - `builder` (optional, str): Specifies a [Hugging Face dataset builder](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/loading_methods#datasets.load_dataset.path), if applicable. + - `rename_columns` (optional, dict[str:str]): Specifies a dictionary of columns to rename like `{"old_name": "new_name"}` at dataset load time. *Applied before `retain_columns` if both are specified*. + - `retain_columns` (optional, list[str]): Specifies a list of columns to retain `["input_ids", "labels"]` every other column will be dropped at dataset load time. *Applied strictly after `rename_columns` if both are specified*. - `sampling` (optional, float): The sampling ratio (0.0 to 1.0) with which to sample a dataset in case of interleaving. - `data_handlers` (optional, list): A list of data handler configurations which preprocess the dataset. @@ -149,6 +155,10 @@ Not Supported: Currently there's no support for sampling under multiple data paths which are defined inside a dataset definition. All dataset paths that will be specified inside one dataset will be [concatenated](https://huggingface.co/docs/datasets/v3.2.0/en/process#concatenate) after loading them, while across datasets users can specify [mixing via sampling datasets](#data-mixing) +Probably something like this: + +Additionally while loading the dataset, users can specify which columns to rename via `rename_columns` and which to retain via `retain_columns` arguments above. +The order of application of these operations is *strictly rename followed by retain* so users should note that an old column name which is renamed will not be available in retain and hence should be careful while applying these operations. The code will throw a `ValueError` in case user specified a column requested to be renamed via rename argument in retain argument as well. ### How can users specify data handlers. diff --git a/docs/ept.md b/docs/ept.md index c6fbdb379..19433bf86 100644 --- a/docs/ept.md +++ b/docs/ept.md @@ -32,7 +32,7 @@ datasets: data_paths: - "" data_handlers: - - name: apply_custom_data_formatting + - name: add_tokenizer_eos_token arguments: remove_columns: all batched: false @@ -109,4 +109,4 @@ The code again would add `EOS_TOKEN` to the non tokenized data before using it a ### Additional Information This feature is supported post [v2.3.1](https://github.com/foundation-model-stack/fms-hf-tuning/releases/tag/v2.3.1) of this library. -Post Last Updated On: 10-02-2025 \ No newline at end of file +Post Last Updated On: 12-02-2025 \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index 7d0317466..63c14867c 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -37,3 +37,6 @@ DATA_CONFIG_DUPLICATE_COLUMNS = os.path.join( PREDEFINED_DATA_CONFIGS, "duplicate_columns.yaml" ) +DATA_CONFIG_RENAME_RETAIN_COLUMNS = os.path.join( + PREDEFINED_DATA_CONFIGS, "rename_retain_columns.yaml" +) diff --git a/tests/artifacts/predefined_data_configs/rename_retain_columns.yaml b/tests/artifacts/predefined_data_configs/rename_retain_columns.yaml new file mode 100644 index 000000000..ecfe993ca --- /dev/null +++ b/tests/artifacts/predefined_data_configs/rename_retain_columns.yaml @@ -0,0 +1,20 @@ +dataprocessor: + type: default +datasets: + - name: text_dataset_input_output_masking + rename_columns: + "input" : "instruction" + "output" : "response" + retain_columns: + - "instruction" + - "response" + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_input_masking + arguments: + remove_columns: all + batched: false + fn_kwargs: + input_field_name: instruction + output_field_name: response \ No newline at end of file diff --git a/tests/data/test_data_preprocessing.py b/tests/data/test_data_preprocessing.py index b435f2d2b..7bf07e58a 100644 --- a/tests/data/test_data_preprocessing.py +++ b/tests/data/test_data_preprocessing.py @@ -33,6 +33,7 @@ DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + DATA_CONFIG_RENAME_RETAIN_COLUMNS, DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, ) from tests.artifacts.testdata import ( @@ -1365,3 +1366,57 @@ def test_process_dataset_configs_with_sampling_error( (_, _, _, _, _, _) = process_dataargs( data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS ) + + +@pytest.mark.parametrize( + "datafile, rename, retain, final, datasetconfigname", + [ + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + {"input": "instruction", "output": "response"}, + None, + ["ID", "Label", "instruction", "response"], + DATA_CONFIG_RENAME_RETAIN_COLUMNS, + ), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + None, + ["ID", "input", "output"], + ["ID", "input", "output"], + DATA_CONFIG_RENAME_RETAIN_COLUMNS, + ), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + {"input": "instruction", "output": "response"}, + ["Label", "instruction", "response"], + ["Label", "instruction", "response"], + DATA_CONFIG_RENAME_RETAIN_COLUMNS, + ), + ], +) +def test_rename_and_retain_dataset_columns( + datafile, rename, retain, final, datasetconfigname +): + """Test process_dataset_configs for expected output.""" + dataprocessor_config = DataPreProcessorConfig() + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + processor = DataPreProcessor( + processor_config=dataprocessor_config, + tokenizer=tokenizer, + ) + datasetconfig = [ + DataSetConfig( + name=datasetconfigname, + data_paths=[datafile], + rename_columns=rename, + retain_columns=retain, + ) + ] + train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig) + + assert isinstance(train_dataset, Dataset) + assert set(train_dataset.column_names) == set(final) + + with open(datafile, "r") as file: + data = json.load(file) + assert len(train_dataset) == len(data) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index c7c73891d..b718a97cf 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -38,6 +38,7 @@ from tests.artifacts.predefined_data_configs import ( DATA_CONFIG_DUPLICATE_COLUMNS, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + DATA_CONFIG_RENAME_RETAIN_COLUMNS, DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, ) from tests.artifacts.testdata import ( @@ -837,6 +838,10 @@ def test_run_causallm_ft_pretokenized(dataset_path): ], DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, ), + ( + [TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON], + DATA_CONFIG_RENAME_RETAIN_COLUMNS, + ), ], ) def test_run_causallm_ft_and_inference_with_multiple_dataset( diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py index 0c5521baf..6707e8767 100644 --- a/tuning/data/data_config.py +++ b/tuning/data/data_config.py @@ -36,6 +36,8 @@ class DataSetConfig: data_paths: List[str] builder: Optional[str] = None # Referring to Hugging Face dataset builder sampling: Optional[float] = None + rename_columns: Optional[Dict] = None + retain_columns: Optional[List] = None data_handlers: Optional[List[DataHandlerConfig]] = None @@ -100,6 +102,18 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig: 0 <= ratio <= 1.0 ), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]" c.sampling = ratio + if "rename_columns" in kwargs and kwargs["rename_columns"] is not None: + rename = kwargs["rename_columns"] + assert isinstance( + rename, dict + ), "rename_columns should be a dict with current_name:new_name" + c.rename_columns = rename + if "retain_columns" in kwargs and kwargs["retain_columns"] is not None: + retain = kwargs["retain_columns"] + assert isinstance( + retain, list + ), "retain_columns should be a list[str] with names of columns to retain" + c.retain_columns = retain if "data_handlers" in kwargs: c.data_handlers = [] for handler in kwargs["data_handlers"]: diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index bdac6947b..1f7e92da5 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -243,6 +243,26 @@ def _process_dataset_configs( logger.info("Loaded raw dataset : %s", str(raw_dataset)) + # Check if both are conflicting options before proceeding. + if d.rename_columns and d.retain_columns: + commmon = set(d.rename_columns.keys()) & set(d.retain_columns) + if commmon: + raise ValueError( + f"You are trying to retain {str(commmon)} columns" + " which will be renamed via rename operation." + ) + + if d.rename_columns: + logger.info("Renaming %s columns", str(d.rename_columns)) + raw_dataset = raw_dataset.rename_columns( + column_mapping=d.rename_columns + ) + logger.info("Done") + if d.retain_columns: + logger.info("Retaining %s columns", str(d.retain_columns)) + raw_dataset = raw_dataset.select_columns(column_names=d.retain_columns) + logger.info("Done") + raw_datasets = DatasetDict() # Assume all is train split From 2f033c72475cfd68a8bd3c4e45fed3b77d124d6c Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 13 Feb 2025 14:46:58 -0500 Subject: [PATCH 10/11] chore(deps): upgrade trl and transformers (#448) * chore(deps): revert trl restriction Signed-off-by: Will Johnson * fix: remove check Signed-off-by: Will Johnson * fix: remove packing from func def Signed-off-by: Will Johnson * chore(deps): upgrade transformers + trl Signed-off-by: Will Johnson * enable packing for pretokenized datasets Signed-off-by: Dushyant Behl * tests: gradient accum steps = 1, get checkpoint path Co-authored-by: Abhishek Signed-off-by: Will Johnson * add upper limit to trl of below 0.15 Signed-off-by: Anh Uong --------- Signed-off-by: Will Johnson Signed-off-by: Dushyant Behl Signed-off-by: Anh Uong Co-authored-by: Dushyant Behl Co-authored-by: Abhishek Co-authored-by: Anh Uong --- pyproject.toml | 4 ++-- tests/build/test_launch_script.py | 2 +- tests/data/test_data_preprocessing.py | 7 ------- tests/test_sft_trainer.py | 11 +++++++++-- tuning/data/setup_dataprocessor.py | 10 ++-------- 5 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eb1da2993..43742f899 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,12 +29,12 @@ classifiers=[ dependencies = [ "numpy>=1.26.4,<2.0", "accelerate>=0.20.3,!=0.34,<1.1", -"transformers>=4.45,<4.46", +"transformers>=4.46,<4.48.2", "torch>=2.2.0,<2.5", "sentencepiece>=0.1.99,<0.3", "tokenizers>=0.13.3,<1.0", "tqdm>=4.66.2,<5.0", -"trl>=0.9.3,<0.12", +"trl>=0.13,<0.15", "peft>=0.8.0,<0.14", "protobuf>=5.28.0,<6.0.0", "datasets>=2.15.0,<3.0", diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index c699e16da..2f81fd78f 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -46,7 +46,7 @@ "num_train_epochs": 5, "per_device_train_batch_size": 4, "per_device_eval_batch_size": 4, - "gradient_accumulation_steps": 4, + "gradient_accumulation_steps": 1, "learning_rate": 0.00001, "weight_decay": 0, "warmup_ratio": 0.03, diff --git a/tests/data/test_data_preprocessing.py b/tests/data/test_data_preprocessing.py index 7bf07e58a..9e65b1302 100644 --- a/tests/data/test_data_preprocessing.py +++ b/tests/data/test_data_preprocessing.py @@ -667,13 +667,6 @@ def test_get_data_collator( ), False, ), - # Pretokenized data with packing to True - ( - configs.DataArguments( - training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - ), - True, - ), ], ) def test_process_data_args_throws_error_where_needed(data_args, packing): diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index b718a97cf..fb707327c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -22,6 +22,7 @@ import copy import json import os +import re import tempfile # Third Party @@ -88,7 +89,7 @@ num_train_epochs=5, per_device_train_batch_size=4, per_device_eval_batch_size=4, - gradient_accumulation_steps=4, + gradient_accumulation_steps=1, learning_rate=0.00001, weight_decay=0, warmup_ratio=0.03, @@ -1147,7 +1148,13 @@ def _validate_hf_resource_scanner_file(tempdir): def _get_checkpoint_path(dir_path): - return os.path.join(dir_path, "checkpoint-5") + checkpoint_dirs = [ + d + for d in os.listdir(dir_path) + if os.path.isdir(os.path.join(dir_path, d)) and re.match(r"^checkpoint-\d+$", d) + ] + checkpoint_dirs.sort(key=lambda name: int(name.split("-")[-1])) + return os.path.join(dir_path, checkpoint_dirs[-1]) def _get_adapter_config(dir_path): diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 433ddbece..730bc318c 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -74,7 +74,7 @@ def _process_dataconfig_file( # Data Format 1: Pretokenized Data -def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized): +def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized): # if the provided train dataset is pretokenized # however user provides formatting flags, error out @@ -96,12 +96,6 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized): along with pretokenized train data" ) - # Support for packing pretokenized datasets has been merged in trl library - # see: https://github.com/huggingface/trl/pull/2011 - # but we wait till a new transformers version is released to remove this check. - if packing: - raise ValueError("packing will not be used when datasets are pretokenized") - # We do not need a handler here as this is tokenized dataset return [], None @@ -264,7 +258,7 @@ def _process_raw_data_args( if is_traindata_tokenized: # Data Format 1: Pretokenized Data handlers, dataset_text_field = _get_pretokenized_dataset_handlers( - data_args, packing, (is_eval_dataset_present and not is_evaldata_tokenized) + data_args, (is_eval_dataset_present and not is_evaldata_tokenized) ) elif data_args.instruction_template and data_args.response_template: # Data Format 2: Chat dataset with instruction and response template From fb3ace8397223932e176de604703c54a14e1ebf0 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Fri, 14 Feb 2025 22:12:56 +0530 Subject: [PATCH 11/11] feat: adding eos token to be made a flag so we don't force it on every handler (#467) * Rethink add_eos by making it a flag so we don't force it on every handler. Signed-off-by: Dushyant Behl * Merge with main before adding unit tests Signed-off-by: Abhishek * Added documentation and test case Signed-off-by: Abhishek * Added documentation and test case Signed-off-by: Abhishek * Added documentation and test case Signed-off-by: Abhishek * Updated test case Signed-off-by: Abhishek --------- Signed-off-by: Dushyant Behl Signed-off-by: Abhishek Co-authored-by: Abhishek --- docs/advanced-data-preprocessing.md | 7 +- .../apply_custom_jinja_template.yaml | 3 +- .../apply_custom_template.yaml | 3 +- .../tokenize_and_apply_input_masking.yaml | 3 +- tests/data/test_data_preprocessing.py | 96 +++++++++++++++++++ tuning/data/data_handlers.py | 18 +++- 6 files changed, 122 insertions(+), 8 deletions(-) diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md index ad2434143..3476a0e9c 100644 --- a/docs/advanced-data-preprocessing.md +++ b/docs/advanced-data-preprocessing.md @@ -214,14 +214,17 @@ Users can also pass any number of `kwargs` arguments required for each data hand #### Preexisting data handlers This library currently supports the following [preexisting data handlers](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py#L156): - - `tokenize_and_apply_input_masking`: - Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets. - `add_tokenizer_eos_token`: Appends the tokenizer's EOS token to a specified dataset field. - `apply_custom_data_formatting_template`: Applies a custom template (e.g., Alpaca style) to format dataset elements. + By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/artifacts/predefined_data_configs/apply_custom_template.yaml) + - `tokenize_and_apply_input_masking`: + Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets. + By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml) - `apply_custom_jinja_template`: Applies a custom jinja template (e.g., Alpaca style) to format dataset elements. + By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml) - `apply_tokenizer_chat_template`: Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates. - `duplicate_columns`: diff --git a/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml index 6dcf031d3..b7c757dc1 100644 --- a/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml +++ b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml @@ -11,4 +11,5 @@ datasets: batched: false fn_kwargs: dataset_text_field: "dataset_text_field" - template: "dataset_template" \ No newline at end of file + template: "dataset_template" + add_eos_token: true \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/apply_custom_template.yaml b/tests/artifacts/predefined_data_configs/apply_custom_template.yaml index c41797624..946431105 100644 --- a/tests/artifacts/predefined_data_configs/apply_custom_template.yaml +++ b/tests/artifacts/predefined_data_configs/apply_custom_template.yaml @@ -11,4 +11,5 @@ datasets: batched: false fn_kwargs: dataset_text_field: "dataset_text_field" - template: "dataset_template" \ No newline at end of file + template: "dataset_template" + add_eos_token: true \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml index ac7e07030..f5d28d6c6 100644 --- a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml +++ b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml @@ -11,4 +11,5 @@ datasets: batched: false fn_kwargs: input_field_name: input - output_field_name: output \ No newline at end of file + output_field_name: output + add_eos_token: true \ No newline at end of file diff --git a/tests/data/test_data_preprocessing.py b/tests/data/test_data_preprocessing.py index 9e65b1302..22e7e5166 100644 --- a/tests/data/test_data_preprocessing.py +++ b/tests/data/test_data_preprocessing.py @@ -762,6 +762,102 @@ def test_process_dataconfig_file(data_config_path, data_path): assert formatted_dataset_field in set(train_set.column_names) +@pytest.mark.parametrize( + "data_config_path, data_path, add_eos_token", + [ + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON, True), + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON, False), + ( + DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, + TWITTER_COMPLAINTS_DATA_JSON, + True, + ), + ( + DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, + TWITTER_COMPLAINTS_DATA_JSON, + False, + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + True, + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + False, + ), + ], +) +def test_process_datahandler_eos_token(data_config_path, data_path, add_eos_token): + """Ensure that the data handlers correctly apply add_eos_token flag to append/remove eos_token.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"][0] = data_path + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ + "input_field_name" + ] = "input" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ + "output_field_name" + ] = "output" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ + "add_eos_token" + ] = add_eos_token + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name in ( + "apply_custom_data_template", + "apply_custom_data_jinja_template", + ): + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ + "dataset_text_field" + ] = formatted_dataset_field + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ + "template" + ] = template + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ + "add_eos_token" + ] = add_eos_token + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + tokenizer.add_special_tokens({"eos_token": ""}) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + if datasets_name == "text_dataset_input_output_masking": + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(train_set.column_names) == column_names + assert ( + train_set[0]["input_ids"][-1] == tokenizer.eos_token_id + if add_eos_token + else train_set[0]["input_ids"][-1] != tokenizer.eos_token_id + ) + elif datasets_name == "pretokenized_dataset": + assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) + elif datasets_name in ( + "apply_custom_data_template", + "apply_custom_data_jinja_template", + ): + assert formatted_dataset_field in set(train_set.column_names) + assert ( + train_set[0][formatted_dataset_field].endswith(tokenizer.eos_token) + if add_eos_token + else not train_set[0][formatted_dataset_field].endswith(tokenizer.eos_token) + ) + + @pytest.mark.parametrize( "data_config_path, data_path_list", [ diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index d91326040..c56813b1a 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -58,6 +58,7 @@ def tokenize_and_apply_input_masking( column_names: List[str], input_field_name: str, output_field_name: str, + add_eos_token: bool = True, **kwargs, ): """Function (data handler) to tokenize and apply instruction masking on dataset @@ -68,6 +69,7 @@ def tokenize_and_apply_input_masking( column_names: Name of all the columns in the dataset. input_field_name: Name of the input (instruction) field in dataset output_field_name: Name of the output field in dataset + add_eos_token: should add tokenizer.eos_token to text or not, defaults to True **kwargs: Any additional args passed to the handler Returns: Formatted Dataset element with input_ids, labels and attention_mask columns @@ -83,7 +85,11 @@ def tokenize_and_apply_input_masking( input_text = element[input_field_name] output_text = element[output_field_name] - combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token) + eos_token = "" + if add_eos_token: + eos_token = tokenizer.eos_token + + combined = combine_sequence(input_text, output_text, eos_token=eos_token) tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) @@ -131,6 +137,7 @@ def apply_custom_data_formatting_template( tokenizer: AutoTokenizer, dataset_text_field: str, template: str, + add_eos_token: bool = True, **kwargs, ): """Function (data handler) to format datasets with Alpaca style / other templates. @@ -142,12 +149,14 @@ def apply_custom_data_formatting_template( dataset_text_field: Text column name of the dataset where formatted text is saved. template: Template to format data with. Features of Dataset should be referred to by {{key}} + add_eos_token: should add tokenizer.eos_token to text or not, defaults to True Returns: Formatted Dataset element by formatting dataset with template+tokenizer.EOS_TOKEN Saves the result to dataset_text_field argument. """ - template += tokenizer.eos_token + if add_eos_token: + template += tokenizer.eos_token def replace_text(match_obj): captured_groups = match_obj.groups() @@ -174,6 +183,7 @@ def apply_custom_jinja_template( tokenizer: AutoTokenizer, dataset_text_field: str, template: str, + add_eos_token: bool = True, **kwargs, ): """Function (data handler) to format datasets with jinja templates. @@ -185,12 +195,14 @@ def apply_custom_jinja_template( dataset_text_field: formatted_dataset_field. template: Template to format data with. Features of Dataset should be referred to by {{key}}. + add_eos_token: should add tokenizer.eos_token to text or not, defaults to True Returns: Formatted HF Dataset element by formatting dataset with provided jinja template Saves the result to dataset_text_field argument. """ + if add_eos_token: + template += tokenizer.eos_token - template += tokenizer.eos_token template = process_jinja_placeholders(template) env = SandboxedEnvironment(undefined=StrictUndefined)