Skip to content

Commit

Permalink
fmt + lint
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Jun 13, 2024
1 parent 9580b99 commit d2d7f9b
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 152 deletions.
41 changes: 25 additions & 16 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,27 @@

# Third Party
import pytest
import yaml

# First Party
from tests.helpers import causal_lm_train_kwargs
from tests.test_sft_trainer import BASE_LORA_KWARGS

# Local
from tuning import sft_trainer
from tuning.utils.import_utils import is_fms_accelerate_available
import tuning.config.configs as config
from tuning.config.acceleration_configs import (
AccelerationFrameworkConfig, QuantizedLoraConfig
AccelerationFrameworkConfig,
QuantizedLoraConfig,
)
from tuning.config.acceleration_configs.quantized_lora_config import (
AutoGPTQLoraConfig, BNBQLoraConfig
AutoGPTQLoraConfig,
BNBQLoraConfig,
)
from tuning.utils.import_utils import is_fms_accelerate_available

# pylint: disable=import-error
if is_fms_accelerate_available():

# Third Party
from fms_acceleration.framework import KEY_PLUGINS, AccelerationFramework
from fms_acceleration.utils.test_utils import build_framework_and_maybe_instantiate

if is_fms_accelerate_available(plugins="peft"):
Expand Down Expand Up @@ -92,7 +91,8 @@ def test_construct_framework_config_with_incorrect_configurations():
"Ensure that framework configuration cannot have empty body"

with pytest.raises(
ValueError, match="AccelerationFrameworkConfig construction requires at least one dataclass"
ValueError,
match="AccelerationFrameworkConfig construction requires at least one dataclass",
):
AccelerationFrameworkConfig.from_dataclasses()

Expand All @@ -102,14 +102,18 @@ def test_construct_framework_config_with_incorrect_configurations():
):
AutoGPTQLoraConfig(from_quantized=False)

# test an invalid activation of two standalone configs.
# test an invalid activation of two standalone configs.
quantized_lora_config = QuantizedLoraConfig(
auto_gptq=AutoGPTQLoraConfig(), bnb_qlora=BNBQLoraConfig()
)
with pytest.raises(
ValueError, match="Configuration path 'peft.quantization' already has one standalone config."
ValueError,
match="Configuration path 'peft.quantization' already has one standalone config.",
):
AccelerationFrameworkConfig.from_dataclasses(quantized_lora_config).get_framework()
AccelerationFrameworkConfig.from_dataclasses(
quantized_lora_config
).get_framework()


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="peft"),
Expand All @@ -119,7 +123,9 @@ def test_construct_framework_with_auto_gptq_peft():
"Ensure that framework object is correctly configured."

quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())
acceleration_config = AccelerationFrameworkConfig.from_dataclasses(quantized_lora_config)
acceleration_config = AccelerationFrameworkConfig.from_dataclasses(
quantized_lora_config
)

# for this test we skip the require package check as second order package
# dependencies of accelerated_peft is not required
Expand All @@ -133,6 +139,7 @@ def test_construct_framework_with_auto_gptq_peft():
# the configuration file should successfully activate the plugin
assert len(framework.active_plugins) == 1


@pytest.mark.skipif(
not is_fms_accelerate_available(),
reason="Only runs if fms-accelerate is installed",
Expand All @@ -156,20 +163,22 @@ def test_framework_not_installed_or_initalized_properly():
# patch is_fms_accelerate_available to return False inside sft_trainer
# to simulate fms_acceleration not installed
with patch(
"tuning.config.acceleration_configs.acceleration_framework_config.is_fms_accelerate_available", return_value=False
"tuning.config.acceleration_configs.acceleration_framework_config."
"is_fms_accelerate_available",
return_value=False,
):
with pytest.raises(
ValueError,
match="No acceleration framework package found."
ValueError, match="No acceleration framework package found."
):
sft_trainer.train(
model_args,
data_args,
training_args,
tune_config,
quantized_lora_config=quantized_lora_config
quantized_lora_config=quantized_lora_config,
)


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="peft"),
reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin",
Expand Down Expand Up @@ -206,7 +215,7 @@ def test_framework_intialized_properly():
training_args,
tune_config,
# acceleration_framework_args=framework_args,
quantized_lora_config=quantized_lora_config
quantized_lora_config=quantized_lora_config,
)

# spy to ensure that the plugin functions were called.
Expand Down
4 changes: 2 additions & 2 deletions tuning/config/acceleration_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .acceleration_framework_config import AccelerationFrameworkConfig

from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
from .quantized_lora_config import QuantizedLoraConfig
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
Loading

0 comments on commit d2d7f9b

Please sign in to comment.