This repository contains an implementation of all Probabilistic Metric Learning (PML) approaches from Probabilistic Embeddings Revisited paper. It fully supports the following probabilistic methods from previous works:
In addition to PML approaches, classical (deterministic) Metric Learning (ML) methods are supported:
- Clone this repository:
git clone git@github.com:tinkoff-ai/probabilistic-embeddings.git cd probabilistic-embeddings
- We recommend building our Docker image with
Dockerfile
. - Library must be installed before execution. It is recommended to use editable installation:
pip install -e .
- You can check the installation using tests:
tox -e py38 -r
-
Prepare experiment
.yaml
config. In this example, a simple ArcFace model is trained on LWF dataset:dataset_params: name: lfw-openset samples_per_class: null model_params: # Embedder maps input image to embedding vector space. embedder_params: model_type: resnet18 # Use ImageNet pretrain. pretrained: true distribution_params: # Spherical 512D embeddings. spherical: true dim: 512 # For deterministic embeddings specify Dirac distribution (default). distribution_type: dirac classifier_type: arcface trainer_params: optimizer_type: adam optimizer_params: lr: 3.0e-4
-
Download and unpack LFW dataset.
-
Run training with command:
python3 -m probabilistic_embeddings train \ --config <path-to-yaml-config> \ --train-root <logs-and-checkpoints-root> \ <path-to-lfw-data-root>
-
Logs and checkpoints will be saved to
./<logs-and-checkpoints-root>
. The default logging format is Tensorboard.
To enable WandB logging run the experiment with command:
WANDB_ENTITY=<entity-name> \
WANDB_API_KEY=<api-key> \
CUDA_VISIBLE_DEVICES=<gpu-index> \
python3 -m probabilistic_embeddings train \
--config <path-to-yaml-config> \
--logger wandb:<project-name>:<experiment-name> \
--train-root <logs-and-checkpoints-root> \
<path-to-dataset-root>
train
runs standard training pipeline:
python3 -m probabilistic_embeddings train \
--config <path-to-yaml-config> \
--train-root <logs-and-checkpoints-root> \
<path-to-data-root>
To apply K-fold cross-validation scheme use cval
command.
test
computes metrics for a given checkpoint:
CUDA_VISIBLE_DEVICES=<gpu-index> \
python3 -m probabilistic_embeddings test \
--config <path-to-config> \
--checkpoint <path-to-checkpoint> \
<path-to-data-root>
evaluate
performs model evaluation over multiple random seeds.
Add num_evaluation_seeds
field to experiment config to specify number of random seeds.
Use evaluate-cval
command to evaluate with cross-validation. Add num_validation_folds
to dataset_params
to set the number of folds.
In order to run WandB sweeps, use hopt
and hopt-cval
commands.
Hyperparameter tuning is only supported with WandB logger.
CUDA_VISIBLE_DEVICES=<gpu-index> \
python3 -m probabilistic_embeddings hopt \
--config <path-to-config> \
--logger wandb:<project-name>:<experiment-name> \
--train-root <training-root> <path-to-data-root>
Hyperparameters to search and their ranges should be specified in config as in this example:
...
model_params:
...
classifier_type: arcface
classifier_params:
_hopt:
scale:
min: 1.0
max: 64.0
margin:
min: 0.0
max: 1.0
...
In order to reproduce all the results of the paper, you need to generate configs for all experiments:
mkdir configs/reality/generated
python scripts/configs/generate-reality.py \
configs/reality/templates/ \
configs/reality/generated/ \
--best configs/reality/best/
Our hyperparameter search results are stored in configs/reality/best
.
You can reproduce hyperparameter search with hopt
command and download best parameters from WandB.
To reproduce training and evaluation, please, refer to the commands above.
Repository supports multiple datasets.
Face recognition:
MS1MV2,
MS1MV3,
LFW, and
CASIA.
Image retrieval:
Cars196,
CUB200,
In-shop clothes (Inshop)
and Stanford Online Products (SOP).
We also implement multiple image classification datasets, please,
refer to ./src/probabilistic_embeddings/dataset
for more details.
Serialized datasets used in reality configs can be downloaded via the following links: Cars196, CUB200, In-shop clothes (Inshop) and Stanford Online Products (SOP).
If you use code from this repository in your project, please, cite our paper:
@inproceedings{pml2022,
title={Probabilistic Embeddings Revisited},
author={Ivan Karpukhin and Stanislav Dereka and Sergey Kolesnikov},
year={2022},
url={https://arxiv.org/pdf/2202.06768.pdf}