forked from cblearn/cblearn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GPU Embedding Algorithms (cblearn#18)
* Torch implementations * CKL * FORTE * SOE * GNMDS * Embedding overview to Readme Co-authored-by: Leena C Vankadara <vleena@amazon.com> Co-authored-by: Michael Lohaus <mlohaus>
- Loading branch information
1 parent
53c9b49
commit 180a6d2
Showing
15 changed files
with
561 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,3 +131,5 @@ dmypy.json | |
.pyre/ | ||
|
||
.idea/ | ||
|
||
*.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from cblearn.embedding._ckl import CKL | ||
from cblearn.embedding._forte import FORTE | ||
from cblearn.embedding._gnmds import GNMDS | ||
from cblearn.embedding._soe import SOE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from typing import Optional, Union | ||
|
||
from sklearn.base import BaseEstimator | ||
from sklearn.utils import check_random_state | ||
import numpy as np | ||
|
||
from cblearn import utils | ||
from cblearn.embedding._base import TripletEmbeddingMixin | ||
from cblearn.embedding import _torch_utils | ||
|
||
|
||
class CKL(BaseEstimator, TripletEmbeddingMixin): | ||
""" Crowd Kernel Learning (CKL) embedding kernel for triplet data. | ||
CKL [1]_ searches for an Euclidean representation of objects. | ||
The model is regularized through the rank of the embedding's kernel matrix. | ||
This estimator supports multiple implementations which can be selected by the `backend` parameter. | ||
The *torch* backend uses the ADAM optimizer and backpropagation [2]_. | ||
It can executed on CPU, but also CUDA GPUs. | ||
.. note:: | ||
The *torch* backend requires the *pytorch* python package (see :ref:`extras_install`). | ||
Attributes: | ||
embedding_: Final embedding, shape (n_objects, n_components) | ||
stress_: Final value of the SOE stress corresponding to the embedding. | ||
n_iter_: Final number of optimization steps. | ||
Examples: | ||
>>> from cblearn import datasets | ||
>>> np.random.seed(42) | ||
>>> true_embedding = np.random.rand(15, 2) | ||
>>> triplets = datasets.make_random_triplets(true_embedding, result_format='list-order', size=1000) | ||
>>> triplets.shape, np.unique(triplets).shape | ||
((1000, 3), (15,)) | ||
>>> estimator = CKL(n_components=2) | ||
>>> embedding = estimator.fit_transform(triplets) | ||
>>> embedding.shape | ||
(15, 2) | ||
>>> round(estimator.score(triplets), 1) > 0.6 | ||
True | ||
>>> estimator = CKL(n_components=2, kernel=True) | ||
>>> embedding = estimator.fit_transform(triplets) | ||
>>> embedding.shape | ||
(15, 2) | ||
References | ||
---------- | ||
.. [1] Tamuz, O., & Liu, mu., & Belognie, S., & Shamir, O., & Kalai, A.T. (2011). | ||
Adaptively Learning the Crowd Kernel. International Conference on Machine Learning. | ||
.. [2] Vankadara, L. C., Haghiri, S., Lohaus, M., Wahab, F. U., & von Luxburg, U. (2020). | ||
Insights into Ordinal Embedding Algorithms: A Systematic Evaluation. ArXiv:1912.01666 [Cs, Stat]. | ||
""" | ||
|
||
def __init__(self, n_components=2, mu=0.0, verbose=False, | ||
random_state: Union[None, int, np.random.RandomState] = None, max_iter=2000, | ||
backend: str = 'torch', kernel: bool = False, learning_rate=None, batch_size=50000, | ||
device: str = "auto"): | ||
""" Initialize the estimator. | ||
Args: | ||
n_components: The dimension of the embedding. | ||
mu: Regularization parameter >= 0. Increased mu serves as increasing a margin constraint. | ||
verbose: Enable verbose output. | ||
random_state: The seed of the pseudo random number generator used to initialize the optimization. | ||
max_iter: Maximum number of optimization iterations. | ||
backend: The optimization backend for fitting. {"torch"} | ||
kernel: Whether to optimize the kernel or the embedding (default). | ||
learning_rate: Learning rate of the gradient-based optimizer. | ||
If None, then 100 is used, or 1 if kernel=True. | ||
Only used with *torch* backend, else ignored. | ||
batch_size: Batch size of stochastic optimization. Only used with the *torch* backend, else ignored. | ||
device: The device on which pytorch computes. {"auto", "cpu", "cuda"} | ||
"auto" chooses cuda (GPU) if available, but falls back on cpu if not. | ||
Only used with the *torch* backend, else ignored. | ||
""" | ||
self.n_components = n_components | ||
self.max_iter = max_iter | ||
self.mu = mu | ||
self.learning_rate = learning_rate | ||
self.batch_size = batch_size | ||
self.kernel = kernel | ||
self.verbose = verbose | ||
self.random_state = random_state | ||
self.backend = backend | ||
self.device = device | ||
|
||
def fit(self, X: utils.Questions, y: np.ndarray = None, init: np.ndarray = None, | ||
n_objects: Optional[int] = None) -> 'CKL': | ||
"""Computes the embedding. | ||
Args: | ||
X: The training input samples, shape (n_samples, 3) | ||
y: Ignored | ||
init: Initial embedding for optimization | ||
Returns: | ||
self. | ||
""" | ||
triplets = utils.check_triplet_answers(X, y, result_format='list-order') | ||
if not n_objects: | ||
n_objects = len(np.unique(triplets)) | ||
random_state = check_random_state(self.random_state) | ||
if init is None: | ||
init = random_state.multivariate_normal( | ||
np.zeros(self.n_components), np.eye(self.n_components), size=n_objects) | ||
|
||
if self.backend != 'torch': | ||
raise ValueError(f"Invalid backend '{self.backend}'") | ||
|
||
_torch_utils.assert_torch_is_available() | ||
if self.kernel: | ||
result = _torch_utils.torch_minimize_kernel( | ||
'adam', _ckl_kernel_loss_torch, init, data=[triplets.astype(int)], args=(self.mu,), | ||
device=self.device, max_iter=self.max_iter, batch_size=self.batch_size, lr=self.learning_rate or 100, | ||
seed=random_state.randint(1)) | ||
else: | ||
result = _torch_utils.torch_minimize( | ||
'adam', _ckl_x_loss_torch, init, data=(triplets.astype(int),), args=(self.mu,), | ||
device=self.device, max_iter=self.max_iter, lr=self.learning_rate or 1, | ||
seed=random_state.randint(1)) | ||
|
||
if self.verbose and not result.success: | ||
print(f"CKL's optimization failed with reason: {result.message}.") | ||
self.embedding_ = result.x.reshape(-1, self.n_components) | ||
self.stress_, self.n_iter_ = result.fun, result.nit | ||
return self | ||
|
||
|
||
def _ckl_x_loss_torch(embedding, triplets, mu): | ||
X = embedding[triplets.long()] | ||
x_i, x_j, x_k = X[:, 0, :], X[:, 1, :], X[:, 2, :] | ||
nominator = (x_i - x_k).norm(p=2, dim=1) ** 2 + mu | ||
denominator = (x_i - x_j).norm(p=2, dim=1) ** 2 + (x_i - x_k).norm(p=2, dim=1) ** 2 + 2 * mu | ||
return -1 * (nominator.log() - denominator.log()).sum() | ||
|
||
|
||
def _ckl_kernel_loss_torch(kernel_matrix, triplets, mu): | ||
triplets = triplets.long() | ||
diag = kernel_matrix.diag()[:, None] | ||
dist = -2 * kernel_matrix + diag + diag.transpose(0, 1) | ||
d_ij = dist[triplets[:, 0], triplets[:, 1]].squeeze() | ||
d_ik = dist[triplets[:, 0], triplets[:, 2]].squeeze() | ||
probability = (d_ik + mu).log() - (d_ij + d_ik + 2 * mu).log() | ||
return -probability.sum() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from typing import Optional, Union | ||
|
||
from sklearn.base import BaseEstimator | ||
from sklearn.utils import check_random_state | ||
import numpy as np | ||
|
||
from cblearn import utils | ||
from cblearn.embedding._base import TripletEmbeddingMixin | ||
from cblearn.embedding import _torch_utils | ||
|
||
|
||
class FORTE(BaseEstimator, TripletEmbeddingMixin): | ||
""" Fast Ordinal Triplet Embedding (FORTE). | ||
FORTE [1]_ minimizes a kernel version of the triplet hinge soft objective | ||
as a smooth relaxation of the triplet error. | ||
This estimator supports multiple implementations which can be selected by the `backend` parameter. | ||
The *torch* backend uses the ADAM optimizer and backpropagation [2]_. | ||
It can executed on CPU, but also CUDA GPUs. We optimize using BFSGS and Strong-Wolfe line search. | ||
.. Note:: | ||
The *torch* backend requires the *pytorch* python package (see :ref:`extras_install`). | ||
Attributes: | ||
embedding_: Final embedding, shape (n_objects, n_components) | ||
stress_: Final value of the SOE stress corresponding to the embedding. | ||
n_iter_: Final number of optimization steps. | ||
Examples: | ||
>>> from cblearn import datasets | ||
>>> np.random.seed(42) | ||
>>> true_embedding = np.random.rand(15, 2) | ||
>>> triplets = datasets.make_random_triplets(true_embedding, result_format='list-order', size=1000) | ||
>>> triplets.shape, np.unique(triplets).shape | ||
((1000, 3), (15,)) | ||
>>> estimator = FORTE(n_components=2) | ||
>>> embedding = estimator.fit_transform(triplets) | ||
>>> embedding.shape | ||
(15, 2) | ||
>>> estimator.score(triplets) > 0.6 | ||
True | ||
References | ||
---------- | ||
.. [1] Jain, L., Jamieson, K. G., & Nowak, R. (2016). Finite Sample Prediction and | ||
Recovery Bounds for Ordinal Embedding. Advances in Neural Information Processing Systems, 29. | ||
.. [2] Vankadara, L. C., Haghiri, S., Lohaus, M., Wahab, F. U., & von Luxburg, U. (2020). | ||
Insights into Ordinal Embedding Algorithms: A Systematic Evaluation. ArXiv:1912.01666 [Cs, Stat]. | ||
""" | ||
|
||
def __init__(self, n_components=2, verbose=False, random_state: Union[None, int, np.random.RandomState] = None, | ||
max_iter=2000, batch_size=50_000, device: str = "auto"): | ||
""" Initialize the estimator. | ||
Args: | ||
n_components : | ||
The dimension of the embedding. | ||
verbose: boolean, default=False | ||
Enable verbose output. | ||
random_state: | ||
The seed of the pseudo random number generator used to initialize the optimization. | ||
max_iter: Maximum number of optimization iterations. | ||
batch_size: Batch size of stochastic optimization. Only used with *torch* backend, else ignored. | ||
device: The device on which pytorch computes. {"auto", "cpu", "cuda"} | ||
"auto" chooses cuda (GPU) if available, but falls back on cpu if not. | ||
Only used with the *torch* backend, else ignored. | ||
""" | ||
self.n_components = n_components | ||
self.max_iter = max_iter | ||
self.verbose = verbose | ||
self.random_state = random_state | ||
self.device = device | ||
self.batch_size = batch_size | ||
|
||
def fit(self, X: utils.Questions, y: np.ndarray = None, init: np.ndarray = None, | ||
n_objects: Optional[int] = None) -> 'FORTE': | ||
"""Computes the embedding. | ||
Args: | ||
X: The training input samples, shape (n_samples, 3) | ||
y: Ignored | ||
init: Initial embedding for optimization | ||
Returns: | ||
self. | ||
""" | ||
triplets = utils.check_triplet_answers(X, y, result_format='list-order') | ||
if not n_objects: | ||
n_objects = len(np.unique(triplets)) | ||
random_state = check_random_state(self.random_state) | ||
if init is None: | ||
init = random_state.multivariate_normal(np.zeros(self.n_components), | ||
np.eye(self.n_components), size=n_objects) | ||
|
||
_torch_utils.assert_torch_is_available() | ||
result = _torch_utils.torch_minimize_kernel('l-bfgs-b', _torch_forte_loss, init, data=(triplets.astype(int),), | ||
device=self.device, max_iter=self.max_iter, | ||
seed=random_state.randint(1), | ||
batch_size=self.batch_size, line_search_fn='strong_wolfe') | ||
|
||
if self.verbose and not result.success: | ||
print(f"FORTE's optimization failed with reason: {result.message}.") | ||
self.embedding_ = result.x.reshape(-1, self.n_components) | ||
self.stress_, self.n_iter_ = result.fun, result.nit | ||
return self | ||
|
||
|
||
def _torch_forte_loss(kernel_matrix, triplets): | ||
triplets = triplets.long() | ||
diag = kernel_matrix.diag()[:, None] | ||
dist = -2 * kernel_matrix + diag + diag.transpose(0, 1) | ||
d_ij = dist[triplets[:, 0], triplets[:, 1]].squeeze() | ||
d_ik = dist[triplets[:, 0], triplets[:, 2]].squeeze() | ||
return (1 + (d_ij - d_ik).exp()).log().sum() |
Oops, something went wrong.