Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding eos token to be made a flag so we don't force it on every handler #467

Merged
merged 8 commits into from
Feb 14, 2025
7 changes: 5 additions & 2 deletions docs/advanced-data-preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,17 @@ Users can also pass any number of `kwargs` arguments required for each data hand

#### Preexisting data handlers
This library currently supports the following [preexisting data handlers](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py#L156):
- `tokenize_and_apply_input_masking`:
Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets.
- `add_tokenizer_eos_token`:
Appends the tokenizer's EOS token to a specified dataset field.
- `apply_custom_data_formatting_template`:
Applies a custom template (e.g., Alpaca style) to format dataset elements.
By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/artifacts/predefined_data_configs/apply_custom_template.yaml)
- `tokenize_and_apply_input_masking`:
Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets.
By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml)
- `apply_custom_jinja_template`:
Applies a custom jinja template (e.g., Alpaca style) to format dataset elements.
By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml)
- `apply_tokenizer_chat_template`:
Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates.
- `duplicate_columns`:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ datasets:
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
template: "dataset_template"
template: "dataset_template"
add_eos_token: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it defaults to true do we need to add this here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it here to point this config file in the documentation (As you can see the Url in documentation of this PR) for the user to know that they can use this flag to disable EOS_TOKEN from the handler.
And additionally I am assigning it value False in Unit test.

Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ datasets:
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
template: "dataset_template"
template: "dataset_template"
add_eos_token: true
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ datasets:
batched: false
fn_kwargs:
input_field_name: input
output_field_name: output
output_field_name: output
add_eos_token: true
96 changes: 96 additions & 0 deletions tests/data/test_data_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,102 @@ def test_process_dataconfig_file(data_config_path, data_path):
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"data_config_path, data_path, add_eos_token",
[
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON, True),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON, False),
(
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML,
TWITTER_COMPLAINTS_DATA_JSON,
True,
),
(
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML,
TWITTER_COMPLAINTS_DATA_JSON,
False,
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
True,
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
False,
),
],
)
def test_process_datahandler_eos_token(data_config_path, data_path, add_eos_token):
"""Ensure that the data handlers correctly apply add_eos_token flag to append/remove eos_token."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"][0] = data_path
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
"input_field_name"
] = "input"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
"output_field_name"
] = "output"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
"add_eos_token"
] = add_eos_token

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name in (
"apply_custom_data_template",
"apply_custom_data_jinja_template",
):
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
"dataset_text_field"
] = formatted_dataset_field
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
"template"
] = template
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
"add_eos_token"
] = add_eos_token

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.add_special_tokens({"eos_token": "</s>"})
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
assert set(train_set.column_names) == column_names
assert (
train_set[0]["input_ids"][-1] == tokenizer.eos_token_id
if add_eos_token
else train_set[0]["input_ids"][-1] != tokenizer.eos_token_id
)
elif datasets_name == "pretokenized_dataset":
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
elif datasets_name in (
"apply_custom_data_template",
"apply_custom_data_jinja_template",
):
assert formatted_dataset_field in set(train_set.column_names)
assert (
train_set[0][formatted_dataset_field].endswith(tokenizer.eos_token)
if add_eos_token
else not train_set[0][formatted_dataset_field].endswith(tokenizer.eos_token)
)


@pytest.mark.parametrize(
"data_config_path, data_path_list",
[
Expand Down
18 changes: 15 additions & 3 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def tokenize_and_apply_input_masking(
column_names: List[str],
input_field_name: str,
output_field_name: str,
add_eos_token: bool = True,
**kwargs,
):
"""Function (data handler) to tokenize and apply instruction masking on dataset
Expand All @@ -68,6 +69,7 @@ def tokenize_and_apply_input_masking(
column_names: Name of all the columns in the dataset.
input_field_name: Name of the input (instruction) field in dataset
output_field_name: Name of the output field in dataset
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
**kwargs: Any additional args passed to the handler
Returns:
Formatted Dataset element with input_ids, labels and attention_mask columns
Expand All @@ -83,7 +85,11 @@ def tokenize_and_apply_input_masking(
input_text = element[input_field_name]
output_text = element[output_field_name]

combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token)
eos_token = ""
if add_eos_token:
eos_token = tokenizer.eos_token

combined = combine_sequence(input_text, output_text, eos_token=eos_token)

tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})

Expand Down Expand Up @@ -131,6 +137,7 @@ def apply_custom_data_formatting_template(
tokenizer: AutoTokenizer,
dataset_text_field: str,
template: str,
add_eos_token: bool = True,
**kwargs,
):
"""Function (data handler) to format datasets with Alpaca style / other templates.
Expand All @@ -142,12 +149,14 @@ def apply_custom_data_formatting_template(
dataset_text_field: Text column name of the dataset where formatted text is saved.
template: Template to format data with. Features of Dataset
should be referred to by {{key}}
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
Returns:
Formatted Dataset element by formatting dataset with template+tokenizer.EOS_TOKEN
Saves the result to dataset_text_field argument.
"""

template += tokenizer.eos_token
if add_eos_token:
template += tokenizer.eos_token

def replace_text(match_obj):
captured_groups = match_obj.groups()
Expand All @@ -174,6 +183,7 @@ def apply_custom_jinja_template(
tokenizer: AutoTokenizer,
dataset_text_field: str,
template: str,
add_eos_token: bool = True,
**kwargs,
):
"""Function (data handler) to format datasets with jinja templates.
Expand All @@ -185,12 +195,14 @@ def apply_custom_jinja_template(
dataset_text_field: formatted_dataset_field.
template: Template to format data with. Features of Dataset
should be referred to by {{key}}.
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
Returns:
Formatted HF Dataset element by formatting dataset with provided jinja template
Saves the result to dataset_text_field argument.
"""
if add_eos_token:
template += tokenizer.eos_token

template += tokenizer.eos_token
template = process_jinja_placeholders(template)
env = SandboxedEnvironment(undefined=StrictUndefined)

Expand Down