Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for renaming and retaining columns in data preprocessor #466

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/advanced-data-preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ definitions:
type: float
builder:
type: string
rename_columns:
type: object
retain_columns:
type: object
data_paths:
type: array
items:
Expand Down Expand Up @@ -118,6 +122,8 @@ Users can create a data config file in any of YAML or JSON format they choose (w
- `name` (optional, str): A unique identifier for the dataset.
- `data_paths` (optional, list): A `list` of file paths or directories containing the dataset.
- `builder` (optional, str): Specifies a [Hugging Face dataset builder](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/loading_methods#datasets.load_dataset.path), if applicable.
- `rename_columns` (optional, dict[str:str]): Specifies a dictionary of columns to rename like `{"old_name": "new_name"}` at dataset load time. *Applied before `retain_columns` if both are specified*.
- `retain_columns` (optional, list[str]): Specifies a list of columns to retain `["input_ids", "labels"]` every other column will be dropped at dataset load time. *Applied strictly after `rename_columns` if both are specified*.
- `sampling` (optional, float): The sampling ratio (0.0 to 1.0) with which to sample a dataset in case of interleaving.
- `data_handlers` (optional, list): A list of data handler configurations which preprocess the dataset.

Expand Down Expand Up @@ -149,6 +155,10 @@ Not Supported:
Currently there's no support for sampling under multiple data paths which are defined inside a dataset definition.
All dataset paths that will be specified inside one dataset will be [concatenated](https://huggingface.co/docs/datasets/v3.2.0/en/process#concatenate) after loading them, while across datasets users can specify [mixing via sampling datasets](#data-mixing)

Probably something like this:

Additionally while loading the dataset, users can specify which columns to rename via `rename_columns` and which to retain via `retain_columns` arguments above.
The order of application of these operations is *strictly rename followed by retain* so users should note that an old column name which is renamed will not be available in retain and hence should be careful while applying these operations. The code will throw a `ValueError` in case user specified a column requested to be renamed via rename argument in retain argument as well.

### How can users specify data handlers.

Expand Down
4 changes: 2 additions & 2 deletions docs/ept.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ datasets:
data_paths:
- "<path-to-the-jsonl-dataset>"
data_handlers:
- name: apply_custom_data_formatting
- name: add_tokenizer_eos_token
arguments:
remove_columns: all
batched: false
Expand Down Expand Up @@ -109,4 +109,4 @@ The code again would add `EOS_TOKEN` to the non tokenized data before using it a

### Additional Information
This feature is supported post [v2.3.1](https://github.com/foundation-model-stack/fms-hf-tuning/releases/tag/v2.3.1) of this library.
Post Last Updated On: 10-02-2025
Post Last Updated On: 12-02-2025
3 changes: 3 additions & 0 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@
DATA_CONFIG_DUPLICATE_COLUMNS = os.path.join(
PREDEFINED_DATA_CONFIGS, "duplicate_columns.yaml"
)
DATA_CONFIG_RENAME_RETAIN_COLUMNS = os.path.join(
PREDEFINED_DATA_CONFIGS, "rename_retain_columns.yaml"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
dataprocessor:
type: default
datasets:
- name: text_dataset_input_output_masking
rename_columns:
"input" : "instruction"
"output" : "response"
retain_columns:
- "instruction"
- "response"
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field_name: instruction
output_field_name: response
55 changes: 55 additions & 0 deletions tests/data/test_data_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
)
from tests.artifacts.testdata import (
Expand Down Expand Up @@ -1365,3 +1366,57 @@ def test_process_dataset_configs_with_sampling_error(
(_, _, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)


@pytest.mark.parametrize(
"datafile, rename, retain, final, datasetconfigname",
[
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
{"input": "instruction", "output": "response"},
None,
["ID", "Label", "instruction", "response"],
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
),
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
None,
["ID", "input", "output"],
["ID", "input", "output"],
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
),
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
{"input": "instruction", "output": "response"},
["Label", "instruction", "response"],
["Label", "instruction", "response"],
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
),
],
)
def test_rename_and_retain_dataset_columns(
datafile, rename, retain, final, datasetconfigname
):
"""Test process_dataset_configs for expected output."""
dataprocessor_config = DataPreProcessorConfig()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
processor = DataPreProcessor(
processor_config=dataprocessor_config,
tokenizer=tokenizer,
)
datasetconfig = [
DataSetConfig(
name=datasetconfigname,
data_paths=[datafile],
rename_columns=rename,
retain_columns=retain,
)
]
train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig)

assert isinstance(train_dataset, Dataset)
assert set(train_dataset.column_names) == set(final)

with open(datafile, "r") as file:
data = json.load(file)
assert len(train_dataset) == len(data)
5 changes: 5 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from tests.artifacts.predefined_data_configs import (
DATA_CONFIG_DUPLICATE_COLUMNS,
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
)
from tests.artifacts.testdata import (
Expand Down Expand Up @@ -837,6 +838,10 @@ def test_run_causallm_ft_pretokenized(dataset_path):
],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
(
[TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON],
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
),
],
)
def test_run_causallm_ft_and_inference_with_multiple_dataset(
Expand Down
14 changes: 14 additions & 0 deletions tuning/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class DataSetConfig:
data_paths: List[str]
builder: Optional[str] = None # Referring to Hugging Face dataset builder
sampling: Optional[float] = None
rename_columns: Optional[Dict] = None
retain_columns: Optional[List] = None
data_handlers: Optional[List[DataHandlerConfig]] = None


Expand Down Expand Up @@ -100,6 +102,18 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
0 <= ratio <= 1.0
), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]"
c.sampling = ratio
if "rename_columns" in kwargs and kwargs["rename_columns"] is not None:
rename = kwargs["rename_columns"]
assert isinstance(
rename, dict
), "rename_columns should be a dict with current_name:new_name"
c.rename_columns = rename
if "retain_columns" in kwargs and kwargs["retain_columns"] is not None:
retain = kwargs["retain_columns"]
assert isinstance(
retain, list
), "retain_columns should be a list[str] with names of columns to retain"
c.retain_columns = retain
if "data_handlers" in kwargs:
c.data_handlers = []
for handler in kwargs["data_handlers"]:
Expand Down
20 changes: 20 additions & 0 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,26 @@ def _process_dataset_configs(

logger.info("Loaded raw dataset : %s", str(raw_dataset))

# Check if both are conflicting options before proceeding.
if d.rename_columns and d.retain_columns:
commmon = set(d.rename_columns.keys()) & set(d.retain_columns)
if commmon:
raise ValueError(
f"You are trying to retain {str(commmon)} columns"
" which will be renamed via rename operation."
)

if d.rename_columns:
logger.info("Renaming %s columns", str(d.rename_columns))
raw_dataset = raw_dataset.rename_columns(
column_mapping=d.rename_columns
)
logger.info("Done")
if d.retain_columns:
logger.info("Retaining %s columns", str(d.retain_columns))
raw_dataset = raw_dataset.select_columns(column_names=d.retain_columns)
logger.info("Done")

raw_datasets = DatasetDict()

# Assume all is train split
Expand Down