Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding eos token to be made a flag so we don't force it on every handler #467

Merged
merged 8 commits into from
Feb 14, 2025
7 changes: 5 additions & 2 deletions docs/advanced-data-preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,17 @@ Users can also pass any number of `kwargs` arguments required for each data hand

#### Preexisting data handlers
This library currently supports the following [preexisting data handlers](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py#L156):
- `tokenize_and_apply_input_masking`:
Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets.
- `add_tokenizer_eos_token`:
Appends the tokenizer's EOS token to a specified dataset field.
- `apply_custom_data_formatting_template`:
Applies a custom template (e.g., Alpaca style) to format dataset elements.
By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py)
- `tokenize_and_apply_input_masking`:
Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets.
By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py)
- `apply_custom_jinja_template`:
Applies a custom jinja template (e.g., Alpaca style) to format dataset elements.
By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I might be missing something in this PR. How can user pass the value of add_eos_token = False, as for example here in fn_kwargs we don't pass the value of add_eos_token given by the user to the handler. Can I ask what am I missing ?

Probably a small test case when add_eos_token = False in test_data_preprocessing.py would be appreciated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Abhishek-TAMU this is only for data_config where people can specify it and we disable adding EOS_TOKEN to the code.
This is not for our cli based args flow for where the function you pointed out to is used which for specific instruction masking use case on a single dataset file and hence in that part of the code we making sure its already added is ensured.

We can try to add a data config test where we set the flag to false and check the behaviour.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I see. Yea user can pass it in fn_kwargs in data_config.yaml
Feel free to add test case for the same.

- `apply_tokenizer_chat_template`:
Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates.
- `duplicate_columns`:
Expand Down
18 changes: 15 additions & 3 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def tokenize_and_apply_input_masking(
column_names: List[str],
input_field_name: str,
output_field_name: str,
add_eos_token: bool = True,
**kwargs,
):
"""Function (data handler) to tokenize and apply instruction masking on dataset
Expand All @@ -68,6 +69,7 @@ def tokenize_and_apply_input_masking(
column_names: Name of all the columns in the dataset.
input_field_name: Name of the input (instruction) field in dataset
output_field_name: Name of the output field in dataset
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
**kwargs: Any additional args passed to the handler
Returns:
Formatted Dataset element with input_ids, labels and attention_mask columns
Expand All @@ -83,7 +85,11 @@ def tokenize_and_apply_input_masking(
input_text = element[input_field_name]
output_text = element[output_field_name]

combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token)
eos_token = ""
if add_eos_token:
eos_token = tokenizer.eos_token

combined = combine_sequence(input_text, output_text, eos_token=eos_token)

tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})

Expand Down Expand Up @@ -131,6 +137,7 @@ def apply_custom_data_formatting_template(
tokenizer: AutoTokenizer,
dataset_text_field: str,
template: str,
add_eos_token: bool = True,
**kwargs,
):
"""Function (data handler) to format datasets with Alpaca style / other templates.
Expand All @@ -142,12 +149,14 @@ def apply_custom_data_formatting_template(
dataset_text_field: Text column name of the dataset where formatted text is saved.
template: Template to format data with. Features of Dataset
should be referred to by {{key}}
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
Returns:
Formatted Dataset element by formatting dataset with template+tokenizer.EOS_TOKEN
Saves the result to dataset_text_field argument.
"""

template += tokenizer.eos_token
if add_eos_token:
template += tokenizer.eos_token

def replace_text(match_obj):
captured_groups = match_obj.groups()
Expand All @@ -174,6 +183,7 @@ def apply_custom_jinja_template(
tokenizer: AutoTokenizer,
dataset_text_field: str,
template: str,
add_eos_token: bool = True,
**kwargs,
):
"""Function (data handler) to format datasets with jinja templates.
Expand All @@ -185,12 +195,14 @@ def apply_custom_jinja_template(
dataset_text_field: formatted_dataset_field.
template: Template to format data with. Features of Dataset
should be referred to by {{key}}.
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
Returns:
Formatted HF Dataset element by formatting dataset with provided jinja template
Saves the result to dataset_text_field argument.
"""
if add_eos_token:
template += tokenizer.eos_token

template += tokenizer.eos_token
template = process_jinja_placeholders(template)
env = SandboxedEnvironment(undefined=StrictUndefined)

Expand Down