From f1fd130d5f6622e0105b232b24b2dae18051283f Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Thu, 13 Feb 2025 21:16:58 +0530 Subject: [PATCH] feat: Add support for renaming and retaining columns in data preprocessor (#466) * Add functionality to rename and retain dataset columns in data preprocessor. Signed-off-by: Dushyant Behl * fix fmt Signed-off-by: Dushyant Behl * add unit tests Signed-off-by: Dushyant Behl * Update advanced-data-preprocessing.md Signed-off-by: Dushyant Behl --------- Signed-off-by: Dushyant Behl --- docs/advanced-data-preprocessing.md | 10 ++++ docs/ept.md | 4 +- .../predefined_data_configs/__init__.py | 3 + .../rename_retain_columns.yaml | 20 +++++++ tests/data/test_data_preprocessing.py | 55 +++++++++++++++++++ tests/test_sft_trainer.py | 5 ++ tuning/data/data_config.py | 14 +++++ tuning/data/data_processors.py | 20 +++++++ 8 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 tests/artifacts/predefined_data_configs/rename_retain_columns.yaml diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md index 55ae4c47d..ad2434143 100644 --- a/docs/advanced-data-preprocessing.md +++ b/docs/advanced-data-preprocessing.md @@ -60,6 +60,10 @@ definitions: type: float builder: type: string + rename_columns: + type: object + retain_columns: + type: object data_paths: type: array items: @@ -118,6 +122,8 @@ Users can create a data config file in any of YAML or JSON format they choose (w - `name` (optional, str): A unique identifier for the dataset. - `data_paths` (optional, list): A `list` of file paths or directories containing the dataset. - `builder` (optional, str): Specifies a [Hugging Face dataset builder](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/loading_methods#datasets.load_dataset.path), if applicable. + - `rename_columns` (optional, dict[str:str]): Specifies a dictionary of columns to rename like `{"old_name": "new_name"}` at dataset load time. *Applied before `retain_columns` if both are specified*. + - `retain_columns` (optional, list[str]): Specifies a list of columns to retain `["input_ids", "labels"]` every other column will be dropped at dataset load time. *Applied strictly after `rename_columns` if both are specified*. - `sampling` (optional, float): The sampling ratio (0.0 to 1.0) with which to sample a dataset in case of interleaving. - `data_handlers` (optional, list): A list of data handler configurations which preprocess the dataset. @@ -149,6 +155,10 @@ Not Supported: Currently there's no support for sampling under multiple data paths which are defined inside a dataset definition. All dataset paths that will be specified inside one dataset will be [concatenated](https://huggingface.co/docs/datasets/v3.2.0/en/process#concatenate) after loading them, while across datasets users can specify [mixing via sampling datasets](#data-mixing) +Probably something like this: + +Additionally while loading the dataset, users can specify which columns to rename via `rename_columns` and which to retain via `retain_columns` arguments above. +The order of application of these operations is *strictly rename followed by retain* so users should note that an old column name which is renamed will not be available in retain and hence should be careful while applying these operations. The code will throw a `ValueError` in case user specified a column requested to be renamed via rename argument in retain argument as well. ### How can users specify data handlers. diff --git a/docs/ept.md b/docs/ept.md index c6fbdb379..19433bf86 100644 --- a/docs/ept.md +++ b/docs/ept.md @@ -32,7 +32,7 @@ datasets: data_paths: - "" data_handlers: - - name: apply_custom_data_formatting + - name: add_tokenizer_eos_token arguments: remove_columns: all batched: false @@ -109,4 +109,4 @@ The code again would add `EOS_TOKEN` to the non tokenized data before using it a ### Additional Information This feature is supported post [v2.3.1](https://github.com/foundation-model-stack/fms-hf-tuning/releases/tag/v2.3.1) of this library. -Post Last Updated On: 10-02-2025 \ No newline at end of file +Post Last Updated On: 12-02-2025 \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index 7d0317466..63c14867c 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -37,3 +37,6 @@ DATA_CONFIG_DUPLICATE_COLUMNS = os.path.join( PREDEFINED_DATA_CONFIGS, "duplicate_columns.yaml" ) +DATA_CONFIG_RENAME_RETAIN_COLUMNS = os.path.join( + PREDEFINED_DATA_CONFIGS, "rename_retain_columns.yaml" +) diff --git a/tests/artifacts/predefined_data_configs/rename_retain_columns.yaml b/tests/artifacts/predefined_data_configs/rename_retain_columns.yaml new file mode 100644 index 000000000..ecfe993ca --- /dev/null +++ b/tests/artifacts/predefined_data_configs/rename_retain_columns.yaml @@ -0,0 +1,20 @@ +dataprocessor: + type: default +datasets: + - name: text_dataset_input_output_masking + rename_columns: + "input" : "instruction" + "output" : "response" + retain_columns: + - "instruction" + - "response" + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_input_masking + arguments: + remove_columns: all + batched: false + fn_kwargs: + input_field_name: instruction + output_field_name: response \ No newline at end of file diff --git a/tests/data/test_data_preprocessing.py b/tests/data/test_data_preprocessing.py index b435f2d2b..7bf07e58a 100644 --- a/tests/data/test_data_preprocessing.py +++ b/tests/data/test_data_preprocessing.py @@ -33,6 +33,7 @@ DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + DATA_CONFIG_RENAME_RETAIN_COLUMNS, DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, ) from tests.artifacts.testdata import ( @@ -1365,3 +1366,57 @@ def test_process_dataset_configs_with_sampling_error( (_, _, _, _, _, _) = process_dataargs( data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS ) + + +@pytest.mark.parametrize( + "datafile, rename, retain, final, datasetconfigname", + [ + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + {"input": "instruction", "output": "response"}, + None, + ["ID", "Label", "instruction", "response"], + DATA_CONFIG_RENAME_RETAIN_COLUMNS, + ), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + None, + ["ID", "input", "output"], + ["ID", "input", "output"], + DATA_CONFIG_RENAME_RETAIN_COLUMNS, + ), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + {"input": "instruction", "output": "response"}, + ["Label", "instruction", "response"], + ["Label", "instruction", "response"], + DATA_CONFIG_RENAME_RETAIN_COLUMNS, + ), + ], +) +def test_rename_and_retain_dataset_columns( + datafile, rename, retain, final, datasetconfigname +): + """Test process_dataset_configs for expected output.""" + dataprocessor_config = DataPreProcessorConfig() + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + processor = DataPreProcessor( + processor_config=dataprocessor_config, + tokenizer=tokenizer, + ) + datasetconfig = [ + DataSetConfig( + name=datasetconfigname, + data_paths=[datafile], + rename_columns=rename, + retain_columns=retain, + ) + ] + train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig) + + assert isinstance(train_dataset, Dataset) + assert set(train_dataset.column_names) == set(final) + + with open(datafile, "r") as file: + data = json.load(file) + assert len(train_dataset) == len(data) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index c7c73891d..b718a97cf 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -38,6 +38,7 @@ from tests.artifacts.predefined_data_configs import ( DATA_CONFIG_DUPLICATE_COLUMNS, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + DATA_CONFIG_RENAME_RETAIN_COLUMNS, DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, ) from tests.artifacts.testdata import ( @@ -837,6 +838,10 @@ def test_run_causallm_ft_pretokenized(dataset_path): ], DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, ), + ( + [TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON], + DATA_CONFIG_RENAME_RETAIN_COLUMNS, + ), ], ) def test_run_causallm_ft_and_inference_with_multiple_dataset( diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py index 0c5521baf..6707e8767 100644 --- a/tuning/data/data_config.py +++ b/tuning/data/data_config.py @@ -36,6 +36,8 @@ class DataSetConfig: data_paths: List[str] builder: Optional[str] = None # Referring to Hugging Face dataset builder sampling: Optional[float] = None + rename_columns: Optional[Dict] = None + retain_columns: Optional[List] = None data_handlers: Optional[List[DataHandlerConfig]] = None @@ -100,6 +102,18 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig: 0 <= ratio <= 1.0 ), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]" c.sampling = ratio + if "rename_columns" in kwargs and kwargs["rename_columns"] is not None: + rename = kwargs["rename_columns"] + assert isinstance( + rename, dict + ), "rename_columns should be a dict with current_name:new_name" + c.rename_columns = rename + if "retain_columns" in kwargs and kwargs["retain_columns"] is not None: + retain = kwargs["retain_columns"] + assert isinstance( + retain, list + ), "retain_columns should be a list[str] with names of columns to retain" + c.retain_columns = retain if "data_handlers" in kwargs: c.data_handlers = [] for handler in kwargs["data_handlers"]: diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index bdac6947b..1f7e92da5 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -243,6 +243,26 @@ def _process_dataset_configs( logger.info("Loaded raw dataset : %s", str(raw_dataset)) + # Check if both are conflicting options before proceeding. + if d.rename_columns and d.retain_columns: + commmon = set(d.rename_columns.keys()) & set(d.retain_columns) + if commmon: + raise ValueError( + f"You are trying to retain {str(commmon)} columns" + " which will be renamed via rename operation." + ) + + if d.rename_columns: + logger.info("Renaming %s columns", str(d.rename_columns)) + raw_dataset = raw_dataset.rename_columns( + column_mapping=d.rename_columns + ) + logger.info("Done") + if d.retain_columns: + logger.info("Retaining %s columns", str(d.retain_columns)) + raw_dataset = raw_dataset.select_columns(column_names=d.retain_columns) + logger.info("Done") + raw_datasets = DatasetDict() # Assume all is train split