Skip to content

Commit

Permalink
Allow running DPO from a local model (#49)
Browse files Browse the repository at this point in the history
* Update model_utils.py

Check if a model is adapter model when a local path is supplied instead of HF model

* Cleaner solution, thanks to lewtun
  • Loading branch information
dmilcevski authored Nov 27, 2023
1 parent f025057 commit 80e952e
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/alignment/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Dict

import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer

from accelerate import Accelerator
from huggingface_hub import list_repo_files
from huggingface_hub.utils._validators import HFValidationError
from peft import LoraConfig, PeftConfig

from .configs import DataArguments, ModelArguments
Expand Down Expand Up @@ -96,5 +97,10 @@ def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:


def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
repo_files = list_repo_files(model_name_or_path, revision=revision)
try:
# Try first if model on a Hub repo
repo_files = list_repo_files(model_name_or_path, revision=revision)
except HFValidationError:
# If not, check local repo
repo_files = os.listdir(model_name_or_path)
return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files

0 comments on commit 80e952e

Please sign in to comment.