From 428b4d9d12b51dbbedb0e294c1f291445814a757 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 20 Jun 2024 03:37:13 +0000 Subject: [PATCH] remove test helpers Signed-off-by: Yu Chin Fabian Lim --- .../test_acceleration_framework.py | 132 ++++++++---------- tests/helpers.py | 47 ------- 2 files changed, 57 insertions(+), 122 deletions(-) delete mode 100644 tests/helpers.py diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index 38d7d6bf9..3a133e2f2 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -13,7 +13,8 @@ # limitations under the License. # Standard -from dataclasses import dataclass +import copy +from dataclasses import dataclass, replace from typing import Annotated from unittest.mock import patch import tempfile @@ -23,8 +24,7 @@ import torch # First Party -from tests.helpers import causal_lm_train_kwargs -from tests.test_sft_trainer import BASE_FT_KWARGS, BASE_LORA_KWARGS +from tests.test_sft_trainer import MODEL_ARGS, DATA_ARGS, TRAIN_ARGS, PEFT_LORA_ARGS # Local from .spying_utils import create_mock_plugin_class_and_spy @@ -214,52 +214,45 @@ def test_framework_raises_if_used_with_missing_package(): """Ensure that trying the use the framework, without first installing fms_acceleration will raise. """ - with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_LORA_KWARGS, - **{ - "output_dir": tempdir, - }, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) - quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = None - # patch is_fms_accelerate_available to return False inside sft_trainer - # to simulate fms_acceleration not installed - with patch( - "tuning.config.acceleration_configs.acceleration_framework_config." - "is_fms_accelerate_available", - return_value=False, + quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) + + # patch is_fms_accelerate_available to return False inside sft_trainer + # to simulate fms_acceleration not installed + with patch( + "tuning.config.acceleration_configs.acceleration_framework_config." + "is_fms_accelerate_available", + return_value=False, + ): + with pytest.raises( + ValueError, match="No acceleration framework package found." ): - with pytest.raises( - ValueError, match="No acceleration framework package found." - ): - sft_trainer.train( - model_args, - data_args, - training_args, - tune_config, - quantized_lora_config=quantized_lora_config, - ) + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + TRAIN_ARGS, + PEFT_LORA_ARGS, + quantized_lora_config=quantized_lora_config, + ) invalid_kwargs_map = [ ( { - **BASE_LORA_KWARGS, "model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", }, + PEFT_LORA_ARGS, AssertionError, "need to run in fp16 mixed precision or load model in fp16", ), ( { - **BASE_FT_KWARGS, "model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", "torch_dtype": torch.float16, }, + None, AssertionError, "need peft_config to install PEFT adapters", ), @@ -271,24 +264,22 @@ def test_framework_raises_if_used_with_missing_package(): reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin", ) @pytest.mark.parametrize( - "train_kwargs,exception", + "bad_kwargs,peft_config,exception,exception_msg", invalid_kwargs_map, ids=["triton_v2 requires fp16", "accelerated peft requires peft config"], ) def test_framework_raises_due_to_invalid_arguments( - bad_train_kwargs, exception, exception_msg + bad_kwargs, peft_config, exception, exception_msg ): """Ensure that invalid arguments will be checked by the activated framework plugin. """ with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **bad_train_kwargs, - "output_dir": tempdir, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + model_args = copy.deepcopy(MODEL_ARGS) + model_args = replace(model_args, **bad_kwargs) + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) # 1. activate the accelerated peft plugin @@ -296,9 +287,9 @@ def test_framework_raises_due_to_invalid_arguments( with pytest.raises(exception, match=exception_msg): sft_trainer.train( model_args, - data_args, - training_args, - tune_config, + DATA_ARGS, + train_args, + peft_config, quantized_lora_config=quantized_lora_config, ) @@ -342,18 +333,14 @@ def test_framework_intialized_properly_peft( properly activates the framework plugin and runs the train sucessfully. """ with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_LORA_KWARGS, - **{ - "fp16": True, - "model_name_or_path": model_name_or_path, - "output_dir": tempdir, - "save_strategy": "no", - }, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = model_name_or_path + model_args.torch_dtype = torch.float16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = 'no' + train_args.fp16 = True + installation_path, (MockedPlugin, spy) = mock_and_spy # 1. mock a plugin class @@ -365,9 +352,9 @@ def test_framework_intialized_properly_peft( ): sft_trainer.train( model_args, - data_args, - training_args, - tune_config, + DATA_ARGS, + train_args, + PEFT_LORA_ARGS, quantized_lora_config=quantized_lora_config, ) @@ -393,19 +380,14 @@ def test_framework_intialized_properly_foak(): properly activates the framework plugin and runs the train sucessfully. """ with tempfile.TemporaryDirectory() as tempdir: - TRAIN_KWARGS = { - **BASE_LORA_KWARGS, - **{ - "fp16": True, - "torch_dtype": torch.float16, - "model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - "output_dir": tempdir, - "save_strategy": "no", - }, - } - model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( - TRAIN_KWARGS - ) + + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ" + model_args.torch_dtype = torch.float16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = 'no' + train_args.fp16 = True # setup default quantized lora args dataclass # - with auth gptq as the quantized method @@ -437,9 +419,9 @@ def test_framework_intialized_properly_foak(): ): sft_trainer.train( model_args, - data_args, - training_args, - tune_config, + DATA_ARGS, + train_args, + PEFT_LORA_ARGS, quantized_lora_config=quantized_lora_config, fusedops_kernels_config=fusedops_kernels_config, ) diff --git a/tests/helpers.py b/tests/helpers.py deleted file mode 100644 index 221697973..000000000 --- a/tests/helpers.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright The FMS HF Tuning Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Third Party -import transformers - -# Local -from tuning.config import configs, peft_config - - -def causal_lm_train_kwargs(train_kwargs): - """Parse the kwargs for a valid train call to a Causal LM.""" - parser = transformers.HfArgumentParser( - dataclass_types=( - configs.ModelArguments, - configs.DataArguments, - configs.TrainingArguments, - peft_config.LoraConfig, - peft_config.PromptTuningConfig, - ) - ) - ( - model_args, - data_args, - training_args, - lora_config, - prompt_tuning_config, - ) = parser.parse_dict(train_kwargs, allow_extra_keys=True) - return ( - model_args, - data_args, - training_args, - lora_config - if train_kwargs.get("peft_method") == "lora" - else (None if train_kwargs.get("peft_method") == "" else prompt_tuning_config,), - )