-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from lezhang7/le
update readme
- Loading branch information
Showing
7 changed files
with
377 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import sys | ||
import argparse | ||
|
||
import torch | ||
from torch import nn | ||
import torch.distributed as dist | ||
import torch.backends.cudnn as cudnn | ||
from torchvision import datasets | ||
from torchvision import transforms as pth_transforms | ||
from torchvision import models as torchvision_models | ||
from transformers import AutoModel | ||
|
||
import knn_utils as utils | ||
|
||
|
||
def extract_feature_pipeline(args): | ||
# ============ preparing data ... ============ | ||
transform = pth_transforms.Compose([ | ||
pth_transforms.Resize(224, interpolation=3), # Corresponds to "do_resize" and "size.shortest_edge" | ||
pth_transforms.CenterCrop(224), # Corresponds to "do_center_crop" and "crop_size" | ||
pth_transforms.ToTensor(), # Converts to tensor | ||
pth_transforms.Normalize( | ||
mean=(0.48145466, 0.4578275, 0.40821073), # Corresponds to "image_mean" | ||
std=(0.26862954, 0.26130258, 0.27577711) # Corresponds to "image_std" | ||
), | ||
]) | ||
dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform) | ||
dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform) | ||
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) | ||
data_loader_train = torch.utils.data.DataLoader( | ||
dataset_train, | ||
sampler=sampler, | ||
batch_size=args.batch_size_per_gpu, | ||
num_workers=args.num_workers, | ||
pin_memory=True, | ||
drop_last=False, | ||
) | ||
data_loader_val = torch.utils.data.DataLoader( | ||
dataset_val, | ||
batch_size=args.batch_size_per_gpu, | ||
num_workers=args.num_workers, | ||
pin_memory=True, | ||
drop_last=False, | ||
) | ||
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") | ||
|
||
# ============ building network ... ============ | ||
|
||
model = AutoModel.from_pretrained("apple/aimv2-large-patch14-224", trust_remote_code=True) | ||
model.cuda() | ||
model.eval() | ||
|
||
# ============ extract features ... ============ | ||
print("Extracting features for train set...") | ||
train_features = extract_features(model, data_loader_train, args.use_cuda) | ||
print("Extracting features for val set...") | ||
test_features = extract_features(model, data_loader_val, args.use_cuda) | ||
|
||
if utils.get_rank() == 0: | ||
train_features = nn.functional.normalize(train_features, dim=1, p=2) | ||
test_features = nn.functional.normalize(test_features, dim=1, p=2) | ||
|
||
train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long() | ||
test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long() | ||
# save features and labels | ||
if args.dump_features and dist.get_rank() == 0: | ||
torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")) | ||
torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")) | ||
torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")) | ||
torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")) | ||
return train_features, test_features, train_labels, test_labels | ||
|
||
|
||
@torch.no_grad() | ||
def extract_features(model, data_loader, use_cuda=True, multiscale=False): | ||
metric_logger = utils.MetricLogger(delimiter=" ") | ||
features = None | ||
for samples, index in metric_logger.log_every(data_loader, 10): | ||
samples = samples.cuda(non_blocking=True) | ||
index = index.cuda(non_blocking=True) | ||
if multiscale: | ||
feats = utils.multi_scale(samples, model) | ||
else: | ||
feats = torch.mean(model(samples).last_hidden_state,dim=1).clone() | ||
|
||
# init storage feature matrix | ||
if dist.get_rank() == 0 and features is None: | ||
features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) | ||
if use_cuda: | ||
features = features.cuda(non_blocking=True) | ||
print(f"Storing features into tensor of shape {features.shape}") | ||
|
||
# get indexes from all processes | ||
y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) | ||
y_l = list(y_all.unbind(0)) | ||
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) | ||
y_all_reduce.wait() | ||
index_all = torch.cat(y_l) | ||
|
||
# share features between processes | ||
feats_all = torch.empty( | ||
dist.get_world_size(), | ||
feats.size(0), | ||
feats.size(1), | ||
dtype=feats.dtype, | ||
device=feats.device, | ||
) | ||
output_l = list(feats_all.unbind(0)) | ||
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) | ||
output_all_reduce.wait() | ||
|
||
# update storage feature matrix | ||
if dist.get_rank() == 0: | ||
if use_cuda: | ||
print("features shape:", features.shape) | ||
print("concatenated output shape:", torch.cat(output_l).shape) | ||
print("index_all shape:", index_all.shape) | ||
|
||
features.index_copy_(0, index_all, torch.cat(output_l)) | ||
else: | ||
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) | ||
return features | ||
|
||
|
||
@torch.no_grad() | ||
def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000): | ||
top1, top5, total = 0.0, 0.0, 0 | ||
train_features = train_features.t() | ||
num_test_images, num_chunks = test_labels.shape[0], 100 | ||
imgs_per_chunk = num_test_images // num_chunks | ||
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) | ||
for idx in range(0, num_test_images, imgs_per_chunk): | ||
# get the features for test images | ||
features = test_features[ | ||
idx : min((idx + imgs_per_chunk), num_test_images), : | ||
] | ||
targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] | ||
batch_size = targets.shape[0] | ||
|
||
# calculate the dot product and compute top-k neighbors | ||
similarity = torch.mm(features, train_features) | ||
distances, indices = similarity.topk(k, largest=True, sorted=True) | ||
candidates = train_labels.view(1, -1).expand(batch_size, -1) | ||
retrieved_neighbors = torch.gather(candidates, 1, indices) | ||
|
||
retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() | ||
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) | ||
distances_transform = distances.clone().div_(T).exp_() | ||
probs = torch.sum( | ||
torch.mul( | ||
retrieval_one_hot.view(batch_size, -1, num_classes), | ||
distances_transform.view(batch_size, -1, 1), | ||
), | ||
1, | ||
) | ||
_, predictions = probs.sort(1, True) | ||
|
||
# find the predictions that match the target | ||
correct = predictions.eq(targets.data.view(-1, 1)) | ||
top1 = top1 + correct.narrow(1, 0, 1).sum().item() | ||
top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() # top5 does not make sense if k < 5 | ||
total += targets.size(0) | ||
top1 = top1 * 100.0 / total | ||
top5 = top5 * 100.0 / total | ||
return top1, top5 | ||
|
||
|
||
class ReturnIndexDataset(datasets.ImageFolder): | ||
def __getitem__(self, idx): | ||
img, lab = super(ReturnIndexDataset, self).__getitem__(idx) | ||
return img, idx | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet') | ||
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') | ||
parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int, | ||
help='Number of NN to use. 20 is usually working the best.') | ||
parser.add_argument('--temperature', default=0.07, type=float, | ||
help='Temperature used in the voting coefficient') | ||
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") | ||
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag, | ||
help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM") | ||
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') | ||
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') | ||
parser.add_argument("--checkpoint_key", default="teacher", type=str, | ||
help='Key to use in the checkpoint (example: "teacher")') | ||
parser.add_argument('--dump_features', default=None, | ||
help='Path where to save computed features, empty for no saving') | ||
parser.add_argument('--load_features', default=None, help="""If the features have | ||
already been computed, where to find them.""") | ||
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') | ||
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up | ||
distributed training; see https://pytorch.org/docs/stable/distributed.html""") | ||
parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.") | ||
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) | ||
args = parser.parse_args() | ||
|
||
utils.init_distributed_mode(args) | ||
print("git:\n {}\n".format(utils.get_sha())) | ||
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) | ||
cudnn.benchmark = True | ||
|
||
if args.load_features: | ||
train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) | ||
test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) | ||
train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) | ||
test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) | ||
else: | ||
# need to extract features ! | ||
train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args) | ||
|
||
if utils.get_rank() == 0: | ||
if args.use_cuda: | ||
train_features = train_features.cuda() | ||
test_features = test_features.cuda() | ||
train_labels = train_labels.cuda() | ||
test_labels = test_labels.cuda() | ||
|
||
print("Features are ready!\nStart the k-NN classification.") | ||
for k in args.nb_knn: | ||
top1, top5 = knn_classifier(train_features, train_labels, | ||
test_features, test_labels, k, args.temperature) | ||
print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") | ||
dist.barrier() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.