Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: cache CLIP embeddings under /tmp #123

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions detic/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,28 @@
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.video_visualizer import VideoVisualizer
from detectron2.utils.visualizer import ColorMode, Visualizer
from pathlib import Path
from hashlib import md5

from .modeling.utils import reset_cls_test


def get_clip_embeddings(vocabulary, prompt='a '):
from detic.modeling.text.text_encoder import build_text_encoder
text_encoder = build_text_encoder(pretrain=True)
text_encoder.eval()
texts = [prompt + x for x in vocabulary]
emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
return emb
# NOTE: need hashing due to filename length limit
hash_value = md5("-".join(sorted(vocabulary)).encode()).hexdigest()
cache_file_path = f"/tmp/detic-clip-embeddings-{hash_value}.pt"
if Path(cache_file_path).exists():
print(f"loading embeddings for {vocabulary} from {cache_file_path}")
return torch.load(cache_file_path)
else:
from detic.modeling.text.text_encoder import build_text_encoder
text_encoder = build_text_encoder(pretrain=True)
text_encoder.eval()
texts = [prompt + x for x in vocabulary]
emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
print(f"saved embeddings for {vocabulary} to {cache_file_path}")
torch.save(emb, cache_file_path)
return emb

BUILDIN_CLASSIFIER = {
'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy',
Expand Down