Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Hugging Face integration #50

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions vggsfm/models/vggsfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,30 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict

from hydra.utils import instantiate


class VGGSfM(nn.Module):
def __init__(self, TRACK: Dict, CAMERA: Dict, TRIANGULAE: Dict, cfg=None):
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig

from huggingface_hub import PyTorchModelHubMixin


class VGGSfM(nn.Module,
PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/vggsfm",
pipeline_tag="image-to-3d",
license="cc-by-nc-sa-4.0",
coders={
DictConfig : (
lambda x: OmegaConf.to_container(x, resolve=True), # Encoder: how to convert a `DictConfig` to a valid jsonable value?
lambda data: OmegaConf.create(data), # Decoder: how to reconstruct a `DictConfig` from a dictionary?
),
}
):
def __init__(self, TRACK: DictConfig, CAMERA: DictConfig, TRIANGULAE: DictConfig, cfg: DictConfig = None):
"""
Initializes a VGGSfM model

Expand All @@ -33,19 +48,3 @@ def __init__(self, TRACK: Dict, CAMERA: Dict, TRIANGULAE: Dict, cfg=None):

# models.Triangulator
self.triangulator = instantiate(TRIANGULAE, _recursive_=False, cfg=cfg)

def from_pretrained(self, model_name):
try:
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(
repo_id="facebook/VGGSfM", filename=model_name + ".bin"
)
checkpoint = torch.load(ckpt_path)
except:
# In case the model is not hosted on huggingface
# or the user cannot import huggingface_hub correctly
_VGGSFM_URL = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_0_0.bin"
checkpoint = torch.hub.load_state_dict_from_url(_VGGSFM_URL)

self.load_state_dict(checkpoint, strict=True)
12 changes: 7 additions & 5 deletions vggsfm/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ def build_vggsfm_model(self):
self.cfg.MODEL, _recursive_=False, cfg=self.cfg
)

if self.cfg.auto_download_ckpt:
vggsfm.from_pretrained(self.cfg.model_name)
else:
checkpoint = torch.load(self.cfg.resume_ckpt)
vggsfm.load_state_dict(checkpoint, strict=True)
# if self.cfg.auto_download_ckpt:
# vggsfm.from_pretrained(self.cfg.model_name)
# else:
from huggingface_hub import hf_hub_download
filepath = hf_hub_download(repo_id="facebook/VGGSfM", filename="vggsfm_v2_0_0.bin", repo_type="model")
checkpoint = torch.load(filepath, map_location="cpu")
vggsfm.load_state_dict(checkpoint, strict=True)
self.vggsfm_model = vggsfm.to(self.device).eval()
print("VGGSfM built successfully")

Expand Down