Skip to content
This repository has been archived by the owner on Jul 2, 2024. It is now read-only.

Commit

Permalink
Error catching
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed May 5, 2024
1 parent 25dfea4 commit 439c26a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 62 deletions.
2 changes: 1 addition & 1 deletion adapters/ukp/gpt2_sentiment_sst-2_houlsby.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
type: text_task

# The string identifier of the task this adapter belongs to.
task: 'sentiment '
task: sentiment

# The string identifier of the subtask this adapter belongs to.
subtask: sst-2
Expand Down
149 changes: 88 additions & 61 deletions migrate_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


OUTPUT_FOLDER = "hf_hub"
ERROR_FILE = "migration_errors.txt"
HUB_URL = "https://github.com/Adapter-Hub/Hub/blob/master/"
# Map from head types to HF labels used for widgets
MODEL_HEAD_MAP = {
Expand Down Expand Up @@ -207,6 +208,77 @@ def create_adapter_card(
return adapter_card.strip()


def migrate_file(
file: str,
push: bool,
hf_org_name: str,
skip_existing: bool,
subtasks_dict: dict,
api=None,
):
adapter_name = os.path.basename(file).split(".")[0]
print(f"Migrating {adapter_name} ...")
with open(file, "r") as f:
data = yaml.load(f, yaml.FullLoader)
subtask_info = subtasks_dict.get(data["task"] + "/" + data["subtask"])

if push and skip_existing:
if api.repo_exists(repo_id=hf_org_name + "/" + adapter_name):
print(f"Skipping {adapter_name} as it already exists.")
return

# create a subfolder for each version in the output
for version_data in data["files"]:
version = version_data["version"]
is_default = version == data["default_version"]
version_folder = os.path.join(OUTPUT_FOLDER, adapter_name, version)
os.makedirs(os.path.dirname(version_folder), exist_ok=True)

# download the checkpoint
dl_folder = download_cached(version_data["url"])
shutil.move(dl_folder, version_folder)

# try loading the adapter
model = AutoAdapterModel.from_pretrained(data["model_name"])
loaded_name = model.load_adapter(version_folder, set_active=True)
model.save_adapter(version_folder, loaded_name)

if loaded_name in model.heads:
head_type = model.heads[loaded_name].config["head_type"]
else:
head_type = None

adapter_card = create_adapter_card(
file,
adapter_name,
data,
subtask_info,
version=version,
head_type=head_type,
hf_org_name=hf_org_name,
)

# write the adapter card
with open(os.path.join(version_folder, "README.md"), "w") as f:
f.write(adapter_card)

del model

if push:
repo_id = hf_org_name + "/" + adapter_name
api.create_repo(repo_id, exist_ok=True)
if not is_default:
api.create_branch(repo_id=repo_id, branch=version, exist_ok=True)
api.upload_folder(
repo_id=repo_id,
folder_path=version_folder,
revision="main" if is_default else version,
commit_message=f"Add adapter {adapter_name} version {version}",
)
if is_default:
api.create_branch(repo_id=repo_id, branch=version, exist_ok=True)


def migrate(
files,
push: bool = False,
Expand All @@ -216,68 +288,18 @@ def migrate(
subtasks_dict = load_subtasks()
if push:
api = HfApi()
else:
api = None
errors = []
for file in files:
adapter_name = os.path.basename(file).split(".")[0]
print(f"Migrating {adapter_name} ...")
with open(file, "r") as f:
data = yaml.load(f, yaml.FullLoader)
subtask_info = subtasks_dict.get(data["task"] + "/" + data["subtask"])

if push and skip_existing:
if api.repo_exists(repo_id=hf_org_name + "/" + adapter_name):
print(f"Skipping {adapter_name} version {version_data['version']}")
continue

# create a subfolder for each version in the output
for version_data in data["files"]:
version = version_data["version"]
is_default = version == data["default_version"]
version_folder = os.path.join(OUTPUT_FOLDER, adapter_name, version)
os.makedirs(os.path.dirname(version_folder), exist_ok=True)

# download the checkpoint
dl_folder = download_cached(version_data["url"])
shutil.move(dl_folder, version_folder)

# try loading the adapter
model = AutoAdapterModel.from_pretrained(data["model_name"])
loaded_name = model.load_adapter(version_folder, set_active=True)
model.save_adapter(version_folder, loaded_name)

if loaded_name in model.heads:
head_type = model.heads[loaded_name].config["head_type"]
else:
head_type = None

adapter_card = create_adapter_card(
file,
adapter_name,
data,
subtask_info,
version=version,
head_type=head_type,
hf_org_name=hf_org_name,
)

# write the adapter card
with open(os.path.join(version_folder, "README.md"), "w") as f:
f.write(adapter_card)
try:
migrate_file(file, push, hf_org_name, skip_existing, subtasks_dict, api)
except Exception as e:
errors.append(file)
print(f"Error migrating {file}: {e}")

del model

if push:
repo_id = hf_org_name + "/" + adapter_name
api.create_repo(repo_id, exist_ok=True)
if not is_default:
api.create_branch(repo_id=repo_id, branch=version, exist_ok=True)
api.upload_folder(
repo_id=repo_id,
folder_path=version_folder,
revision="main" if is_default else version,
commit_message=f"Add adapter {adapter_name} version {version}",
)
if is_default:
api.create_branch(repo_id=repo_id, branch=version, exist_ok=True)
with open(ERROR_FILE, "w") as f:
f.write("\n".join(errors))


if __name__ == "__main__":
Expand All @@ -293,4 +315,9 @@ def migrate(
args = parser.parse_args()

files = glob(os.path.join(REPO_FOLDER, args.folder, "*"))
migrate(files, push=args.push, hf_org_name=args.org_name)
migrate(
files,
push=args.push,
hf_org_name=args.org_name,
skip_existing=args.skip_existing,
)

0 comments on commit 439c26a

Please sign in to comment.