Skip to content

Commit

Permalink
Enable pylint in the github workflow
Browse files Browse the repository at this point in the history
Signed-off-by: ted chang <htchang@us.ibm.com>
  • Loading branch information
tedhtchang committed Feb 29, 2024
1 parent 5a0cf5c commit 3ae3d2f
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 50 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,6 @@ jobs:
python -m pip install -r setup_requirements.txt
- name: Check Formatting
run: tox -e fmt
- name: Run pylint
run: tox -e lint

17 changes: 12 additions & 5 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,15 @@ def _apply_config_changes(self, overrides: dict) -> dict:
# If we have no overrides, this context manager is a noop; no need to do anything
if not overrides:
return {}
with open(self.config_path, "r") as config_file:
with open(
self.config_path, "r"
) as config_file: # pylint: disable=unspecified-encoding
adapter_config = json.load(config_file)
overridden_values = self._get_old_config_values(adapter_config, overrides)
adapter_config = {**adapter_config, **overrides}
with open(self.config_path, "w") as config_file:
with open(
self.config_path, "w"
) as config_file: # pylint: disable=unspecified-encoding
json.dump(adapter_config, config_file, indent=4)
return overridden_values

Expand Down Expand Up @@ -213,7 +217,8 @@ def main():
)
parser.add_argument(
"--base_model_name_or_path",
help="Override for base model to be used for non-merged models [default: value in model adapter_config.json]",
help="Override for base model to be used for non-merged models \
[default: value in model adapter_config.json]",
default=None,
)
parser.add_argument(
Expand Down Expand Up @@ -243,7 +248,9 @@ def main():
if args.text:
texts = [args.text]
else:
with open(args.text_file, "r") as text_file:
with open(
args.text_file, "r"
) as text_file: # pylint: disable=unspecified-encoding
texts = [line.strip() for line in text_file.readlines()]

# TODO: we should add batch inference support
Expand All @@ -256,7 +263,7 @@ def main():
]

# Export the results to a file
with open(args.out_file, "w") as out_file:
with open(args.out_file, "w") as out_file: # pylint: disable=unspecified-encoding
json.dump(results, out_file, sort_keys=True, indent=4)

