Skip to content

Commit

Permalink
remove test helpers
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Jun 20, 2024
1 parent 06c4872 commit 428b4d9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 122 deletions.
132 changes: 57 additions & 75 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

# Standard
from dataclasses import dataclass
import copy
from dataclasses import dataclass, replace
from typing import Annotated
from unittest.mock import patch
import tempfile
Expand All @@ -23,8 +24,7 @@
import torch

# First Party
from tests.helpers import causal_lm_train_kwargs
from tests.test_sft_trainer import BASE_FT_KWARGS, BASE_LORA_KWARGS
from tests.test_sft_trainer import MODEL_ARGS, DATA_ARGS, TRAIN_ARGS, PEFT_LORA_ARGS

# Local
from .spying_utils import create_mock_plugin_class_and_spy
Expand Down Expand Up @@ -214,52 +214,45 @@ def test_framework_raises_if_used_with_missing_package():
"""Ensure that trying the use the framework, without first installing fms_acceleration
will raise.
"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_LORA_KWARGS,
**{
"output_dir": tempdir,
},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = None

# patch is_fms_accelerate_available to return False inside sft_trainer
# to simulate fms_acceleration not installed
with patch(
"tuning.config.acceleration_configs.acceleration_framework_config."
"is_fms_accelerate_available",
return_value=False,
quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())

# patch is_fms_accelerate_available to return False inside sft_trainer
# to simulate fms_acceleration not installed
with patch(
"tuning.config.acceleration_configs.acceleration_framework_config."
"is_fms_accelerate_available",
return_value=False,
):
with pytest.raises(
ValueError, match="No acceleration framework package found."
):
with pytest.raises(
ValueError, match="No acceleration framework package found."
):
sft_trainer.train(
model_args,
data_args,
training_args,
tune_config,
quantized_lora_config=quantized_lora_config,
)
sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
TRAIN_ARGS,
PEFT_LORA_ARGS,
quantized_lora_config=quantized_lora_config,
)


invalid_kwargs_map = [
(
{
**BASE_LORA_KWARGS,
"model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
},
PEFT_LORA_ARGS,
AssertionError,
"need to run in fp16 mixed precision or load model in fp16",
),
(
{
**BASE_FT_KWARGS,
"model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
"torch_dtype": torch.float16,
},
None,
AssertionError,
"need peft_config to install PEFT adapters",
),
Expand All @@ -271,34 +264,32 @@ def test_framework_raises_if_used_with_missing_package():
reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin",
)
@pytest.mark.parametrize(
"train_kwargs,exception",
"bad_kwargs,peft_config,exception,exception_msg",
invalid_kwargs_map,
ids=["triton_v2 requires fp16", "accelerated peft requires peft config"],
)
def test_framework_raises_due_to_invalid_arguments(
bad_train_kwargs, exception, exception_msg
bad_kwargs, peft_config, exception, exception_msg
):
"""Ensure that invalid arguments will be checked by the activated framework
plugin.
"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**bad_train_kwargs,
"output_dir": tempdir,
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
model_args = copy.deepcopy(MODEL_ARGS)
model_args = replace(model_args, **bad_kwargs)
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())

# 1. activate the accelerated peft plugin
# 2. demonstrate that the invalid arguments will be checked
with pytest.raises(exception, match=exception_msg):
sft_trainer.train(
model_args,
data_args,
training_args,
tune_config,
DATA_ARGS,
train_args,
peft_config,
quantized_lora_config=quantized_lora_config,
)

Expand Down Expand Up @@ -342,18 +333,14 @@ def test_framework_intialized_properly_peft(
properly activates the framework plugin and runs the train sucessfully.
"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_LORA_KWARGS,
**{
"fp16": True,
"model_name_or_path": model_name_or_path,
"output_dir": tempdir,
"save_strategy": "no",
},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = model_name_or_path
model_args.torch_dtype = torch.float16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = 'no'
train_args.fp16 = True

installation_path, (MockedPlugin, spy) = mock_and_spy

# 1. mock a plugin class
Expand All @@ -365,9 +352,9 @@ def test_framework_intialized_properly_peft(
):
sft_trainer.train(
model_args,
data_args,
training_args,
tune_config,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
quantized_lora_config=quantized_lora_config,
)

Expand All @@ -393,19 +380,14 @@ def test_framework_intialized_properly_foak():
properly activates the framework plugin and runs the train sucessfully.
"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_LORA_KWARGS,
**{
"fp16": True,
"torch_dtype": torch.float16,
"model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
"output_dir": tempdir,
"save_strategy": "no",
},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)

model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"
model_args.torch_dtype = torch.float16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = 'no'
train_args.fp16 = True

# setup default quantized lora args dataclass
# - with auth gptq as the quantized method
Expand Down Expand Up @@ -437,9 +419,9 @@ def test_framework_intialized_properly_foak():
):
sft_trainer.train(
model_args,
data_args,
training_args,
tune_config,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
quantized_lora_config=quantized_lora_config,
fusedops_kernels_config=fusedops_kernels_config,
)
Expand Down
47 changes: 0 additions & 47 deletions tests/helpers.py

This file was deleted.

0 comments on commit 428b4d9

Please sign in to comment.