Skip to content

Commit ea93f3a

Browse files
vidyasivregisss
andauthored
Support for custom files for run_lora_clm.py (huggingface#1039)
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
1 parent 59d182d commit ea93f3a

File tree

6 files changed

+320
-13
lines changed

6 files changed

+320
-13
lines changed

Makefile

+5
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ fast_tests_table_transformers:
7272
python -m pip install .[tests]
7373
python -m pytest tests/test_table_transformer.py
7474

75+
# Run non-performance regressions
76+
slow_tests_custom_file_input: test_installs
77+
python -m pip install -r examples/language-modeling/requirements.txt
78+
python -m pytest tests/test_custom_file_input.py
79+
7580
# Run single-card non-regression tests
7681
slow_tests_1x: test_installs
7782
python -m pytest tests/test_examples.py -v -s -k "single_card"

examples/language-modeling/README.md

+42
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,48 @@ DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 LOWER_LIST=ops_bf16.txt python3 ..
693693
```
694694
Default `peft_type` is `lora`, you could enable adalora or ia3 using `--peft_type adalora` or `--peft_type ia3`, or enable llama-adapter for llama model using `--peft_type llama-adapter`.
695695

696+
#### Custom Files
697+
698+
To run on your own training and validation files, use the following command:
699+
700+
```bash
701+
python run_lora_clm.py \
702+
--model_name_or_path bigcode/starcoder \
703+
--train_file path_to_train_file \
704+
--validation_file path_to_validation_file \
705+
--per_device_train_batch_size 8 \
706+
--per_device_eval_batch_size 8 \
707+
--do_train \
708+
--do_eval \
709+
--output_dir /tmp/test-lora-clm \
710+
--bf16 \
711+
--use_habana \
712+
--use_lazy_mode \
713+
--use_hpu_graphs_for_inference \
714+
--dataset_concatenation \
715+
--throughput_warmup_steps 3
716+
```
717+
718+
The format of the jsonlines files (with extensions .json or .jsonl) is expected to be
719+
720+
```json
721+
{"text": "<text>"}
722+
{"text": "<text>"}
723+
{"text": "<text>"}
724+
{"text": "<text>"}
725+
```
726+
727+
The format of the text files (with extensions .text or .txt) is expected to be
728+
729+
```json
730+
"<text>"
731+
"<text>"
732+
"<text>"
733+
"<text>"
734+
```
735+
736+
> Note: When using both custom files i.e `--train_file` and `--validation_file`, all files are expected to be of the same type i.e json or text.
737+
696738
### Prompt/Prefix/P-tuning
697739

698740
To run prompt tuning finetuning, you can use `run_prompt_tuning_clm.py`.

examples/language-modeling/run_lora_clm.py

+39-13
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ class DataArguments:
240240
default=False,
241241
metadata={"help": "Whether to keep in memory the loaded dataset. Defaults to False."},
242242
)
243+
keep_linebreaks: bool = field(
244+
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
245+
)
243246
dataset_seed: int = field(
244247
default=42,
245248
metadata={
@@ -510,6 +513,10 @@ def main():
510513
)
511514

512515
if "validation" not in raw_datasets.keys() and training_args.do_eval:
516+
if not data_args.validation_split_percentage:
517+
raise ValueError(
518+
"Please set --validation_split_percentage as dataset does not contain `validation` key"
519+
)
513520
raw_datasets["validation"] = load_dataset(
514521
data_args.dataset_name,
515522
data_args.dataset_config_name,
@@ -538,9 +545,11 @@ def main():
538545
if data_args.train_file is not None
539546
else data_args.validation_file.split(".")[-1]
540547
)
541-
if extension == "txt":
548+
if extension in ("txt", "text"):
542549
extension = "text"
543550
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
551+
if extension in ("json", "jsonl"):
552+
extension = "json"
544553
raw_datasets = load_dataset(
545554
extension,
546555
data_files=data_files,
@@ -551,6 +560,10 @@ def main():
551560

552561
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
553562
if "validation" not in raw_datasets.keys() and training_args.do_eval:
563+
if not data_args.validation_split_percentage:
564+
raise ValueError(
565+
"Please set --validation_split_percentage as dataset does not contain `validation` key"
566+
)
554567
raw_datasets["validation"] = load_dataset(
555568
extension,
556569
data_files=data_files,
@@ -567,20 +580,32 @@ def main():
567580
token=model_args.token,
568581
**dataset_args,
569582
)
570-
583+
single_column_dataset = False
584+
# For named dataset (timdettmers/openassistant-guanaco) or custom dataset with a single column "text"
571585
if (
572-
data_args.dataset_name == "timdettmers/openassistant-guanaco"
573-
): # from https://github.com/artidoro/qlora/blob/main/qlora.py#L621
586+
training_args.do_train
587+
and raw_datasets["train"].num_columns == 1
588+
or training_args.do_eval
589+
and raw_datasets["validation"].num_columns == 1
590+
):
591+
single_column_dataset = True
574592
raw_datasets = raw_datasets.map(
575593
lambda x: {
576594
"input": "",
577595
"output": x["text"],
578596
}
579597
)
580-
# Remove unused columns.
581-
raw_datasets = raw_datasets.remove_columns(
582-
[col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]]
583-
)
598+
if training_args.do_train:
599+
# Remove unused columns.
600+
raw_datasets = raw_datasets.remove_columns(
601+
[col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]]
602+
)
603+
604+
if training_args.do_eval:
605+
# Remove unused columns.
606+
raw_datasets = raw_datasets.remove_columns(
607+
[col for col in raw_datasets.column_names["validation"] if col not in ["input", "output"]]
608+
)
584609
else:
585610
# Preprocessing the datasets.
586611
for key in raw_datasets:
@@ -680,7 +705,7 @@ def tokenize(prompt, add_eos_token=True):
680705
def preprocess_function(examples):
681706
keys = list(examples.data.keys())
682707
if len(keys) != 2:
683-
raise ValueError("Unsupported dataset format")
708+
raise ValueError(f"Unsupported dataset format, number of keys {keys} !=2")
684709

685710
st = [s + t for s, t in zip(examples[keys[0]], examples[keys[1]])]
686711

@@ -717,17 +742,18 @@ def concatenate_data(dataset, max_seq_length):
717742
concatenated_dataset[column] = reshaped_data
718743
return datasets.Dataset.from_dict(concatenated_dataset)
719744

720-
if data_args.dataset_name == "timdettmers/openassistant-guanaco":
745+
if single_column_dataset:
721746
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"])
722747
if training_args.do_eval:
723-
tokenized_datasets_eval_ = tokenized_datasets["test"].remove_columns(["input", "output"])
748+
tokenized_datasets_eval_ = tokenized_datasets["validation"].remove_columns(["input", "output"])
724749
else:
725750
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["prompt_sources", "prompt_targets"])
726751
if training_args.do_eval:
727752
tokenized_datasets_eval_ = tokenized_datasets["validation"].remove_columns(
728753
["prompt_sources", "prompt_targets"]
729754
)
730-
tokenized_datasets["train"] = concatenate_data(tokenized_datasets_, data_args.max_seq_length)
755+
if training_args.do_train:
756+
tokenized_datasets["train"] = concatenate_data(tokenized_datasets_, data_args.max_seq_length)
731757
if training_args.do_eval:
732758
tokenized_datasets["validation"] = concatenate_data(tokenized_datasets_eval_, data_args.max_seq_length)
733759
if training_args.do_train:
@@ -849,7 +875,7 @@ def compute_metrics(eval_preds):
849875
trainer.log_metrics("train", metrics)
850876
trainer.save_metrics("train", metrics)
851877

852-
# Evaluation
878+
# Evaluation
853879
if training_args.do_eval:
854880
logger.info("*** Evaluate ***")
855881
metrics = trainer.evaluate()

0 commit comments

Comments
 (0)