Skip to content

Add ability to specify the floating point precision #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/bioclip_vector_db/vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@

def _get_device() -> torch.device:
if torch.cuda.is_available():
logger.info("CUDA is available")
logger.debug("CUDA is available")
return torch.device("cuda")
elif torch.mps.is_available():
logger.info("MPS is available")
logger.debug("MPS is available")
return torch.device("mps")
else:
logger.warning("CUDA and MPS are not available. Default to CPU")
Expand All @@ -55,6 +55,7 @@ def __init__(
split: str,
local_dataset: str = None,
batch_size: int = 10,
fp_precision: str = "fp16",
):
self._dataset_type = dataset_type
self._classifier = TreeOfLifeClassifier(device=_get_device())
Expand All @@ -64,6 +65,7 @@ def __init__(
self._collection = None
self._use_local_dataset = local_dataset is not None
self._batch_size = batch_size
self._fp_precision = torch.float16 if fp_precision == "fp16" else torch.float32

self._prepare_dataset(split=split, local_dataset=local_dataset)
self._init_collection()
Expand Down Expand Up @@ -181,12 +183,14 @@ def _load_database_local(self):
data_batch[2],
)
)
embeddings = list(
map(
lambda x: x.tolist(),
self._classifier.create_image_features(imgs, normalize=True),

with torch.autocast(_get_device().type, self._fp_precision):
embeddings = list(
map(
lambda x: x.tolist(),
self._classifier.create_image_features(imgs, normalize=True),
)
)
)
self._collection.add(embeddings=embeddings, ids=ids, metadatas=taxon_tags)

num_records += len(ids)
Expand Down Expand Up @@ -252,6 +256,14 @@ def main():
help="Specifies the batch size which determine the number of datapoints which will be read at once from the dataset.",
)

parser.add_argument(
"--fp_precision",
type=str,
default="fp16",
choices=["fp16", "fp32"],
help="Floating point precision to use for the model.",
)

args = parser.parse_args()
dataset = args.dataset
output_dir = args.output_dir
Expand All @@ -269,6 +281,7 @@ def main():
split=split,
local_dataset=local_dataset,
batch_size=args.batch_size,
fp_precision=args.fp_precision,
)
vdb.load_database(reset=args.reset)

Expand Down