Skip to content

Commit

Permalink
Merge pull request #19 from lezhang7/le
Browse files Browse the repository at this point in the history
update readme
  • Loading branch information
lezhang7 authored Nov 24, 2024
2 parents 9b7d9bc + d4b3aee commit 6be76f9
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 41 deletions.
239 changes: 239 additions & 0 deletions evaluation/eval_knn_aimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse

import torch
from torch import nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torchvision import datasets
from torchvision import transforms as pth_transforms
from torchvision import models as torchvision_models
from transformers import AutoModel

import knn_utils as utils


def extract_feature_pipeline(args):
# ============ preparing data ... ============
transform = pth_transforms.Compose([
pth_transforms.Resize(224, interpolation=3), # Corresponds to "do_resize" and "size.shortest_edge"
pth_transforms.CenterCrop(224), # Corresponds to "do_center_crop" and "crop_size"
pth_transforms.ToTensor(), # Converts to tensor
pth_transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073), # Corresponds to "image_mean"
std=(0.26862954, 0.26130258, 0.27577711) # Corresponds to "image_std"
),
])
dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform)
dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform)
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
sampler=sampler,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
)
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")

# ============ building network ... ============

model = AutoModel.from_pretrained("apple/aimv2-large-patch14-224", trust_remote_code=True)
model.cuda()
model.eval()

# ============ extract features ... ============
print("Extracting features for train set...")
train_features = extract_features(model, data_loader_train, args.use_cuda)
print("Extracting features for val set...")
test_features = extract_features(model, data_loader_val, args.use_cuda)

if utils.get_rank() == 0:
train_features = nn.functional.normalize(train_features, dim=1, p=2)
test_features = nn.functional.normalize(test_features, dim=1, p=2)

train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
# save features and labels
if args.dump_features and dist.get_rank() == 0:
torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth"))
torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth"))
torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth"))
torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth"))
return train_features, test_features, train_labels, test_labels


@torch.no_grad()
def extract_features(model, data_loader, use_cuda=True, multiscale=False):
metric_logger = utils.MetricLogger(delimiter=" ")
features = None
for samples, index in metric_logger.log_every(data_loader, 10):
samples = samples.cuda(non_blocking=True)
index = index.cuda(non_blocking=True)
if multiscale:
feats = utils.multi_scale(samples, model)
else:
feats = torch.mean(model(samples).last_hidden_state,dim=1).clone()

# init storage feature matrix
if dist.get_rank() == 0 and features is None:
features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
if use_cuda:
features = features.cuda(non_blocking=True)
print(f"Storing features into tensor of shape {features.shape}")

# get indexes from all processes
y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
y_l = list(y_all.unbind(0))
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
y_all_reduce.wait()
index_all = torch.cat(y_l)

# share features between processes
feats_all = torch.empty(
dist.get_world_size(),
feats.size(0),
feats.size(1),
dtype=feats.dtype,
device=feats.device,
)
output_l = list(feats_all.unbind(0))
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
output_all_reduce.wait()

# update storage feature matrix
if dist.get_rank() == 0:
if use_cuda:
print("features shape:", features.shape)
print("concatenated output shape:", torch.cat(output_l).shape)
print("index_all shape:", index_all.shape)

features.index_copy_(0, index_all, torch.cat(output_l))
else:
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
return features


@torch.no_grad()
def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000):
top1, top5, total = 0.0, 0.0, 0
train_features = train_features.t()
num_test_images, num_chunks = test_labels.shape[0], 100
imgs_per_chunk = num_test_images // num_chunks
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
for idx in range(0, num_test_images, imgs_per_chunk):
# get the features for test images
features = test_features[
idx : min((idx + imgs_per_chunk), num_test_images), :
]
targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)]
batch_size = targets.shape[0]

