forked from cblearn/cblearn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Datasets: Car, ImageNet v0.1 and v0.2, Things, Nature, Vogue, Material * Add utility to transform odd-one-out, n-select, and n-rank queries to triplets. * Add preprocessing methods for queries from object attributes * Fix dangling url for musicsim dataset * Use remote-data directive to indicate tests that depend on the internet, disable by default. * Add h5py to CI in order to run dataset tests
- Loading branch information
1 parent
180a6d2
commit fcc2816
Showing
29 changed files
with
1,341 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,16 @@ | ||
from ._musician_similarity import fetch_musician_similarity | ||
from ._food_similarity import fetch_food_similarity | ||
from ._material_similarity import fetch_material_similarity | ||
from ._nature_vogue_similarity import fetch_nature_scene_similarity | ||
from ._nature_vogue_similarity import fetch_vogue_cover_similarity | ||
from ._things_similarity import fetch_things_similarity | ||
from ._imagenet_similarity import fetch_imagenet_similarity | ||
from ._car_similarity import fetch_car_similarity | ||
|
||
from ._triplet_simulation import make_all_triplets | ||
from ._triplet_simulation import make_random_triplets | ||
|
||
from ._triplet_indices import make_all_triplet_indices | ||
from ._triplet_indices import make_random_triplet_indices | ||
from ._triplet_answers import triplet_answers | ||
from ._triplet_answers import noisy_triplet_answers | ||
from ._triplet_answers import noisy_triplet_answers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from pathlib import Path | ||
import logging | ||
import joblib | ||
import os | ||
from typing import Optional, Union | ||
import zipfile | ||
|
||
import numpy as np | ||
from sklearn.datasets import _base | ||
from sklearn.utils import check_random_state, Bunch | ||
|
||
|
||
ARCHIVE = _base.RemoteFileMetadata( | ||
filename='60_cars_data.zip', | ||
url='http://www.tml.cs.uni-tuebingen.de/team/luxburg/code_and_data/60_cars_data.zip', | ||
checksum=('5fa2ad932d48adf5cfe36bd16a08b25fd88d1519d974908f6ccbba769f629640')) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def fetch_car_similarity(data_home: Optional[os.PathLike] = None, download_if_missing: bool = True, | ||
shuffle: bool = True, random_state: Optional[np.random.RandomState] = None, | ||
return_triplets: bool = False) -> Union[Bunch, np.ndarray]: | ||
""" Load the 60-car dataset (most-central triplets). | ||
=================== ===================== | ||
Triplets 7097 | ||
Objects (Cars) 60 | ||
Query 3 cars, most-central | ||
=================== ===================== | ||
See :ref:`central_car_dataset` for a detailed description. | ||
>>> dataset = fetch_car_similarity(shuffle=False) # doctest: +REMOTE_DATA | ||
>>> dataset.class_name.tolist() # doctest: +REMOTE_DATA | ||
['OFF-ROAD / SPORT UTILITY VEHICLES', 'ORDINARY CARS', 'OUTLIERS', 'SPORTS CARS'] | ||
>>> dataset.triplet.shape # doctest: +REMOTE_DATA | ||
(7097, 3) | ||
Args: | ||
data_home : optional, default: None | ||
Specify another download and cache folder for the datasets. By default | ||
all scikit-learn data is stored in '~/scikit_learn_data' subfolders. | ||
download_if_missing : optional, default=True | ||
shuffle: default = True | ||
Shuffle the order of triplet constraints. | ||
random_state: optional, default = None | ||
Initialization for shuffle random generator | ||
return_triplets : boolean, default=False. | ||
If True, returns numpy array instead of a Bunch object. | ||
Returns: | ||
dataset : :class:`~sklearn.utils.Bunch` | ||
Dictionary-like object, with the following attributes. | ||
triplet : ndarray, shape (n_triplets, 3) | ||
Each row corresponding a triplet constraint. | ||
The columns represent the three indices shown per most-central question. | ||
response : ndarray, shape (n_triplets, ) | ||
The car per question (0, 1, or 2) that was selected as "most-central". | ||
rt_ms : ndarray, shape (n_triplets, ) | ||
Reaction time of the response in milliseconds. | ||
class_id : np.ndarray (60, ) | ||
The class assigned to each object. | ||
class_name : list (4) | ||
Names of the classes. | ||
DESCR : string | ||
Description of the dataset. | ||
triplets : numpy array (n_triplets, 3) | ||
Only present when `return_triplets=True`. | ||
Raises: | ||
IOError: If the data is not locally available, but download_if_missing=False | ||
""" | ||
|
||
data_home = Path(_base.get_data_home(data_home=data_home)) | ||
if not data_home.exists(): | ||
data_home.mkdir() | ||
|
||
filepath = Path(_base._pkl_filepath(data_home, 'car_centrality.pkz')) | ||
if not filepath.exists(): | ||
if not download_if_missing: | ||
raise IOError("Data not found and `download_if_missing` is False") | ||
|
||
logger.info('Downloading 60-car dataset from {} to {}'.format(ARCHIVE.url, data_home)) | ||
|
||
archive_path = _base._fetch_remote(ARCHIVE, dirname=data_home) | ||
with zipfile.ZipFile(archive_path) as zf: | ||
with zf.open('60_cars_data/survey_data.csv', 'r') as f: | ||
survey_data = np.loadtxt(f, dtype=str, delimiter=',', skiprows=1) | ||
|
||
joblib.dump(survey_data, filepath, compress=6) | ||
os.remove(archive_path) | ||
else: | ||
survey_data = joblib.load(filepath) | ||
|
||
class_map = { | ||
'ORDINARY CARS': [2, 6, 7, 8, 9, 10, 11, 12, 16, 17, 25, 32, 35, 36, 37, 38, | ||
39, 41, 44, 45, 46, 55, 58, 60], | ||
'SPORTS CARS': [15, 19, 20, 28, 40, 42, 47, 48, 49, 50, 51, 52, 54, 56, 59], | ||
'OFF-ROAD / SPORT UTILITY VEHICLES': [1, 3, 4, 5, 13, 14, 18, 22, 24, 26, 27, | ||
29, 31, 33, 34, 43, 57], | ||
'OUTLIERS': [21, 23, 30, 53], | ||
} | ||
class_names = np.asarray(sorted(class_map.keys())) | ||
classes = np.empty(60, dtype=int) | ||
for cls_ix, cls_name in enumerate(class_names): | ||
classes[np.array(class_map[cls_name]) - 1] = cls_ix | ||
|
||
if shuffle: | ||
random_state = check_random_state(random_state) | ||
shuffle_ix = random_state.permutation(len(survey_data)) | ||
survey_data = survey_data[shuffle_ix] | ||
|
||
raw_triplets = survey_data[:, [2, 3, 4]].astype(int) | ||
triplets = raw_triplets - 1 | ||
response = (survey_data[:, [1]].astype(int) == raw_triplets).nonzero()[1] | ||
rt_ms = survey_data[:, [5]].astype(float) | ||
if return_triplets: | ||
return triplets | ||
|
||
module_path = Path(__file__).parent | ||
with module_path.joinpath('descr', 'car_similarity.rst').open() as rst_file: | ||
fdescr = rst_file.read() | ||
|
||
return Bunch(triplet=triplets, | ||
response=response, | ||
rt_ms=rt_ms, | ||
class_id=classes, | ||
class_name=class_names, | ||
DESCR=fdescr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
from pathlib import Path | ||
import logging | ||
import joblib | ||
import os | ||
from os.path import join | ||
from typing import Optional, Union | ||
from urllib.request import urlretrieve | ||
import zipfile | ||
|
||
import numpy as np | ||
from sklearn.datasets import _base | ||
from sklearn.utils import check_random_state, Bunch | ||
|
||
ARCHIVE = _base.RemoteFileMetadata( | ||
filename='osfstorage-archive.zip', | ||
url='https://files.osf.io/v1/resources/7f96y/providers/osfstorage/?zip=', | ||
checksum=('cannot check - zip involves randomness')) | ||
|
||
logger = logging.getLogger(__name__) | ||
__doctest_requires__ = {'fetch_imagenet_similarity': ['h5py']} | ||
|
||
|
||
def fetch_imagenet_similarity(data_home: Optional[os.PathLike] = None, download_if_missing: bool = True, | ||
shuffle: bool = True, random_state: Optional[np.random.RandomState] = None, | ||
version: str = '0.1', return_data: bool = False) -> Union[Bunch, np.ndarray]: | ||
""" Load the imagenet similarity dataset (rank 2 from 8). | ||
=================== ===================== | ||
Trials v0.1/v0.2 25,273 / 384,277 | ||
Objects (Images) 1,000 / 50,000 | ||
Classes 1,000 | ||
Query rank 2 from 8 | ||
=================== ===================== | ||
See :ref:`imagenet_similarity_dataset` for a detailed description. | ||
.. Note : | ||
Loading dataset requires the package `h5py`_, which was not installed as an dependency of cblearn. | ||
.. _`h5py`: https://docs.h5py.org/en/stable/build.html | ||
>>> dataset = fetch_imagenet_similarity(shuffle=True, version='0.1') # doctest: +REMOTE_DATA | ||
>>> dataset.class_label[[0, -1]].tolist() # doctest: +REMOTE_DATA | ||
['n01440764', 'n15075141'] | ||
>>> dataset.n_select, dataset.is_ranked # doctest: +REMOTE_DATA | ||
(2, True) | ||
>>> dataset.data.shape # doctest: +REMOTE_DATA | ||
(25273, 9) | ||
Args: | ||
data_home : optional, default: None | ||
Specify another download and cache folder for the datasets. By default | ||
all scikit-learn data is stored in '~/scikit_learn_data' subfolders. | ||
download_if_missing : optional, default=True | ||
shuffle: default = True | ||
Shuffle the order of triplet constraints. | ||
random_state: optional, default = None | ||
Initialization for shuffle random generator | ||
version: Version of the dataset. | ||
'0.1' contains one object per class, | ||
'0.2' 50 objects per class. | ||
return_triplets : boolean, default=False. | ||
If True, returns numpy array instead of a Bunch object. | ||
Returns: | ||
dataset : :class:`~sklearn.utils.Bunch` | ||
Dictionary-like object, with the following attributes. | ||
data : ndarray, shape (n_query, 9) | ||
Each row corresponding a rank-2-of-8 query, entries are object indices. | ||
The first column is the reference, the second column is the most similar, and the | ||
third column is the second most similar object. | ||
rt_ms : ndarray, shape (n_query, ) | ||
Reaction time in milliseconds. | ||
n_select : int | ||
Number of selected objects per trial. | ||
is_ranked : bool | ||
Whether the selection is ranked in similarity to the reference. | ||
session_id : (n_query,) | ||
Ids of the survey session for query recording. | ||
stimulus_id : (n_query,) | ||
Ids of the image stimulus (object). | ||
class_id : (n_query,) | ||
Imagenet class assigned to each image. | ||
class_label : (50000,) | ||
WordNet labels of the classes. | ||
DESCR : string | ||
Description of the dataset. | ||
data : numpy arrays (n_query, 9) | ||
Only present when `return_data=True`. | ||
Raises: | ||
IOError: If the data is not locally available, but download_if_missing=False | ||
""" | ||
data_home = Path(_base.get_data_home(data_home=data_home)) | ||
if not data_home.exists(): | ||
data_home.mkdir() | ||
|
||
filepath = Path(_base._pkl_filepath(data_home, 'imagenet_similarity.pkz')) | ||
if not filepath.exists(): | ||
if not download_if_missing: | ||
raise IOError("Data not found and `download_if_missing` is False") | ||
|
||
logger.info('Downloading imagenet similarity data from {} to {}'.format(ARCHIVE.url, data_home)) | ||
|
||
archive_path = (ARCHIVE.filename if data_home is None | ||
else join(data_home, ARCHIVE.filename)) | ||
urlretrieve(ARCHIVE.url, archive_path) | ||
|
||
with zipfile.ZipFile(archive_path) as zf: | ||
import h5py | ||
|
||
with zf.open('val/obs/psiz0.4.1/obs-118.hdf5', 'r') as f: | ||
data_v1 = {k: np.asarray(v[()]) for k, v in h5py.File(f, mode='r').items()} | ||
|
||
with zf.open('val/obs/psiz0.4.1/obs-195.hdf5', 'r') as f: | ||
data_v2 = {k: np.asarray(v[()]) for k, v in h5py.File(f, mode='r').items()} | ||
|
||
with zf.open('val/catalogs/psiz0.4.1/catalog.hdf5', 'r') as f: | ||
catalog = {k: np.asarray(v[()]) for k, v in h5py.File(f, mode='r').items()} | ||
|
||
joblib.dump((data_v1, data_v2, catalog), filepath, compress=6) | ||
os.remove(archive_path) | ||
else: | ||
(data_v1, data_v2, catalog) = joblib.load(filepath) | ||
|
||
if str(version) == '0.1': | ||
data = data_v1 | ||
elif str(version) == '0.2': | ||
data = data_v2 | ||
else: | ||
raise ValueError(f"Expects version '0.1' or '0.2', got '{version}'.") | ||
|
||
data.pop('trial_type') | ||
catalog['class_map_label'] = catalog['class_map_label'].astype(str) | ||
|
||
if shuffle: | ||
random_state = check_random_state(random_state) | ||
ix = random_state.permutation(len(data['stimulus_set'])) | ||
data = {k: v[ix] for k, v in data.items()} | ||
|
||
if return_data: | ||
return data['stimulus_set'] | ||
|
||
module_path = Path(__file__).parent | ||
with module_path.joinpath('descr', 'imagenet_similarity.rst').open() as rst_file: | ||
fdescr = rst_file.read() | ||
|
||
return Bunch(data=data['stimulus_set'], | ||
rt_ms=data['rt_ms'], | ||
n_select=int(np.unique(data['n_select'])), | ||
is_ranked=bool(np.unique(data['is_ranked'])), | ||
session_id=data['session_id'], | ||
stimulus_id=catalog['stimulus_id'], | ||
class_id=catalog['class_id'], | ||
class_label=catalog['class_map_label'][1:], | ||
DESCR=fdescr) |
Oops, something went wrong.