From bafcb1190f782e93e26505bc59badd79d6a90427 Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 13 Aug 2024 08:58:34 +0200 Subject: [PATCH 01/10] First draft --- vggsfm/models/vggsfm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vggsfm/models/vggsfm.py b/vggsfm/models/vggsfm.py index 558fe3e..a500302 100644 --- a/vggsfm/models/vggsfm.py +++ b/vggsfm/models/vggsfm.py @@ -6,14 +6,15 @@ import torch import torch.nn as nn -import torch.nn.functional as F -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict from hydra.utils import instantiate +from huggingface_hub import PyTorchModelHubMixin -class VGGSfM(nn.Module): + +class VGGSfM(nn.Module, PyTorchModelHubMixin): def __init__(self, TRACK: Dict, CAMERA: Dict, TRIANGULAE: Dict, cfg=None): """ Initializes a VGGSfM model From 86ae562dff19b62b3df751c3cf577bd0a5fe936b Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 13 Aug 2024 09:20:29 +0200 Subject: [PATCH 02/10] Update device --- vggsfm/models/vggsfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vggsfm/models/vggsfm.py b/vggsfm/models/vggsfm.py index a500302..d1b977a 100644 --- a/vggsfm/models/vggsfm.py +++ b/vggsfm/models/vggsfm.py @@ -42,7 +42,7 @@ def from_pretrained(self, model_name): ckpt_path = hf_hub_download( repo_id="facebook/VGGSfM", filename=model_name + ".bin" ) - checkpoint = torch.load(ckpt_path) + checkpoint = torch.load(ckpt_path, map_location="cpu") except: # In case the model is not hosted on huggingface # or the user cannot import huggingface_hub correctly From ae1b8b3c3a0964f33eadd71b0b6848a238069acd Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 13 Aug 2024 09:30:50 +0200 Subject: [PATCH 03/10] Comment out from_pretrained --- vggsfm/models/vggsfm.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vggsfm/models/vggsfm.py b/vggsfm/models/vggsfm.py index d1b977a..eb95c49 100644 --- a/vggsfm/models/vggsfm.py +++ b/vggsfm/models/vggsfm.py @@ -35,18 +35,18 @@ def __init__(self, TRACK: Dict, CAMERA: Dict, TRIANGULAE: Dict, cfg=None): # models.Triangulator self.triangulator = instantiate(TRIANGULAE, _recursive_=False, cfg=cfg) - def from_pretrained(self, model_name): - try: - from huggingface_hub import hf_hub_download - - ckpt_path = hf_hub_download( - repo_id="facebook/VGGSfM", filename=model_name + ".bin" - ) - checkpoint = torch.load(ckpt_path, map_location="cpu") - except: - # In case the model is not hosted on huggingface - # or the user cannot import huggingface_hub correctly - _VGGSFM_URL = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_0_0.bin" - checkpoint = torch.hub.load_state_dict_from_url(_VGGSFM_URL) - - self.load_state_dict(checkpoint, strict=True) + # def from_pretrained(self, model_name): + # try: + # from huggingface_hub import hf_hub_download + + # ckpt_path = hf_hub_download( + # repo_id="facebook/VGGSfM", filename=model_name + ".bin" + # ) + # checkpoint = torch.load(ckpt_path, map_location="cpu") + # except: + # # In case the model is not hosted on huggingface + # # or the user cannot import huggingface_hub correctly + # _VGGSFM_URL = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_0_0.bin" + # checkpoint = torch.hub.load_state_dict_from_url(_VGGSFM_URL) + + # self.load_state_dict(checkpoint, strict=True) From aa55371909405f87ba164c49e61ca24aa8117e57 Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 13 Aug 2024 09:33:32 +0200 Subject: [PATCH 04/10] Comment out from_pretrained --- vggsfm/runners/runner.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vggsfm/runners/runner.py b/vggsfm/runners/runner.py index 4956874..c984100 100644 --- a/vggsfm/runners/runner.py +++ b/vggsfm/runners/runner.py @@ -112,11 +112,11 @@ def build_vggsfm_model(self): self.cfg.MODEL, _recursive_=False, cfg=self.cfg ) - if self.cfg.auto_download_ckpt: - vggsfm.from_pretrained(self.cfg.model_name) - else: - checkpoint = torch.load(self.cfg.resume_ckpt) - vggsfm.load_state_dict(checkpoint, strict=True) + # if self.cfg.auto_download_ckpt: + # vggsfm.from_pretrained(self.cfg.model_name) + # else: + checkpoint = torch.load(self.cfg.resume_ckpt) + vggsfm.load_state_dict(checkpoint, strict=True) self.vggsfm_model = vggsfm.to(self.device).eval() print("VGGSfM built successfully") From e6ea0e0508630a1dc2da7276798c610c16a32a6c Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 13 Aug 2024 10:57:05 +0200 Subject: [PATCH 05/10] Use hf_hub_download --- vggsfm/runners/runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vggsfm/runners/runner.py b/vggsfm/runners/runner.py index c984100..c9c5f6f 100644 --- a/vggsfm/runners/runner.py +++ b/vggsfm/runners/runner.py @@ -115,7 +115,9 @@ def build_vggsfm_model(self): # if self.cfg.auto_download_ckpt: # vggsfm.from_pretrained(self.cfg.model_name) # else: - checkpoint = torch.load(self.cfg.resume_ckpt) + from huggingface_hub import hf_hub_download + filepath = hf_hub_download(repo_id="facebook/VGGSfM", filename="vggsfm_v2_0_0.bin", repo_type="model") + checkpoint = torch.load(filepath, map_location="cpu") vggsfm.load_state_dict(checkpoint, strict=True) self.vggsfm_model = vggsfm.to(self.device).eval() print("VGGSfM built successfully") From cd3d66e2925356f6f03388676644b9c523fd4762 Mon Sep 17 00:00:00 2001 From: Niels Date: Wed, 14 Aug 2024 22:46:02 +0200 Subject: [PATCH 06/10] Add print statements --- vggsfm/models/vggsfm.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vggsfm/models/vggsfm.py b/vggsfm/models/vggsfm.py index eb95c49..e703061 100644 --- a/vggsfm/models/vggsfm.py +++ b/vggsfm/models/vggsfm.py @@ -22,6 +22,13 @@ def __init__(self, TRACK: Dict, CAMERA: Dict, TRIANGULAE: Dict, cfg=None): TRACK, CAMERA, TRIANGULAE are the dicts to construct the model modules cfg is the whole hydra config """ + print("TRACK", TRACK) + print("Type of TRACK", type(TRACK)) + print("CAMERA", CAMERA) + print("Type of CAMERA", type(CAMERA)) + print("TRIANGULAE", TRIANGULAE) + print("CFG", cfg) + print("Type of CFG", type(cfg)) super().__init__() self.cfg = cfg From 59936099fa1103c1a01d9956e1d524ea029b3bac Mon Sep 17 00:00:00 2001 From: Niels Date: Wed, 14 Aug 2024 22:57:47 +0200 Subject: [PATCH 07/10] Add coders --- vggsfm/models/vggsfm.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vggsfm/models/vggsfm.py b/vggsfm/models/vggsfm.py index e703061..89409fd 100644 --- a/vggsfm/models/vggsfm.py +++ b/vggsfm/models/vggsfm.py @@ -11,11 +11,22 @@ from hydra.utils import instantiate +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig + from huggingface_hub import PyTorchModelHubMixin -class VGGSfM(nn.Module, PyTorchModelHubMixin): - def __init__(self, TRACK: Dict, CAMERA: Dict, TRIANGULAE: Dict, cfg=None): +class VGGSfM(nn.Module, + PyTorchModelHubMixin, + coders={ + DictConfig : ( + lambda x: OmegaConf.to_container(x, resolve=True), # Encoder: how to convert a `DictConfig` to a valid jsonable value? + lambda data: OmegaConf.create(data), # Decoder: how to reconstruct a `DictConfig` from a dictionary? + ), + } + ): + def __init__(self, TRACK: DictConfig, CAMERA: DictConfig, TRIANGULAE: DictConfig, cfg: DictConfig = None): """ Initializes a VGGSfM model From 445979aaeb380cbd3c5d73d4fb15d2ff667ab2a9 Mon Sep 17 00:00:00 2001 From: Niels Date: Wed, 14 Aug 2024 23:04:38 +0200 Subject: [PATCH 08/10] Clean up --- vggsfm/models/vggsfm.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/vggsfm/models/vggsfm.py b/vggsfm/models/vggsfm.py index 89409fd..b6a06b6 100644 --- a/vggsfm/models/vggsfm.py +++ b/vggsfm/models/vggsfm.py @@ -19,6 +19,8 @@ class VGGSfM(nn.Module, PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/vggsfm", + pipeline_tag="image-to-3D", coders={ DictConfig : ( lambda x: OmegaConf.to_container(x, resolve=True), # Encoder: how to convert a `DictConfig` to a valid jsonable value? @@ -33,13 +35,6 @@ def __init__(self, TRACK: DictConfig, CAMERA: DictConfig, TRIANGULAE: DictConfig TRACK, CAMERA, TRIANGULAE are the dicts to construct the model modules cfg is the whole hydra config """ - print("TRACK", TRACK) - print("Type of TRACK", type(TRACK)) - print("CAMERA", CAMERA) - print("Type of CAMERA", type(CAMERA)) - print("TRIANGULAE", TRIANGULAE) - print("CFG", cfg) - print("Type of CFG", type(cfg)) super().__init__() self.cfg = cfg @@ -52,19 +47,3 @@ def __init__(self, TRACK: DictConfig, CAMERA: DictConfig, TRIANGULAE: DictConfig # models.Triangulator self.triangulator = instantiate(TRIANGULAE, _recursive_=False, cfg=cfg) - - # def from_pretrained(self, model_name): - # try: - # from huggingface_hub import hf_hub_download - - # ckpt_path = hf_hub_download( - # repo_id="facebook/VGGSfM", filename=model_name + ".bin" - # ) - # checkpoint = torch.load(ckpt_path, map_location="cpu") - # except: - # # In case the model is not hosted on huggingface - # # or the user cannot import huggingface_hub correctly - # _VGGSFM_URL = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_0_0.bin" - # checkpoint = torch.hub.load_state_dict_from_url(_VGGSFM_URL) - - # self.load_state_dict(checkpoint, strict=True) From c904c7fd2929cd7e2c95793c0b0cc581dea532f1 Mon Sep 17 00:00:00 2001 From: Niels Date: Thu, 15 Aug 2024 09:56:51 +0200 Subject: [PATCH 09/10] Add license --- vggsfm/models/vggsfm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vggsfm/models/vggsfm.py b/vggsfm/models/vggsfm.py index b6a06b6..37ae5f4 100644 --- a/vggsfm/models/vggsfm.py +++ b/vggsfm/models/vggsfm.py @@ -21,6 +21,7 @@ class VGGSfM(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/facebookresearch/vggsfm", pipeline_tag="image-to-3D", + license="cc-by-nc-sa-4.0", coders={ DictConfig : ( lambda x: OmegaConf.to_container(x, resolve=True), # Encoder: how to convert a `DictConfig` to a valid jsonable value? From 64f9ce12ed979c68ec2205e741e989c3926dd413 Mon Sep 17 00:00:00 2001 From: Niels Date: Thu, 15 Aug 2024 10:15:35 +0200 Subject: [PATCH 10/10] Update pipeline_tag --- vggsfm/models/vggsfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vggsfm/models/vggsfm.py b/vggsfm/models/vggsfm.py index 37ae5f4..52d3345 100644 --- a/vggsfm/models/vggsfm.py +++ b/vggsfm/models/vggsfm.py @@ -20,7 +20,7 @@ class VGGSfM(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/facebookresearch/vggsfm", - pipeline_tag="image-to-3D", + pipeline_tag="image-to-3d", license="cc-by-nc-sa-4.0", coders={ DictConfig : (