# calculate the dot product and compute top-k neighbors
similarity = torch.mm(features, train_features)
distances, indices = similarity.topk(k, largest=True, sorted=True)
candidates = train_labels.view(1, -1).expand(batch_size, -1)
retrieved_neighbors = torch.gather(candidates, 1, indices)

retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
distances_transform = distances.clone().div_(T).exp_()
probs = torch.sum(
torch.mul(
retrieval_one_hot.view(batch_size, -1, num_classes),
distances_transform.view(batch_size, -1, 1),
),
1,
)
_, predictions = probs.sort(1, True)

# find the predictions that match the target
correct = predictions.eq(targets.data.view(-1, 1))
top1 = top1 + correct.narrow(1, 0, 1).sum().item()
top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() # top5 does not make sense if k < 5
total += targets.size(0)
top1 = top1 * 100.0 / total
top5 = top5 * 100.0 / total
return top1, top5


class ReturnIndexDataset(datasets.ImageFolder):
def __getitem__(self, idx):
img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
return img, idx


if __name__ == '__main__':
parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet')
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int,
help='Number of NN to use. 20 is usually working the best.')
parser.add_argument('--temperature', default=0.07, type=float,
help='Temperature used in the voting coefficient')
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag,
help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM")
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
parser.add_argument("--checkpoint_key", default="teacher", type=str,
help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument('--dump_features', default=None,
help='Path where to save computed features, empty for no saving')
parser.add_argument('--load_features', default=None, help="""If the features have
already been computed, where to find them.""")
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.")
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
args = parser.parse_args()

utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
cudnn.benchmark = True

if args.load_features:
train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth"))
test_features = torch.load(os.path.join(args.load_features, "testfeat.pth"))
train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth"))
test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth"))
else:
# need to extract features !
train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args)

if utils.get_rank() == 0:
if args.use_cuda:
train_features = train_features.cuda()
test_features = test_features.cuda()
train_labels = train_labels.cuda()
test_labels = test_labels.cuda()

print("Features are ready!\nStart the k-NN classification.")
for k in args.nb_knn:
top1, top5 = knn_classifier(train_features, train_labels,
test_features, test_labels, k, args.temperature)
print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
dist.barrier()
2 changes: 1 addition & 1 deletion model/sail_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(
super(SAILModel, self).__init__()
self.text_model = SentenceEmbedding(text_model_name)
self.vision_model = ImageEmbedding(vision_model_name, seg=seg, agg_mode=agg_mode)
if any(x in vision_model_name for x in ['mae','ibot','dinov1','aim','ijepa','clip']) or 'patch' in agg_mode or 'cls' in agg_mode:
if any(x in vision_model_name for x in ['mae','ibot','dinov1','ml-aim','ijepa','clip','aimv2']) or 'patch' in agg_mode or 'cls' in agg_mode:
if hasattr(self.vision_model.model, 'config'):
vision_dimesion = self.vision_model.model.config.hidden_size
else:
Expand Down
41 changes: 20 additions & 21 deletions model/vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, model_name="facebook/dinov2-base", device=None, seg: bool = F
self.agg_mode = agg_mode
self.model_name = model_name

if any(x in model_name for x in ['ibot', 'mae', 'dinov1', 'aim', 'ijepa']):
if any(x in model_name for x in ['ibot', 'mae', 'dinov1', 'ml-aim', 'ijepa']):
# load from local
if 'ibot' in model_name:
self.model = get_ibot_vit(model_name)
Expand All @@ -84,7 +84,7 @@ def __init__(self, model_name="facebook/dinov2-base", device=None, seg: bool = F
elif 'resnet' in model_name:
self.model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
self.model.embed_dim = 2048
elif 'aim' in model_name:
elif 'ml-aim' in model_name:
if '1B' in model_name:
self.model = torch.hub.load("apple/ml-aim", "aim_1B")
self.model.embed_dim = 2048
Expand All @@ -107,6 +107,9 @@ def __init__(self, model_name="facebook/dinov2-base", device=None, seg: bool = F
elif any(x in model_name.lower() for x in ['clip']):
self.model = CLIPVisionModel.from_pretrained(model_name, torch_dtype=torch.float16)
self.image_processor = AutoImageProcessor.from_pretrained(model_name)
elif any(x in model_name.lower() for x in ['aimv2']):
self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True)
self.image_processor = AutoImageProcessor.from_pretrained(model_name)
else:
if seg:
modify_vit('seg')
Expand All @@ -133,13 +136,6 @@ def load_single_image(args):
images[idx] = img

return images

# def load_images_from_directory(self, images_path: List[str]) -> List[Image.Image]:
# images = []
# for image_path in images_path:
# with Image.open(image_path) as img:
# images.append(img.convert("RGB"))
# return images

def get_visual_embeddings_from_directory(self, images_path: List[str]):
images = self.load_images_from_directory(images_path)
Expand All @@ -161,7 +157,7 @@ def forward(self, inputs, patch_mode=False, attetion_type='qk', ignore_residual=
self.model.encoder.attetion_type = attetion_type
self.model.encoder.ignore_residual = ignore_residual

if any(x in self.model_name.lower() for x in ['mae', 'convnextv2']):
if any(x in self.model_name.lower() for x in ['mae']):
if isinstance(inputs, torch.Tensor):
outputs = self.model.forward_features(inputs)
else:
Expand All @@ -170,43 +166,46 @@ def forward(self, inputs, patch_mode=False, attetion_type='qk', ignore_residual=
if isinstance(inputs, torch.Tensor):
outputs = self.model(inputs)
elif isinstance(inputs, dict) or isinstance(inputs, BaseBatchFeature):
if any(x in self.model_name.lower() for x in ['ibot', 'dinov1', 'aim', 'ijepa']):
if any(x in self.model_name.lower() for x in ['ibot', 'dinov1', 'ml-aim', 'ijepa']):
outputs = self.model(inputs['pixel_values'])
else:
# huggingface transformer vision model
outputs = self.model(**inputs)
else:
raise ValueError(f"Unsupported input type: {type(inputs)}")

# extract the embeddings
if any(x in self.model_name.lower() for x in ['ijepa']):
linear_input = outputs.mean(dim=1)
embedding = outputs.mean(dim=1)
elif any(x in self.model_name.lower() for x in ['clip']):
linear_input = outputs.pooler_output
embedding = outputs.pooler_output
elif any(x in self.model_name.lower() for x in ['aimv2']):
embedding = torch.mean(outputs.last_hidden_state, dim=1)
else:
sequence_output = outputs[0] # batch_size, sequence_length, hidden_size

if patch_mode:
patch_tokens = sequence_output[:, 1:]
cls_token = sequence_output[:, 0].unsqueeze(1).repeat(1, patch_tokens.shape[1], 1)
linear_input = torch.cat([cls_token, patch_tokens], dim=-1)
embedding = torch.cat([cls_token, patch_tokens], dim=-1)
else:
if any(x in self.model_name.lower() for x in ['ibot', 'r101', 'r152', 'mae', 'convnextv2', 'dinov1']):
linear_input = outputs
if any(x in self.model_name.lower() for x in ['ibot', 'mae', 'dinov1']):
embedding = outputs
elif any(x in self.model_name.lower() for x in ['aim']):
linear_input = outputs[1]
embedding = outputs[1]
else:
cls_token = sequence_output[:, 0]
patch_tokens = sequence_output[:, 1:]
if self.agg_mode == 'patch':
linear_input = patch_tokens.mean(dim=1)
embedding = patch_tokens.mean(dim=1)
elif self.agg_mode == 'cls':
linear_input = cls_token
embedding = cls_token
elif self.agg_mode == 'concat':
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
embedding = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
else:
raise ValueError(f"Invalid agg_mode: {self.agg_mode}")

return linear_input
return embedding


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 6be76f9

Please sign in to comment.