print(f"Exported results to: {args.out_file}")
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ allowlist_externals = ./scripts/fmt.sh
[testenv:lint]
description = lint with pylint
deps = pylint>=2.16.2,<=3.1.0
-r requirements.txt
commands = pylint tuning scripts/*.py
allowlist_externals = pylint
5 changes: 3 additions & 2 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Standard
from dataclasses import dataclass, field
from typing import Dict, Optional, Union
from typing import Optional, Union

# Third Party
import torch
Expand Down Expand Up @@ -50,7 +50,8 @@ class TrainingArguments(transformers.TrainingArguments):
model_max_length: int = field(
default=DEFAULT_CONTEXT_LENGTH,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
"help": "Maximum sequence length. Sequences will be right padded \
(and possibly truncated)."
},
)
packing: bool = field(
Expand Down
6 changes: 4 additions & 2 deletions tuning/config/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ class LoraConfig:
target_modules: List[str] = field(
default_factory=lambda: ["q_proj", "v_proj"],
metadata={
"help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or "
'end with one of the strings. If the value is ["all-linear"], then LORA selects all linear and Conv1D '
"help": "The names of the modules to apply LORA to. LORA selects modules which either \
completely match or "
'end with one of the strings. If the value is ["all-linear"], \
then LORA selects all linear and Conv1D '
"modules except for the output layer."
},
)
Expand Down
10 changes: 1 addition & 9 deletions tuning/data/tokenizer_data_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
# Standard
from typing import Dict, Sequence
import copy
import json
import logging
from typing import Dict

# Third Party
from torch.utils.data import Dataset
import torch
import transformers

# Local
from tuning.config import configs


def tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
Expand Down
62 changes: 34 additions & 28 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Union
import json
import os
import sys

# Third Party
from peft.utils.other import fsdp_auto_wrap_policy
Expand Down Expand Up @@ -88,7 +89,7 @@ def train(
model_args: configs.ModelArguments,
data_args: configs.DataArguments,
train_args: configs.TrainingArguments,
peft_config: Optional[
peft_configs: Optional[
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
] = None,
):
Expand All @@ -98,7 +99,7 @@ def train(
model_args: tuning.config.configs.ModelArguments
data_args: tuning.config.configs.DataArguments
train_args: tuning.config.configs.TrainingArguments
peft_config: peft_config.LoraConfig for Lora tuning | \
peft_configs: peft_config.LoraConfig for Lora tuning | \
peft_config.PromptTuningConfig for prompt tuning | \
None for fine tuning
The peft configuration to pass to trainer
Expand Down Expand Up @@ -130,7 +131,7 @@ def train(
use_flash_attention_2=model_args.use_flash_attn,
)

peft_config = get_hf_peft_config(task_type, peft_config)
peft_configs = get_hf_peft_config(task_type, peft_configs)

model.gradient_checkpointing_enable()

Expand All @@ -140,9 +141,7 @@ def train(
)

# TODO: understand if we need to hardcode these here or just use defaults in model
if isinstance(tokenizer, LlamaTokenizer) or isinstance(
tokenizer, LlamaTokenizerFast
):
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
tokenizer.add_special_tokens(
{
"bos_token": "<s>",
Expand All @@ -151,33 +150,36 @@ def train(
"pad_token": "<pad>",
}
)
elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(
tokenizer, GPT2Tokenizer
):
elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
tokenizer.add_special_tokens(
{
"pad_token": "<pad>",
}
)

"""TODO: near term - how response template ids are parsed out needs to be cleaned.
The [2:] here applies if response template has \n prefix, it is needed to strip \n, otherwise template is not found.
We will create issue to clean this out after we discuss data formats and collators we will support
"""
# TODO: near term - how response template ids are parsed out needs to be cleaned.
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
# otherwise template is not found. We will create issue to clean this out after we discuss
# data formats and collators we will support.
response_template_ids = tokenizer.encode(
data_args.response_template, add_special_tokens=False
)[2:]
# TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length
# as in current main. We need to change name of this parameter we expose to users.
# TODO: This is actually max_seq_length and not model_max_length. we should not override
# model_max_length as in current main. We need to change name of this parameter we expose
# to users.
model_max_length = min(train_args.model_max_length, tokenizer.model_max_length)
logger.info(f"Model max length {model_max_length}")
logger.info(
f"Model max length {model_max_length}"
) # pylint: disable=logging-fstring-interpolation
if train_args.model_max_length > tokenizer.model_max_length:
logger.warning(
f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}"
logger.warning( # pylint: disable=logging-fstring-interpolation
f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length \
{tokenizer.model_max_length}, using tokenizer.model_max_length \
{tokenizer.model_max_length}"
)

# TODO: we need to change this, perhaps follow what open instruct does?
special_tokens_dict = dict()
special_tokens_dict = {}
if tokenizer.pad_token is None:
logger.warning("PAD token set to default, missing in tokenizer")
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
Expand Down Expand Up @@ -205,19 +207,23 @@ def train(
if data_args.validation_data_path:
data_files["validation"] = data_args.validation_data_path

format_dataset = lambda example: {
format_dataset = lambda example: { # pylint: disable=unnecessary-lambda-assignment
f"{data_args.dataset_text_field}": example[f"{data_args.dataset_text_field}"]
+ tokenizer.eos_token
}

json_dataset = datasets.load_dataset("json", data_files=data_files)
formatted_train_dataset = json_dataset["train"].map(format_dataset)
logger.info(f"Training dataset length is {len(formatted_train_dataset)}")
logger.info(
f"Training dataset length is {len(formatted_train_dataset)}"
) # pylint: disable=logging-fstring-interpolation

formatted_validation_dataset = None
if data_args.validation_data_path:
formatted_validation_dataset = json_dataset["validation"].map(format_dataset)
logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}")
logger.info(
f"Validation dataset length is {len(formatted_validation_dataset)}"
) # pylint: disable=logging-fstring-interpolation

aim_callback = get_aimstack_callback()
file_logger_callback = FileLoggingCallback(logger)
Expand All @@ -234,13 +240,13 @@ def train(
logger.error(
"Error, response template is None, needs to be set for training"
)
exit(-1)
sys.exit(-1)

if data_args.dataset_text_field is None:
logger.error(
"Error, dataset_text_field is None, needs to be set for training"
)
exit(-1)
sys.exit(-1)

data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids,
Expand All @@ -260,17 +266,17 @@ def train(
args=train_args,
max_seq_length=model_max_length,
callbacks=callbacks,
peft_config=peft_config,
peft_config=peft_configs,
)

if run_distributed and peft_config is not None:
if run_distributed and peft_configs is not None:
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
model
)
trainer.train()


def main(**kwargs):
def main():
parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
Expand All @@ -293,7 +299,7 @@ def main(**kwargs):
lora_config,
prompt_tuning_config,
peft_method,
_,
*_,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if peft_method.peft_method == "lora":
tune_config = lora_config
Expand Down
3 changes: 2 additions & 1 deletion tuning/utils/data_type_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard
from typing import Union
import sys

# Third Party
from transformers.utils import logging
Expand All @@ -22,7 +23,7 @@ def str_to_torch_dtype(dtype_str: str) -> torch.dtype:
dt = getattr(torch, dtype_str, None)
if not isinstance(dt, torch.dtype):
logger.error(" ValueError: Unrecognized data type of a torch.Tensor")
exit(-1)
sys.exit(-1)
return dt


Expand Down
5 changes: 2 additions & 3 deletions tuning/utils/merge_model_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Standard
from typing import Union
import argparse
import json
import os

Expand All @@ -27,7 +26,7 @@ def create_merged_model(
References:
- https://github.com/huggingface/peft/issues/1040
- https://github.com/huggingface/peft/issues/280#issuecomment-1500805831
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter # pylint: disable=line-too-long
Args:
checkpoint_model: Union[str, list[str]]
Expand Down Expand Up @@ -82,7 +81,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
if not os.path.isfile(adapter_config):
raise FileNotFoundError("Unable to locate adapter config to infer base model!")

with open(adapter_config, "r") as cfg:
with open(adapter_config, "r") as cfg: # pylint: disable=unspecified-encoding
adapter_dict = json.load(cfg)
if "base_model_name_or_path" not in adapter_dict:
raise KeyError(
Expand Down

0 comments on commit 3ae3d2f

Please sign in to comment.