Skip to content

Commit

Permalink
perf: cache clip embeddings under /tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
HiroIshida committed Feb 11, 2025
1 parent 46f8317 commit cec3bce
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions detic/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,27 @@
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.video_visualizer import VideoVisualizer
from detectron2.utils.visualizer import ColorMode, Visualizer
from pathlib import Path
from hashlib import md5

from .modeling.utils import reset_cls_test


def get_clip_embeddings(vocabulary, prompt='a '):
from detic.modeling.text.text_encoder import build_text_encoder
text_encoder = build_text_encoder(pretrain=True)
text_encoder.eval()
texts = [prompt + x for x in vocabulary]
emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
return emb
hash_value = md5(prompt.encode()).hexdigest()
cache_file_path = f"/tmp/detic-clip-embeddings-{hash_value}.pt"
if Path(cache_file_path).exists():
print(f"loading embeddings for {vocabulary} from {cache_file_path}")
return torch.load(cache_file_path)
else:
from detic.modeling.text.text_encoder import build_text_encoder
text_encoder = build_text_encoder(pretrain=True)
text_encoder.eval()
texts = [prompt + x for x in vocabulary]
emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
print(f"saved embeddings for {vocabulary} to {cache_file_path}")
torch.save(emb, cache_file_path)
return emb

BUILDIN_CLASSIFIER = {
'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy',
Expand Down

0 comments on commit cec3bce

Please sign in to comment.