diff --git a/src/bioclip_vector_db/vector_db.py b/src/bioclip_vector_db/vector_db.py index f1df4e8..2afe6e2 100644 --- a/src/bioclip_vector_db/vector_db.py +++ b/src/bioclip_vector_db/vector_db.py @@ -31,10 +31,10 @@ def _get_device() -> torch.device: if torch.cuda.is_available(): - logger.info("CUDA is available") + logger.debug("CUDA is available") return torch.device("cuda") elif torch.mps.is_available(): - logger.info("MPS is available") + logger.debug("MPS is available") return torch.device("mps") else: logger.warning("CUDA and MPS are not available. Default to CPU") @@ -55,6 +55,7 @@ def __init__( split: str, local_dataset: str = None, batch_size: int = 10, + fp_precision: str = "fp16", ): self._dataset_type = dataset_type self._classifier = TreeOfLifeClassifier(device=_get_device()) @@ -64,6 +65,7 @@ def __init__( self._collection = None self._use_local_dataset = local_dataset is not None self._batch_size = batch_size + self._fp_precision = torch.float16 if fp_precision == "fp16" else torch.float32 self._prepare_dataset(split=split, local_dataset=local_dataset) self._init_collection() @@ -181,12 +183,14 @@ def _load_database_local(self): data_batch[2], ) ) - embeddings = list( - map( - lambda x: x.tolist(), - self._classifier.create_image_features(imgs, normalize=True), + + with torch.autocast(_get_device().type, self._fp_precision): + embeddings = list( + map( + lambda x: x.tolist(), + self._classifier.create_image_features(imgs, normalize=True), + ) ) - ) self._collection.add(embeddings=embeddings, ids=ids, metadatas=taxon_tags) num_records += len(ids) @@ -252,6 +256,14 @@ def main(): help="Specifies the batch size which determine the number of datapoints which will be read at once from the dataset.", ) + parser.add_argument( + "--fp_precision", + type=str, + default="fp16", + choices=["fp16", "fp32"], + help="Floating point precision to use for the model.", + ) + args = parser.parse_args() dataset = args.dataset output_dir = args.output_dir @@ -269,6 +281,7 @@ def main(): split=split, local_dataset=local_dataset, batch_size=args.batch_size, + fp_precision=args.fp_precision, ) vdb.load_database(reset=args.reset)