diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md index dd22a99d9..2e53dc61e 100644 --- a/docs/advanced-data-preprocessing.md +++ b/docs/advanced-data-preprocessing.md @@ -210,6 +210,8 @@ This library currently supports the following [preexisting data handlers](https: Formats a dataset by appending an EOS token to a specified field. - `apply_custom_data_formatting_template`: Applies a custom template (e.g., Alpaca style) to format dataset elements. + - `apply_custom_data_formatting_jinja_template`: + Applies a custom jinja template (e.g., Alpaca style) to format dataset elements. - `apply_tokenizer_chat_template`: Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates. diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index c199406c6..8ecc07b8f 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -22,6 +22,9 @@ DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" ) +DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "apply_custom_jinja_template.yaml" +) DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml" ) diff --git a/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml new file mode 100644 index 000000000..474068fe8 --- /dev/null +++ b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml @@ -0,0 +1,14 @@ +dataprocessor: + type: default +datasets: + - name: apply_custom_data_jinja_template + data_paths: + - "FILE_PATH" + data_handlers: + - name: apply_custom_data_formatting_jinja_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + dataset_text_field: "dataset_text_field" + template: "dataset_template" \ No newline at end of file diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index d2a390fe9..bfe366ef8 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -25,6 +25,7 @@ # Local from tuning.data.data_handlers import ( + apply_custom_data_formatting_jinja_template, apply_custom_data_formatting_template, combine_sequence, ) @@ -57,6 +58,32 @@ def test_apply_custom_formatting_template(): assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response +def test_apply_custom_formatting_jinja_template(): + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + formatted_dataset_field = "formatted_data_field" + formatted_dataset = json_dataset.map( + apply_custom_data_formatting_jinja_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + # First response from the data file that is read. + expected_response = ( + "### Input: @HMRCcustomers No this is my first job" + + " \n\n ### Response: no complaint" + + tokenizer.eos_token + ) + + assert formatted_dataset_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response + + def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): """Tests that the formatting function will throw error if wrong keys are passed to template""" json_dataset = datasets.load_dataset( @@ -76,6 +103,25 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): ) +def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(): + """Tests that the jinja formatting function will throw error if wrong keys are passed to template""" + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" + formatted_dataset_field = "formatted_data_field" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises(KeyError): + json_dataset.map( + apply_custom_data_formatting_jinja_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + + @pytest.mark.parametrize( "input_element,output_element,expected_res", [ diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 8de5dfc36..95153d1a4 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -29,6 +29,7 @@ # First Party from tests.artifacts.predefined_data_configs import ( + DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, @@ -693,6 +694,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), @@ -731,7 +736,10 @@ def test_process_dataconfig_file(data_config_path, data_path): # Modify dataset_text_field and template according to dataset formatted_dataset_field = "formatted_data_field" - if datasets_name == "apply_custom_data_template": + if datasets_name in ( + "apply_custom_data_template", + "apply_custom_data_jinja_template", + ): template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { "dataset_text_field": formatted_dataset_field, @@ -753,7 +761,10 @@ def test_process_dataconfig_file(data_config_path, data_path): assert set(train_set.column_names) == column_names elif datasets_name == "pretokenized_dataset": assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) - elif datasets_name == "apply_custom_data_template": + elif datasets_name in ( + "apply_custom_data_template", + "apply_custom_data_jinja_template", + ): assert formatted_dataset_field in set(train_set.column_names) diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index 5b80dc4bb..d993dee31 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -19,8 +19,12 @@ import re # Third Party +from jinja2 import Environment, StrictUndefined from transformers import AutoTokenizer +# Local +from tuning.utils.config_utils import process_jinja_placeholders + ### Utils for custom masking / manipulating input / output strs, etc def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): @@ -108,6 +112,8 @@ def apply_custom_data_formatting_template( Expects to be run as a HF Map API function. Args: element: the HF Dataset element loaded from a JSON or DatasetDict object. + tokenizer: Tokenizer to be used for the EOS token, which will be appended + when formatting the data into a single sequence. Defaults to empty. template: Template to format data with. Features of Dataset should be referred to by {{key}} formatted_dataset_field: Dataset_text_field @@ -137,6 +143,39 @@ def replace_text(match_obj): } +def apply_custom_data_formatting_jinja_template( + element: Dict[str, str], + tokenizer: AutoTokenizer, + dataset_text_field: str, + template: str, + **kwargs, +): + """Function to format datasets with jinja templates. + Expects to be run as a HF Map API function. + Args: + element: the HF Dataset element loaded from a JSON or DatasetDict object. + tokenizer: Tokenizer to be used for the EOS token, which will be appended + when formatting the data into a single sequence. Defaults to empty. + dataset_text_field: formatted_dataset_field. + template: Template to format data with. Features of Dataset + should be referred to by {{key}}. + Returns: + Formatted HF Dataset + """ + + template += tokenizer.eos_token + template = process_jinja_placeholders(template) + env = Environment(undefined=StrictUndefined) + jinja_template = env.from_string(template) + + try: + rendered_text = jinja_template.render(element=element, **element) + except Exception as e: + raise KeyError(f"Dataset does not contain field in template. {e}") from e + + return {dataset_text_field: rendered_text} + + def apply_tokenizer_chat_template( element: Dict[str, str], tokenizer: AutoTokenizer, @@ -157,5 +196,6 @@ def apply_tokenizer_chat_template( "tokenize_and_apply_input_masking": tokenize_and_apply_input_masking, "apply_dataset_formatting": apply_dataset_formatting, "apply_custom_data_formatting_template": apply_custom_data_formatting_template, + "apply_custom_data_formatting_jinja_template": apply_custom_data_formatting_jinja_template, "apply_tokenizer_chat_template": apply_tokenizer_chat_template, } diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index b5dede937..061d6017b 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -18,6 +18,7 @@ import json import os import pickle +import re # Third Party from peft import LoraConfig, PromptTuningConfig @@ -135,3 +136,34 @@ def txt_to_obj(txt): except UnicodeDecodeError: # Otherwise the bytes are a pickled python dictionary return pickle.loads(message_bytes) + + +def process_jinja_placeholders(template: str) -> str: + """ + Function to detect all placeholders of the form {{...}}. + - If the inside has a space (e.g. {{Tweet text}}), + rewrite to {{ element['Tweet text'] }}. + - If it doesn't have a space (e.g. {{text_label}}), leave it as is. + - If it is already using dictionary-style access ({{ element['xyz'] }}), do nothing. + + Args: + template: str + Return: template: str + """ + + pattern = r"\{\{([^}]+)\}\}" + matches = re.findall(pattern, template) + + for match in matches: + original_placeholder = f"{{{{{match}}}}}" + trimmed = match.strip() + + if trimmed.startswith("element["): + continue + + # If there's a space in the placeholder name, rewrite it to dictionary-style + if " " in trimmed: + new_placeholder = f"{{{{ element['{trimmed}'] }}}}" + template = template.replace(original_placeholder, new_placeholder) + + return template