Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🍟 [SFT] Handles the dataset if it has been preprocessed #2863

Merged
merged 16 commits into from
Feb 18, 2025
Merged
34 changes: 33 additions & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_sft_trainer(self):

self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))

def test_sft_trainer_with_pretokenzied_data_packing(self):
def test_sft_trainer_with_pretokenized_data_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
Expand Down Expand Up @@ -1370,3 +1370,35 @@ def test_train_peft_model(self):
"base_layer" not in n
): # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

def test_sft_trainer_directly_with_pretokenized_data(self):
# Get the model and dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

def tokenize_example(example):
return tokenizer(example["text"])

# Apply tokenization
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"])

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(args=training_args, model=model, train_dataset=tokenized_dataset)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
28 changes: 22 additions & 6 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class SFTTrainer(Trainer):
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).

The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
Expand Down Expand Up @@ -370,14 +372,27 @@ def _prepare_dataset(
if isinstance(dataset, ConstantLengthDataset):
return dataset

# If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is
# a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset
column_names = list(next(iter(dataset)).keys())
is_processed = "input_ids" in column_names

# Build the kwargs for the `map` function
map_kwargs = {}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc

with PartialState().local_main_process_first():
# Apply the formatting function if any
if formatting_func is not None:
if formatting_func is not None and is_processed:
warnings.warn(
"You passed a dataset that is already processed (contains an `input_ids` field) together with a "
"formatting function. Therefore `formatting_func` will be ignored. Either remove the "
"`formatting_func` or pass a dataset that is not already processed.",
UserWarning,
)

if formatting_func is not None and not is_processed:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"

Expand Down Expand Up @@ -407,10 +422,11 @@ def concat_prompt_completion(example):
**map_kwargs,
)

# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
# Tokenize the dataset if needed
if not is_processed:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)

# Pack or truncate
if packing:
Expand All @@ -424,7 +440,7 @@ def concat_prompt_completion(example):
)
elif args.max_seq_length is not None:
dataset = dataset.map(
lambda ex: {key: ex[key][: args.max_seq_length] for key in ["input_ids", "attention_mask"]},
lambda ex: {"input_ids": ex["input_ids"][: args.max_seq_length]},
**map_kwargs,
)
# For Liger kernel, ensure only input_ids is present
Expand Down
Loading