diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index d926b1220..f8cb15e72 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -35,4 +35,6 @@ jobs: python -m pip install -r setup_requirements.txt - name: Check Formatting run: tox -e fmt + - name: Run pylint + run: tox -e lint diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 989aaa8ca..bf79ad3d3 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -77,11 +77,15 @@ def _apply_config_changes(self, overrides: dict) -> dict: # If we have no overrides, this context manager is a noop; no need to do anything if not overrides: return {} - with open(self.config_path, "r") as config_file: + with open( + self.config_path, "r" + ) as config_file: # pylint: disable=unspecified-encoding adapter_config = json.load(config_file) overridden_values = self._get_old_config_values(adapter_config, overrides) adapter_config = {**adapter_config, **overrides} - with open(self.config_path, "w") as config_file: + with open( + self.config_path, "w" + ) as config_file: # pylint: disable=unspecified-encoding json.dump(adapter_config, config_file, indent=4) return overridden_values @@ -213,7 +217,8 @@ def main(): ) parser.add_argument( "--base_model_name_or_path", - help="Override for base model to be used for non-merged models [default: value in model adapter_config.json]", + help="Override for base model to be used for non-merged models \ + [default: value in model adapter_config.json]", default=None, ) parser.add_argument( @@ -243,7 +248,9 @@ def main(): if args.text: texts = [args.text] else: - with open(args.text_file, "r") as text_file: + with open( + args.text_file, "r" + ) as text_file: # pylint: disable=unspecified-encoding texts = [line.strip() for line in text_file.readlines()] # TODO: we should add batch inference support @@ -256,7 +263,7 @@ def main(): ] # Export the results to a file - with open(args.out_file, "w") as out_file: + with open(args.out_file, "w") as out_file: # pylint: disable=unspecified-encoding json.dump(results, out_file, sort_keys=True, indent=4) print(f"Exported results to: {args.out_file}") diff --git a/tox.ini b/tox.ini index 3ca14c330..5ec9a4620 100644 --- a/tox.ini +++ b/tox.ini @@ -9,5 +9,6 @@ allowlist_externals = ./scripts/fmt.sh [testenv:lint] description = lint with pylint deps = pylint>=2.16.2,<=3.1.0 + -r requirements.txt commands = pylint tuning scripts/*.py allowlist_externals = pylint diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 0b0a8fb67..e34512a94 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -1,6 +1,6 @@ # Standard from dataclasses import dataclass, field -from typing import Dict, Optional, Union +from typing import Optional, Union # Third Party import torch @@ -50,7 +50,8 @@ class TrainingArguments(transformers.TrainingArguments): model_max_length: int = field( default=DEFAULT_CONTEXT_LENGTH, metadata={ - "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + "help": "Maximum sequence length. Sequences will be right padded \ + (and possibly truncated)." }, ) packing: bool = field( diff --git a/tuning/config/peft_config.py b/tuning/config/peft_config.py index a3d30c763..0c70eb221 100644 --- a/tuning/config/peft_config.py +++ b/tuning/config/peft_config.py @@ -10,8 +10,10 @@ class LoraConfig: target_modules: List[str] = field( default_factory=lambda: ["q_proj", "v_proj"], metadata={ - "help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or " - 'end with one of the strings. If the value is ["all-linear"], then LORA selects all linear and Conv1D ' + "help": "The names of the modules to apply LORA to. LORA selects modules which either \ + completely match or " + 'end with one of the strings. If the value is ["all-linear"], \ + then LORA selects all linear and Conv1D ' "modules except for the output layer." }, ) diff --git a/tuning/data/tokenizer_data_utils.py b/tuning/data/tokenizer_data_utils.py index 3a8a288f3..4177136c8 100644 --- a/tuning/data/tokenizer_data_utils.py +++ b/tuning/data/tokenizer_data_utils.py @@ -1,17 +1,9 @@ # Standard -from typing import Dict, Sequence -import copy -import json -import logging +from typing import Dict # Third Party -from torch.utils.data import Dataset -import torch import transformers -# Local -from tuning.config import configs - def tokenizer_and_embedding_resize( special_tokens_dict: Dict, diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index ce9e323e0..404d78a82 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -3,6 +3,7 @@ from typing import Optional, Union import json import os +import sys # Third Party from peft.utils.other import fsdp_auto_wrap_policy @@ -88,7 +89,7 @@ def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, train_args: configs.TrainingArguments, - peft_config: Optional[ + peft_configs: Optional[ Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, ): @@ -98,7 +99,7 @@ def train( model_args: tuning.config.configs.ModelArguments data_args: tuning.config.configs.DataArguments train_args: tuning.config.configs.TrainingArguments - peft_config: peft_config.LoraConfig for Lora tuning | \ + peft_configs: peft_config.LoraConfig for Lora tuning | \ peft_config.PromptTuningConfig for prompt tuning | \ None for fine tuning The peft configuration to pass to trainer @@ -130,7 +131,7 @@ def train( use_flash_attention_2=model_args.use_flash_attn, ) - peft_config = get_hf_peft_config(task_type, peft_config) + peft_configs = get_hf_peft_config(task_type, peft_configs) model.gradient_checkpointing_enable() @@ -140,9 +141,7 @@ def train( ) # TODO: understand if we need to hardcode these here or just use defaults in model - if isinstance(tokenizer, LlamaTokenizer) or isinstance( - tokenizer, LlamaTokenizerFast - ): + if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): tokenizer.add_special_tokens( { "bos_token": "", @@ -151,33 +150,36 @@ def train( "pad_token": "", } ) - elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance( - tokenizer, GPT2Tokenizer - ): + elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)): tokenizer.add_special_tokens( { "pad_token": "", } ) - """TODO: near term - how response template ids are parsed out needs to be cleaned. - The [2:] here applies if response template has \n prefix, it is needed to strip \n, otherwise template is not found. - We will create issue to clean this out after we discuss data formats and collators we will support - """ + # TODO: near term - how response template ids are parsed out needs to be cleaned. + # The [2:] here applies if response template has \n prefix, it is needed to strip \n, + # otherwise template is not found. We will create issue to clean this out after we discuss + # data formats and collators we will support. response_template_ids = tokenizer.encode( data_args.response_template, add_special_tokens=False )[2:] - # TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length - # as in current main. We need to change name of this parameter we expose to users. + # TODO: This is actually max_seq_length and not model_max_length. we should not override + # model_max_length as in current main. We need to change name of this parameter we expose + # to users. model_max_length = min(train_args.model_max_length, tokenizer.model_max_length) - logger.info(f"Model max length {model_max_length}") + logger.info( + f"Model max length {model_max_length}" + ) # pylint: disable=logging-fstring-interpolation if train_args.model_max_length > tokenizer.model_max_length: - logger.warning( - f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}" + logger.warning( # pylint: disable=logging-fstring-interpolation + f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length \ + {tokenizer.model_max_length}, using tokenizer.model_max_length \ + {tokenizer.model_max_length}" ) # TODO: we need to change this, perhaps follow what open instruct does? - special_tokens_dict = dict() + special_tokens_dict = {} if tokenizer.pad_token is None: logger.warning("PAD token set to default, missing in tokenizer") special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN @@ -205,19 +207,23 @@ def train( if data_args.validation_data_path: data_files["validation"] = data_args.validation_data_path - format_dataset = lambda example: { + format_dataset = lambda example: { # pylint: disable=unnecessary-lambda-assignment f"{data_args.dataset_text_field}": example[f"{data_args.dataset_text_field}"] + tokenizer.eos_token } json_dataset = datasets.load_dataset("json", data_files=data_files) formatted_train_dataset = json_dataset["train"].map(format_dataset) - logger.info(f"Training dataset length is {len(formatted_train_dataset)}") + logger.info( + f"Training dataset length is {len(formatted_train_dataset)}" + ) # pylint: disable=logging-fstring-interpolation formatted_validation_dataset = None if data_args.validation_data_path: formatted_validation_dataset = json_dataset["validation"].map(format_dataset) - logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}") + logger.info( + f"Validation dataset length is {len(formatted_validation_dataset)}" + ) # pylint: disable=logging-fstring-interpolation aim_callback = get_aimstack_callback() file_logger_callback = FileLoggingCallback(logger) @@ -234,13 +240,13 @@ def train( logger.error( "Error, response template is None, needs to be set for training" ) - exit(-1) + sys.exit(-1) if data_args.dataset_text_field is None: logger.error( "Error, dataset_text_field is None, needs to be set for training" ) - exit(-1) + sys.exit(-1) data_collator = DataCollatorForCompletionOnlyLM( response_template_ids, @@ -260,17 +266,17 @@ def train( args=train_args, max_seq_length=model_max_length, callbacks=callbacks, - peft_config=peft_config, + peft_config=peft_configs, ) - if run_distributed and peft_config is not None: + if run_distributed and peft_configs is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model ) trainer.train() -def main(**kwargs): +def main(): parser = transformers.HfArgumentParser( dataclass_types=( configs.ModelArguments, @@ -293,7 +299,7 @@ def main(**kwargs): lora_config, prompt_tuning_config, peft_method, - _, + *_, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) if peft_method.peft_method == "lora": tune_config = lora_config diff --git a/tuning/utils/data_type_utils.py b/tuning/utils/data_type_utils.py index 42b058cde..375d04aeb 100644 --- a/tuning/utils/data_type_utils.py +++ b/tuning/utils/data_type_utils.py @@ -1,5 +1,6 @@ # Standard from typing import Union +import sys # Third Party from transformers.utils import logging @@ -22,7 +23,7 @@ def str_to_torch_dtype(dtype_str: str) -> torch.dtype: dt = getattr(torch, dtype_str, None) if not isinstance(dt, torch.dtype): logger.error(" ValueError: Unrecognized data type of a torch.Tensor") - exit(-1) + sys.exit(-1) return dt diff --git a/tuning/utils/merge_model_utils.py b/tuning/utils/merge_model_utils.py index a8a41fecb..b8809bd87 100644 --- a/tuning/utils/merge_model_utils.py +++ b/tuning/utils/merge_model_utils.py @@ -1,6 +1,5 @@ # Standard from typing import Union -import argparse import json import os @@ -27,7 +26,7 @@ def create_merged_model( References: - https://github.com/huggingface/peft/issues/1040 - https://github.com/huggingface/peft/issues/280#issuecomment-1500805831 - - https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter + - https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter # pylint: disable=line-too-long Args: checkpoint_model: Union[str, list[str]] @@ -82,7 +81,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str: if not os.path.isfile(adapter_config): raise FileNotFoundError("Unable to locate adapter config to infer base model!") - with open(adapter_config, "r") as cfg: + with open(adapter_config, "r") as cfg: # pylint: disable=unspecified-encoding adapter_dict = json.load(cfg) if "base_model_name_or_path" not in adapter_dict: raise KeyError(