-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: adding eos token to be made a flag so we don't force it on every handler #467
Changes from 7 commits
e5f1fbb
adafe9a
22f17ab
4ffa8fe
10024b6
38fce96
450c52d
0533ba8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -682,39 +682,75 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): | |
|
||
|
||
@pytest.mark.parametrize( | ||
"data_config_path, data_path", | ||
"data_config_path, data_path, add_eos_token", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do a separate unit test for this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea. Pushed the changes. Thanks! |
||
[ | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), | ||
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), | ||
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), | ||
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), | ||
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), | ||
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), | ||
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), | ||
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), | ||
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_ARROW), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON, True), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL, False), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET, True), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW, False), | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, | ||
TWITTER_COMPLAINTS_DATA_JSON, | ||
True, | ||
), | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, | ||
TWITTER_COMPLAINTS_DATA_JSONL, | ||
False, | ||
), | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, | ||
TWITTER_COMPLAINTS_DATA_PARQUET, | ||
True, | ||
), | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, | ||
TWITTER_COMPLAINTS_DATA_ARROW, | ||
False, | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
TWITTER_COMPLAINTS_TOKENIZED_JSON, | ||
True, | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
TWITTER_COMPLAINTS_TOKENIZED_JSONL, | ||
True, | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
TWITTER_COMPLAINTS_TOKENIZED_PARQUET, | ||
True, | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
TWITTER_COMPLAINTS_TOKENIZED_ARROW, | ||
True, | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, | ||
True, | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, | ||
False, | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
True, | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, | ||
False, | ||
), | ||
], | ||
) | ||
def test_process_dataconfig_file(data_config_path, data_path): | ||
def test_process_dataconfig_file(data_config_path, data_path, add_eos_token): | ||
"""Ensure that datasets are formatted and validated correctly based on the arguments passed in config file.""" | ||
with open(data_config_path, "r") as f: | ||
yaml_content = yaml.safe_load(f) | ||
|
@@ -723,10 +759,15 @@ def test_process_dataconfig_file(data_config_path, data_path): | |
|
||
# Modify input_field_name and output_field_name according to dataset | ||
if datasets_name == "text_dataset_input_output_masking": | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
"input_field_name": "input", | ||
"output_field_name": "output", | ||
} | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ | ||
"input_field_name" | ||
] = "input" | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ | ||
"output_field_name" | ||
] = "output" | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ | ||
"add_eos_token" | ||
] = add_eos_token | ||
|
||
# Modify dataset_text_field and template according to dataset | ||
formatted_dataset_field = "formatted_data_field" | ||
|
@@ -735,10 +776,15 @@ def test_process_dataconfig_file(data_config_path, data_path): | |
"apply_custom_data_jinja_template", | ||
): | ||
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
} | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ | ||
"dataset_text_field" | ||
] = formatted_dataset_field | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ | ||
"template" | ||
] = template | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][ | ||
"add_eos_token" | ||
] = add_eos_token | ||
|
||
with tempfile.NamedTemporaryFile( | ||
"w", delete=False, suffix=".yaml" | ||
|
@@ -748,18 +794,30 @@ def test_process_dataconfig_file(data_config_path, data_path): | |
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
tokenizer.add_special_tokens({"eos_token": "</s>"}) | ||
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) | ||
assert isinstance(train_set, Dataset) | ||
if datasets_name == "text_dataset_input_output_masking": | ||
column_names = set(["input_ids", "attention_mask", "labels"]) | ||
assert set(train_set.column_names) == column_names | ||
print("INFO", train_set[8]["input_ids"], tokenizer.eos_token_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove print statement here |
||
assert ( | ||
train_set[0]["input_ids"][-1] == tokenizer.eos_token_id | ||
if add_eos_token | ||
else train_set[0]["input_ids"][-1] != tokenizer.eos_token_id | ||
) | ||
elif datasets_name == "pretokenized_dataset": | ||
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) | ||
elif datasets_name in ( | ||
"apply_custom_data_template", | ||
"apply_custom_data_jinja_template", | ||
): | ||
assert formatted_dataset_field in set(train_set.column_names) | ||
assert ( | ||
train_set[0][formatted_dataset_field].endswith(tokenizer.eos_token) | ||
if add_eos_token | ||
else not train_set[0][formatted_dataset_field].endswith(tokenizer.eos_token) | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it defaults to
true
do we need to add this here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it here to point this config file in the documentation (As you can see the Url in documentation of this PR) for the user to know that they can use this flag to disable
EOS_TOKEN
from the handler.And additionally I am assigning it value
False
in Unit test.