From 7b21b5301e7cf5b7e457aa972d72883c92ab472f Mon Sep 17 00:00:00 2001 From: Angel Luu Date: Tue, 17 Dec 2024 13:23:03 -0700 Subject: [PATCH 1/6] feat: add scanner tracker Signed-off-by: Angel Luu --- pyproject.toml | 1 + tests/build/test_launch_script.py | 34 +++++++- tests/test_sft_trainer.py | 20 ++++- .../test_hf_resource_scanner_tracker.py | 85 +++++++++++++++++++ tuning/config/tracker_configs.py | 6 ++ tuning/sft_trainer.py | 12 ++- .../trackers/hf_resource_scanner_tracker.py | 47 ++++++++++ tuning/trackers/tracker_factory.py | 39 ++++++++- 8 files changed, 240 insertions(+), 4 deletions(-) create mode 100644 tests/trackers/test_hf_resource_scanner_tracker.py create mode 100644 tuning/trackers/hf_resource_scanner_tracker.py diff --git a/pyproject.toml b/pyproject.toml index b930f7680..0041804bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ aim = ["aim>=3.19.0,<4.0"] mlflow = ["mlflow"] fms-accel = ["fms-acceleration>=0.6"] gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"] +scanner-dev = ["HFResourceScanner>=0.1.0"] [tool.setuptools.packages.find] diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index e331a5e9b..58f24eea0 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -22,6 +22,7 @@ # Third Party import pytest +from transformers.utils.import_utils import _is_package_available # First Party from build.accelerate_launch import main @@ -31,7 +32,10 @@ USER_ERROR_EXIT_CODE, INTERNAL_ERROR_EXIT_CODE, ) -from tuning.config.tracker_configs import FileLoggingTrackerConfig +from tuning.config.tracker_configs import ( + FileLoggingTrackerConfig, + HFResourceScannerConfig, +) SCRIPT = "tuning/sft_trainer.py" MODEL_NAME = "Maykeye/TinyLLama-v0" @@ -246,6 +250,34 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true(): assert os.path.exists(new_embeddings_file_path) +@pytest.mark.skipif( + not _is_package_available("HFResourceScanner"), + reason="Only runs if HFResourceScanner is installed", +) +def test_launch_with_HFResourceScanner_enabled(): + with tempfile.TemporaryDirectory() as tempdir: + setup_env(tempdir) + TRAIN_KWARGS = { + **BASE_LORA_KWARGS, + **{ + "output_dir": tempdir, + "save_model_dir": tempdir, + "lora_post_process_for_vllm": True, + "gradient_accumulation_steps": 1, + "trackers": ["hf_resource_scanner"], + }, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + assert main() == 0 + + scanner_outfile = os.path.join( + tempdir, HFResourceScannerConfig.scanner_output_filename + ) + assert os.path.exists(scanner_outfile) + + def test_bad_script_path(): """Check for appropriate error for an invalid training script location""" with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index f2d4a1ee1..6f08b3bfe 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1053,12 +1053,18 @@ def _test_run_inference(checkpoint_path): def _validate_training( - tempdir, check_eval=False, train_logs_file="training_logs.jsonl" + tempdir, + check_eval=False, + train_logs_file="training_logs.jsonl", + check_scanner_file=False, ): assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) train_logs_file_path = "{}/{}".format(tempdir, train_logs_file) _validate_logfile(train_logs_file_path, check_eval) + if check_scanner_file: + _validate_hf_resource_scanner_file(tempdir) + def _validate_logfile(log_file_path, check_eval=False): train_log_contents = "" @@ -1073,6 +1079,18 @@ def _validate_logfile(log_file_path, check_eval=False): assert "validation_loss" in train_log_contents +def _validate_hf_resource_scanner_file(tempdir): + scanner_file_path = os.path.join(tempdir, "scanner_output.json") + assert os.path.exists(scanner_file_path) + assert os.path.getsize(scanner_file_path) > 0 + + scanner_contents = "" + with open(scanner_file_path, encoding="utf-8") as f: + scanner_contents = f.read() + + assert "ResourceScanner Memory Data:" in scanner_contents + + def _get_checkpoint_path(dir_path): return os.path.join(dir_path, "checkpoint-5") diff --git a/tests/trackers/test_hf_resource_scanner_tracker.py b/tests/trackers/test_hf_resource_scanner_tracker.py new file mode 100644 index 000000000..3c095042d --- /dev/null +++ b/tests/trackers/test_hf_resource_scanner_tracker.py @@ -0,0 +1,85 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + + +# Standard +import copy +import tempfile + +# Third Party +from transformers.utils.import_utils import _is_package_available +import pytest + +# First Party +from tests.test_sft_trainer import ( + DATA_ARGS, + MODEL_ARGS, + TRAIN_ARGS, + _get_checkpoint_path, + _test_run_causallm_ft, + _test_run_inference, + _validate_training, +) + +# Local +from tuning import sft_trainer +from tuning.config.tracker_configs import HFResourceScannerConfig, TrackerConfigFactory + +## HF Resource Scanner Tracker Tests + + +@pytest.mark.skipif( + not _is_package_available("HFResourceScanner"), + reason="Only runs if HFResourceScanner is installed", +) +def test_run_with_hf_resource_scanner_tracker(): + """Ensure that training succeeds with a good tracker name""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.trackers = ["hf_resource_scanner"] + + _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir) + _test_run_inference(_get_checkpoint_path(tempdir)) + + +@pytest.mark.skipif( + not _is_package_available("HFResourceScanner"), + reason="Only runs if HFResourceScanner is installed", +) +def test_sample_run_with_hf_resource_scanner_updated_filename(): + """Ensure that hf_resource_scanner output filename can be updated""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + train_args.trackers = ["hf_resource_scanner"] + + scanner_output_file = "scanner_output.json" + + tracker_configs = TrackerConfigFactory( + hf_resource_scanner_config=HFResourceScannerConfig( + scanner_output_filename=scanner_output_file + ) + ) + + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs + ) + + # validate ft tuning configs + _validate_training(tempdir, check_scanner_file=True) diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py index 51c44aed1..bcadc7776 100644 --- a/tuning/config/tracker_configs.py +++ b/tuning/config/tracker_configs.py @@ -16,6 +16,11 @@ from dataclasses import dataclass +@dataclass +class HFResourceScannerConfig: + scanner_output_filename: str = "scanner_output.json" + + @dataclass class FileLoggingTrackerConfig: training_logs_filename: str = "training_logs.jsonl" @@ -80,3 +85,4 @@ class TrackerConfigFactory: file_logger_config: FileLoggingTrackerConfig = None aim_config: AimConfig = None mlflow_config: MLflowConfig = None + hf_resource_scanner_config: HFResourceScannerConfig = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index b3e28f686..27af7078c 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -53,6 +53,7 @@ AimConfig, FileLoggingTrackerConfig, MLflowConfig, + HFResourceScannerConfig, TrackerConfigFactory, ) from tuning.data.setup_dataprocessor import process_dataargs @@ -458,6 +459,7 @@ def get_parser(): peft_config.PromptTuningConfig, FileLoggingTrackerConfig, AimConfig, + HFResourceScannerConfig, QuantizedLoraConfig, FusedOpsAndKernelsConfig, AttentionAndDistributedPackingConfig, @@ -506,6 +508,8 @@ def parse_arguments(parser, json_config=None): Configuration for training log file. AimConfig Configuration for AIM stack. + HFResourceScannerConfig + Configuration for HFResourceScanner. QuantizedLoraConfig Configuration for quantized LoRA (a form of PEFT). FusedOpsAndKernelsConfig @@ -529,6 +533,7 @@ def parse_arguments(parser, json_config=None): prompt_tuning_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -547,6 +552,7 @@ def parse_arguments(parser, json_config=None): prompt_tuning_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -574,6 +580,7 @@ def parse_arguments(parser, json_config=None): tune_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -597,6 +604,7 @@ def main(): tune_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -611,7 +619,7 @@ def main(): logger.debug( "Input args parsed: \ model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ - tune_config %s, file_logger_config, %s aim_config %s, \ + tune_config %s, file_logger_config %s, aim_config %s, hf_resource_scanner_config %s, \ quantized_lora_config %s, fusedops_kernels_config %s, \ attention_and_distributed_packing_config, %s,\ mlflow_config %s, fast_moe_config %s, \ @@ -623,6 +631,7 @@ def main(): tune_config, file_logger_config, aim_config, + hf_resource_scanner_config, quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, @@ -656,6 +665,7 @@ def main(): file_logger_config=file_logger_config, aim_config=aim_config, mlflow_config=mlflow_config, + hf_resource_scanner_config=hf_resource_scanner_config, ) if training_args.output_dir: diff --git a/tuning/trackers/hf_resource_scanner_tracker.py b/tuning/trackers/hf_resource_scanner_tracker.py new file mode 100644 index 000000000..50f503df4 --- /dev/null +++ b/tuning/trackers/hf_resource_scanner_tracker.py @@ -0,0 +1,47 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import logging + +# Third Party +from HFResourceScanner import Scanner # pylint: disable=import-error + +# Local +from .tracker import Tracker +from tuning.config.tracker_configs import HFResourceScannerConfig + + +class HFResourceScannerTracker(Tracker): + def __init__(self, tracker_config: HFResourceScannerConfig): + """Tracker which encodes callback to scan for resources using HFResourceScanner + + Args: + tracker_config (HFResourceScannerConfig): An instance of HFResourceScanner + tracker config which contains the location of output file. + """ + super().__init__(name="hf_resource_scanner", tracker_config=tracker_config) + # Get logger with root log level + self.logger = logging.getLogger() + + def get_hf_callback(self): + """Returns the HFResourceScanner object associated with this tracker. + + Returns: + HFResourceScanner: The file logging callback which inherits + transformers.TrainerCallback and records the metrics to a file. + """ + output_filename = self.config.scanner_output_filename + self.hf_callback = Scanner(output_fmt=output_filename) + return self.hf_callback diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index a550250a8..8627863f4 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -29,8 +29,14 @@ AIMSTACK_TRACKER = "aim" FILE_LOGGING_TRACKER = "file_logger" MLFLOW_TRACKER = "mlflow" +HF_RESOURCE_SCANNER_TRACKER = "hf_resource_scanner" -AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER, MLFLOW_TRACKER] +AVAILABLE_TRACKERS = [ + AIMSTACK_TRACKER, + FILE_LOGGING_TRACKER, + HF_RESOURCE_SCANNER_TRACKER, + MLFLOW_TRACKER +] # Trackers which can be used @@ -39,6 +45,7 @@ # One time package check for list of external trackers. _is_aim_available = _is_package_available("aim") _is_mlflow_available = _is_package_available("mlflow") +_is_hf_resource_scanner_available = _is_package_available("HFResourceScanner") def _get_tracker_class(T, C): @@ -90,6 +97,34 @@ def _register_mlflow_tracker(): "\t pip install mlflow" ) +def _register_hf_resource_scanner_tracker(): + # pylint: disable=import-outside-toplevel + if _is_hf_resource_scanner_available: + # Local + from .hf_resource_scanner_tracker import HFResourceScannerTracker + from tuning.config.tracker_configs import HFResourceScannerConfig + + HFResourceScannerTracker = _get_tracker_class( + HFResourceScannerTracker, HFResourceScannerConfig + ) + + REGISTERED_TRACKERS[HF_RESOURCE_SCANNER_TRACKER] = HFResourceScannerTracker + logger.info("Registered HFResourceScanner tracker") + else: + logger.info( + "Not registering HFResourceScanner tracker due to unavailablity of package.\n" + "Please install HFResourceScanner if you intend to use it.\n" + "\t pip install HFResourceScanner" + ) + + +def _is_tracker_installed(name): + if name == AIMSTACK_TRACKER: + return _is_aim_available + if name == HF_RESOURCE_SCANNER_TRACKER: + return _is_hf_resource_scanner_available + return False + def _register_file_logging_tracker(): FileTracker = _get_tracker_class(FileLoggingTracker, FileLoggingTrackerConfig) @@ -109,6 +144,8 @@ def _register_trackers(): _register_file_logging_tracker() if MLFLOW_TRACKER not in REGISTERED_TRACKERS: _register_mlflow_tracker() + if HF_RESOURCE_SCANNER_TRACKER not in REGISTERED_TRACKERS: + _register_hf_resource_scanner_tracker() def _get_tracker_config_by_name(name: str, tracker_configs: TrackerConfigFactory): From 3ea9368957d800da11073f184201a73bb877804b Mon Sep 17 00:00:00 2001 From: Angel Luu Date: Tue, 17 Dec 2024 14:40:09 -0700 Subject: [PATCH 2/6] Add installation for HFResourceScanner if enabled in Dockerfile Signed-off-by: Angel Luu --- build/Dockerfile | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/build/Dockerfile b/build/Dockerfile index 9a6a5583f..99f91c040 100644 --- a/build/Dockerfile +++ b/build/Dockerfile @@ -23,6 +23,7 @@ ARG WHEEL_VERSION="" ARG ENABLE_AIM=false ARG ENABLE_MLFLOW=false ARG ENABLE_FMS_ACCELERATION=true +ARG ENABLE_SCANNER=false ## Base Layer ################################################################## FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base @@ -111,6 +112,7 @@ ARG USER ARG USER_UID ARG ENABLE_FMS_ACCELERATION ARG ENABLE_AIM +ARG ENABLE_SCANNER RUN dnf install -y git && \ # perl-Net-SSLeay.x86_64 and server_key.pem are installed with git as dependencies @@ -155,6 +157,9 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \ RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \ python -m pip install --user "$(head bdist_name)[mlflow]"; \ fi +RUN if [[ "${ENABLE_SCANNER}}" == "true" ]]; then \ + python -m pip install --user "$(head bdist_name)[scanner-dev]"; \ + fi # Clean up the wheel module. It's only needed by flash-attn install RUN python -m pip uninstall wheel build -y && \ From 39d070a28ac119f4761673d94598005f153dd555 Mon Sep 17 00:00:00 2001 From: Angel Luu Date: Tue, 17 Dec 2024 16:43:50 -0700 Subject: [PATCH 3/6] fix: remove extra } Signed-off-by: Angel Luu --- build/Dockerfile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/build/Dockerfile b/build/Dockerfile index 99f91c040..66aa40a6e 100644 --- a/build/Dockerfile +++ b/build/Dockerfile @@ -156,8 +156,9 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \ RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \ python -m pip install --user "$(head bdist_name)[mlflow]"; \ -fi -RUN if [[ "${ENABLE_SCANNER}}" == "true" ]]; then \ + fi + +RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \ python -m pip install --user "$(head bdist_name)[scanner-dev]"; \ fi From 481ef10f42db92b679c01897ea42e4a3d77da3c3 Mon Sep 17 00:00:00 2001 From: Angel Luu Date: Wed, 18 Dec 2024 10:08:58 -0700 Subject: [PATCH 4/6] test: make the test more explicit Signed-off-by: Angel Luu --- tests/build/test_launch_script.py | 15 ++++++++++----- tests/test_sft_trainer.py | 10 +++++----- .../trackers/test_hf_resource_scanner_tracker.py | 5 ++++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index 58f24eea0..c699e16da 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -16,6 +16,7 @@ """ # Standard +import json import os import tempfile import glob @@ -257,6 +258,9 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true(): def test_launch_with_HFResourceScanner_enabled(): with tempfile.TemporaryDirectory() as tempdir: setup_env(tempdir) + scanner_outfile = os.path.join( + tempdir, HFResourceScannerConfig.scanner_output_filename + ) TRAIN_KWARGS = { **BASE_LORA_KWARGS, **{ @@ -265,17 +269,18 @@ def test_launch_with_HFResourceScanner_enabled(): "lora_post_process_for_vllm": True, "gradient_accumulation_steps": 1, "trackers": ["hf_resource_scanner"], + "scanner_output_filename": scanner_outfile, }, } serialized_args = serialize_args(TRAIN_KWARGS) os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args assert main() == 0 - - scanner_outfile = os.path.join( - tempdir, HFResourceScannerConfig.scanner_output_filename - ) - assert os.path.exists(scanner_outfile) + assert os.path.exists(scanner_outfile) is True + with open(scanner_outfile, "r", encoding="utf-8") as f: + scanner_res = json.load(f) + assert scanner_res["time_data"] is not None + assert scanner_res["mem_data"] is not None def test_bad_script_path(): diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 6f08b3bfe..ae90605e7 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1081,14 +1081,14 @@ def _validate_logfile(log_file_path, check_eval=False): def _validate_hf_resource_scanner_file(tempdir): scanner_file_path = os.path.join(tempdir, "scanner_output.json") - assert os.path.exists(scanner_file_path) + assert os.path.exists(scanner_file_path) is True assert os.path.getsize(scanner_file_path) > 0 - scanner_contents = "" - with open(scanner_file_path, encoding="utf-8") as f: - scanner_contents = f.read() + with open(scanner_file_path, "r", encoding="utf-8") as f: + scanner_contents = json.load(f) - assert "ResourceScanner Memory Data:" in scanner_contents + assert scanner_contents["time_data"] is not None + assert scanner_contents["mem_data"] is not None def _get_checkpoint_path(dir_path): diff --git a/tests/trackers/test_hf_resource_scanner_tracker.py b/tests/trackers/test_hf_resource_scanner_tracker.py index 3c095042d..04ce0e20c 100644 --- a/tests/trackers/test_hf_resource_scanner_tracker.py +++ b/tests/trackers/test_hf_resource_scanner_tracker.py @@ -18,6 +18,7 @@ # Standard import copy +import os import tempfile # Third Party @@ -65,15 +66,17 @@ def test_sample_run_with_hf_resource_scanner_updated_filename(): with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ARGS) + train_args.gradient_accumulation_steps = 1 train_args.output_dir = tempdir + # add hf_resource_scanner to the list of requested tracker train_args.trackers = ["hf_resource_scanner"] scanner_output_file = "scanner_output.json" tracker_configs = TrackerConfigFactory( hf_resource_scanner_config=HFResourceScannerConfig( - scanner_output_filename=scanner_output_file + scanner_output_filename=os.path.join(tempdir, scanner_output_file) ) ) From 7841a9d552f2a1909ce51b740575eba8098558fd Mon Sep 17 00:00:00 2001 From: Angel Luu Date: Tue, 14 Jan 2025 11:47:54 -0700 Subject: [PATCH 5/6] chore: Run fmt Signed-off-by: Angel Luu --- tuning/sft_trainer.py | 2 +- tuning/trackers/tracker_factory.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 27af7078c..0bf3a3b08 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -52,8 +52,8 @@ from tuning.config.tracker_configs import ( AimConfig, FileLoggingTrackerConfig, - MLflowConfig, HFResourceScannerConfig, + MLflowConfig, TrackerConfigFactory, ) from tuning.data.setup_dataprocessor import process_dataargs diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 8627863f4..be1057fac 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -35,7 +35,7 @@ AIMSTACK_TRACKER, FILE_LOGGING_TRACKER, HF_RESOURCE_SCANNER_TRACKER, - MLFLOW_TRACKER + MLFLOW_TRACKER, ] @@ -97,6 +97,7 @@ def _register_mlflow_tracker(): "\t pip install mlflow" ) + def _register_hf_resource_scanner_tracker(): # pylint: disable=import-outside-toplevel if _is_hf_resource_scanner_available: From ee46cd242b6f2f75f6a58d29ffcdd5ba3e8794e1 Mon Sep 17 00:00:00 2001 From: Angel Luu Date: Tue, 14 Jan 2025 13:37:09 -0700 Subject: [PATCH 6/6] fix: fix unit tests Signed-off-by: Angel Luu --- tests/test_sft_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index ae90605e7..8faa3746c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -363,6 +363,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -390,6 +391,7 @@ def test_parse_arguments_defaults(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -400,14 +402,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig)