Skip to content

Commit

Permalink
Merge pull request #422 from aluu317/scanner_tracker
Browse files Browse the repository at this point in the history
feat: add scanner tracker
  • Loading branch information
aluu317 authored Jan 23, 2025
2 parents 2a9faec + ee46cd2 commit c0362ad
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 7 deletions.
8 changes: 7 additions & 1 deletion build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ARG WHEEL_VERSION=""
ARG ENABLE_AIM=false
ARG ENABLE_MLFLOW=false
ARG ENABLE_FMS_ACCELERATION=true
ARG ENABLE_SCANNER=false

## Base Layer ##################################################################
FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base
Expand Down Expand Up @@ -111,6 +112,7 @@ ARG USER
ARG USER_UID
ARG ENABLE_FMS_ACCELERATION
ARG ENABLE_AIM
ARG ENABLE_SCANNER

RUN dnf install -y git && \
# perl-Net-SSLeay.x86_64 and server_key.pem are installed with git as dependencies
Expand Down Expand Up @@ -154,7 +156,11 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \

RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[mlflow]"; \
fi
fi

RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[scanner-dev]"; \
fi

# Clean up the wheel module. It's only needed by flash-attn install
RUN python -m pip uninstall wheel build -y && \
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ aim = ["aim>=3.19.0,<4.0"]
mlflow = ["mlflow"]
fms-accel = ["fms-acceleration>=0.6"]
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
scanner-dev = ["HFResourceScanner>=0.1.0"]


[tool.setuptools.packages.find]
Expand Down
39 changes: 38 additions & 1 deletion tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
"""

# Standard
import json
import os
import tempfile
import glob

# Third Party
import pytest
from transformers.utils.import_utils import _is_package_available

# First Party
from build.accelerate_launch import main
Expand All @@ -31,7 +33,10 @@
USER_ERROR_EXIT_CODE,
INTERNAL_ERROR_EXIT_CODE,
)
from tuning.config.tracker_configs import FileLoggingTrackerConfig
from tuning.config.tracker_configs import (
FileLoggingTrackerConfig,
HFResourceScannerConfig,
)

SCRIPT = "tuning/sft_trainer.py"
MODEL_NAME = "Maykeye/TinyLLama-v0"
Expand Down Expand Up @@ -246,6 +251,38 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true():
assert os.path.exists(new_embeddings_file_path)


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_launch_with_HFResourceScanner_enabled():
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
scanner_outfile = os.path.join(
tempdir, HFResourceScannerConfig.scanner_output_filename
)
TRAIN_KWARGS = {
**BASE_LORA_KWARGS,
**{
"output_dir": tempdir,
"save_model_dir": tempdir,
"lora_post_process_for_vllm": True,
"gradient_accumulation_steps": 1,
"trackers": ["hf_resource_scanner"],
"scanner_output_filename": scanner_outfile,
},
}
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args

assert main() == 0
assert os.path.exists(scanner_outfile) is True
with open(scanner_outfile, "r", encoding="utf-8") as f:
scanner_res = json.load(f)
assert scanner_res["time_data"] is not None
assert scanner_res["mem_data"] is not None


def test_bad_script_path():
"""Check for appropriate error for an invalid training script location"""
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
26 changes: 23 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand Down Expand Up @@ -390,6 +391,7 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -400,14 +402,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down Expand Up @@ -1053,12 +1055,18 @@ def _test_run_inference(checkpoint_path):


def _validate_training(
tempdir, check_eval=False, train_logs_file="training_logs.jsonl"
tempdir,
check_eval=False,
train_logs_file="training_logs.jsonl",
check_scanner_file=False,
):
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
train_logs_file_path = "{}/{}".format(tempdir, train_logs_file)
_validate_logfile(train_logs_file_path, check_eval)

if check_scanner_file:
_validate_hf_resource_scanner_file(tempdir)


def _validate_logfile(log_file_path, check_eval=False):
train_log_contents = ""
Expand All @@ -1073,6 +1081,18 @@ def _validate_logfile(log_file_path, check_eval=False):
assert "validation_loss" in train_log_contents


def _validate_hf_resource_scanner_file(tempdir):
scanner_file_path = os.path.join(tempdir, "scanner_output.json")
assert os.path.exists(scanner_file_path) is True
assert os.path.getsize(scanner_file_path) > 0

with open(scanner_file_path, "r", encoding="utf-8") as f:
scanner_contents = json.load(f)

assert scanner_contents["time_data"] is not None
assert scanner_contents["mem_data"] is not None


def _get_checkpoint_path(dir_path):
return os.path.join(dir_path, "checkpoint-5")

Expand Down
88 changes: 88 additions & 0 deletions tests/trackers/test_hf_resource_scanner_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/


# Standard
import copy
import os
import tempfile

# Third Party
from transformers.utils.import_utils import _is_package_available
import pytest

# First Party
from tests.test_sft_trainer import (
DATA_ARGS,
MODEL_ARGS,
TRAIN_ARGS,
_get_checkpoint_path,
_test_run_causallm_ft,
_test_run_inference,
_validate_training,
)

# Local
from tuning import sft_trainer
from tuning.config.tracker_configs import HFResourceScannerConfig, TrackerConfigFactory

## HF Resource Scanner Tracker Tests


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_run_with_hf_resource_scanner_tracker():
"""Ensure that training succeeds with a good tracker name"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.trackers = ["hf_resource_scanner"]

_test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir)
_test_run_inference(_get_checkpoint_path(tempdir))


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_sample_run_with_hf_resource_scanner_updated_filename():
"""Ensure that hf_resource_scanner output filename can be updated"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.gradient_accumulation_steps = 1
train_args.output_dir = tempdir

# add hf_resource_scanner to the list of requested tracker
train_args.trackers = ["hf_resource_scanner"]

scanner_output_file = "scanner_output.json"

tracker_configs = TrackerConfigFactory(
hf_resource_scanner_config=HFResourceScannerConfig(
scanner_output_filename=os.path.join(tempdir, scanner_output_file)
)
)

sft_trainer.train(
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
)

# validate ft tuning configs
_validate_training(tempdir, check_scanner_file=True)
6 changes: 6 additions & 0 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from dataclasses import dataclass


@dataclass
class HFResourceScannerConfig:
scanner_output_filename: str = "scanner_output.json"


@dataclass
class FileLoggingTrackerConfig:
training_logs_filename: str = "training_logs.jsonl"
Expand Down Expand Up @@ -80,3 +85,4 @@ class TrackerConfigFactory:
file_logger_config: FileLoggingTrackerConfig = None
aim_config: AimConfig = None
mlflow_config: MLflowConfig = None
hf_resource_scanner_config: HFResourceScannerConfig = None
12 changes: 11 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from tuning.config.tracker_configs import (
AimConfig,
FileLoggingTrackerConfig,
HFResourceScannerConfig,
MLflowConfig,
TrackerConfigFactory,
)
Expand Down Expand Up @@ -458,6 +459,7 @@ def get_parser():
peft_config.PromptTuningConfig,
FileLoggingTrackerConfig,
AimConfig,
HFResourceScannerConfig,
QuantizedLoraConfig,
FusedOpsAndKernelsConfig,
AttentionAndDistributedPackingConfig,
Expand Down Expand Up @@ -506,6 +508,8 @@ def parse_arguments(parser, json_config=None):
Configuration for training log file.
AimConfig
Configuration for AIM stack.
HFResourceScannerConfig
Configuration for HFResourceScanner.
QuantizedLoraConfig
Configuration for quantized LoRA (a form of PEFT).
FusedOpsAndKernelsConfig
Expand All @@ -529,6 +533,7 @@ def parse_arguments(parser, json_config=None):
prompt_tuning_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -547,6 +552,7 @@ def parse_arguments(parser, json_config=None):
prompt_tuning_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand Down Expand Up @@ -574,6 +580,7 @@ def parse_arguments(parser, json_config=None):
tune_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -597,6 +604,7 @@ def main():
tune_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -611,7 +619,7 @@ def main():
logger.debug(
"Input args parsed: \
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
tune_config %s, file_logger_config, %s aim_config %s, \
tune_config %s, file_logger_config %s, aim_config %s, hf_resource_scanner_config %s, \
quantized_lora_config %s, fusedops_kernels_config %s, \
attention_and_distributed_packing_config, %s,\
mlflow_config %s, fast_moe_config %s, \
Expand All @@ -623,6 +631,7 @@ def main():
tune_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand Down Expand Up @@ -656,6 +665,7 @@ def main():
file_logger_config=file_logger_config,
aim_config=aim_config,
mlflow_config=mlflow_config,
hf_resource_scanner_config=hf_resource_scanner_config,
)

if training_args.output_dir:
Expand Down
Loading

0 comments on commit c0362ad

Please sign in to comment.