From cc180d3ec3bc0fd283b33924743066124b03def6 Mon Sep 17 00:00:00 2001 From: nmalfroy Date: Thu, 5 Sep 2024 11:17:20 -0400 Subject: [PATCH 1/7] Drive by: remove unused method parameter --- cellarium/cas/service.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cellarium/cas/service.py b/cellarium/cas/service.py index f878ed8..9e7d711 100644 --- a/cellarium/cas/service.py +++ b/cellarium/cas/service.py @@ -372,7 +372,7 @@ def get_user_quota(self) -> t.Dict[str, t.Any]: return self.get_json(endpoint=endpoints.GET_USER_QUOTA) def query_cells_by_ids( - self, model_name: str, cell_ids: t.List[int], metadata_feature_names: t.List[str] + self, cell_ids: t.List[int], metadata_feature_names: t.List[str] ) -> t.List[t.Dict[str, t.Any]]: """ Retrieve cells by their ids from Cellarium Cloud database. @@ -380,14 +380,12 @@ def query_cells_by_ids( Refer to API Docs: {api_url}/api/docs#/cell-analysis/get_cells_by_ids_api_cellarium_cas_query_cells_by_ids_post - :param model_name: Name of the model to use. Model name is required to locate the correct database. :param cell_ids: List of cell ids from Cellarium Cloud database to query by. :param metadata_feature_names: List of metadata feature names to include in the response. :return: List of cells with metadata. """ request_data = { - "model_name": model_name, "cas_cell_ids": cell_ids, "metadata_feature_names": metadata_feature_names, } From ac5b6eef334ca1898ca147caf2770ac5212cc05e Mon Sep 17 00:00:00 2001 From: nmalfroy Date: Thu, 5 Sep 2024 11:19:13 -0400 Subject: [PATCH 2/7] Add model classes that will be used by client --- cellarium/cas/constants.py | 29 ++++++ cellarium/cas/models.py | 188 +++++++++++++++++++++++++++++++++++++ requirements/base.txt | 3 +- 3 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 cellarium/cas/models.py diff --git a/cellarium/cas/constants.py b/cellarium/cas/constants.py index 5f397af..84dc543 100644 --- a/cellarium/cas/constants.py +++ b/cellarium/cas/constants.py @@ -49,3 +49,32 @@ class Headers: client_session_id = "x-client-session-id" # The client action id that is used to track a user's logical action that may span multiple requests. client_action_id = "x-client-action-id" + + +class CellMetadataFeatures(Enum): + """ + Represents the cell features that can be queried for in the CAS API. + + """ + + CAS_CELL_INDEX: str = "cas_cell_index" + CELL_TYPE: str = "cell_type" + ASSAY: str = "assay" + DISEASE: str = "disease" + DONOR_ID: str = "donor_id" + IS_PRIMARY_DATA: str = "is_primary_data" + DEVELOPMENT_STAGE: str = "development_stage" + ORGANISM: str = "organism" + SELF_REPORTED_ETHNICITY: str = "self_reported_ethnicity" + SEX: str = "sex" + SUSPENSION_TYPE: str = "suspension_type" + TISSUE: str = "tissue" + TOTAL_MRNA_UMIS: str = "total_mrna_umis" + CELL_TYPE_ONTOLOGY_TERM_ID: str = "cell_type_ontology_term_id" + ASSAY_ONTOLOGY_TERM_ID: str = "assay_ontology_term_id" + DISEASE_ONTOLOGY_TERM_ID: str = "disease_ontology_term_id" + DEVELOPMENT_STAGE_ONTOLOGY_TERM_ID: str = "development_stage_ontology_term_id" + ORGANISM_ONTOLOGY_TERM_ID: str = "organism_ontology_term_id" + SELF_REPORTED_ETHNICITY_ONTOLOGY_TERM_ID: str = "self_reported_ethnicity_ontology_term_id" + SEX_ONTOLOGY_TERM_ID: str = "sex_ontology_term_id" + TISSUE_ONTOLOGY_TERM_ID: str = "tissue_ontology_term_id" diff --git a/cellarium/cas/models.py b/cellarium/cas/models.py new file mode 100644 index 0000000..decf781 --- /dev/null +++ b/cellarium/cas/models.py @@ -0,0 +1,188 @@ +import typing as t + +from pydantic import BaseModel, Field + + +class CellTypeSummaryStatisticsResults(BaseModel): + """ + Represents the data object returned by the CAS API for nearest neighbor annotations. + """ + + class DatasetStatistics(BaseModel): + dataset_id: str = Field( + description="The ID of the dataset containing cells", examples=["a7a92fb49-50741b00a-244955d47"] + ) + count_per_dataset: int = Field(description="The number of cells found in the dataset", examples=[10]) + min_distance: float = Field( + description="The minimum distance between the query cell and the dataset cells", + examples=[1589.847900390625], + ) + max_distance: float = Field( + description="The maximum distance between the query cell and the dataset cells", + examples=[1840.047119140625], + ) + median_distance: float = Field( + description="The median distance between the query cell and the dataset cells", examples=[1791.372802734375] + ) + mean_distance: float = Field( + description="The mean distance between the query cell and the dataset cells", examples=[1791.372802734375] + ) + + class SummaryStatistics(BaseModel): + cell_type: str = Field(description="The cell type of the cluster of cells", examples=["erythrocyte"]) + cell_count: int = Field(description="The number of cells in the cluster", examples=[94]) + min_distance: float = Field( + description="The minimum distance between the query cell and the cluster cells", + examples=[1589.847900390625], + ) + p25_distance: float = Field( + description="The 25th percentile distance between the query cell and the cluster cells", + examples=[1664.875244140625], + ) + median_distance: float = Field( + description="The median distance between the query cell and the cluster cells", examples=[1791.372802734375] + ) + p75_distance: float = Field( + description="The 75th percentile distance between the query cell and the cluster cells", + examples=[1801.3585205078125], + ) + max_distance: float = Field( + description="The maximum distance between the query cell and the cluster cells", + examples=[1840.047119140625], + ) + dataset_ids_with_counts: t.Optional[t.List["CellTypeSummaryStatisticsResults.DatasetStatistics"]] = None + + class NeighborhoodAnnotation(BaseModel): + """ + Represents the data object returned by the CAS API for a single nearest neighbor annotation. + """ + + query_cell_id: str = Field(description="The ID of the querying cell", examples=["ATTACTTATTTAGTT-12311"]) + matches: t.List["CellTypeSummaryStatisticsResults.SummaryStatistics"] + + data: t.List["CellTypeSummaryStatisticsResults.NeighborhoodAnnotation"] = Field(description="The annotations found") + + +CellTypeSummaryStatisticsResults.model_rebuild() + + +class CellTypeOntologyAwareResults(BaseModel): + """ + Represents the data object returned by the CAS API for a ontology-aware annotations. + """ + + class Matches(BaseModel): + score: float = Field(description="The score of the match", examples=[0.789]) + cell_type_ontology_term_id: str = Field( + description="The ontology term ID of the cell type for the match", examples=["CL:0000121"] + ) + cell_type: str = Field(description="The cell type of the match", examples=["erythrocyte"]) + + class OntologyAwareAnnotation(BaseModel): + """ + Represents the data object returned by the CAS API for a single ontology-aware annotation. + """ + + query_cell_id: str = Field(description="The ID of the querying cell", examples=["ATTACTTATTTAGTT-12311"]) + matches: t.List["CellTypeOntologyAwareResults.Matches"] = Field( + description="The matches found for the querying cell" + ) + total_weight: float = Field(description="The total weight of the matches", examples=[11.23232]) + total_neighbors: int = Field(description="The total number of neighbors matched", examples=[1023]) + total_neighbors_unrecognized: int = Field( + description="The total number of neighbors that were not recognized", examples=[5] + ) + + data: t.List["CellTypeOntologyAwareResults.OntologyAwareAnnotation"] = Field(description="The annotations found") + + +CellTypeOntologyAwareResults.model_rebuild() + + +class MatrixQueryResults(BaseModel): + """ + Represents the data object returned by the CAS API when performing a cell matrix query + (e.g. a query of the cell database using a matrix). + """ + + class Matches(BaseModel): + cas_cell_index: float = Field(description="CAS-specific ID of a single cell", examples=[123]) + distance: float = Field( + description="The distance between this querying cell and the found cell", examples=[0.123] + ) + + class MatrixQueryResult(BaseModel): + """ + Represents the data object returned by the CAS API for a single cell query. + """ + + query_cell_id: str = Field(description="The ID of the querying cell", examples=["ATTACTTATTTAGTT-12311"]) + neighbors: t.List["MatrixQueryResults.Matches"] + + data: t.List["MatrixQueryResults.MatrixQueryResult"] = Field(description="The results of the query") + + +MatrixQueryResults.model_rebuild() + + +class CellQueryResults(BaseModel): + """ + Represents the data object returned by the CAS API for a cell query. + """ + + class CellariumCellMetadata(BaseModel): + cas_cell_index: int = Field(description="The CAS-specific ID of the cell", examples=[123]) + cell_type: t.Optional[str] = Field(description="The cell type of the cell", examples=["enterocyte"]) + assay: t.Optional[str] = Field(description="The assay used to generate the cell", examples=["10x 3' v2"]) + disease: t.Optional[str] = Field(description="The disease state of the cell", examples=["glioblastoma"]) + donor_id: t.Optional[str] = Field(description="The ID of the donor of the cell", examples=["H20.33.013"]) + is_primary_data: t.Optional[bool] = Field(description="Whether the cell is primary data", examples=[True]) + development_stage: t.Optional[str] = Field( + description="The development stage of the cell donor", examples=["human adult stage"] + ) + organism: t.Optional[str] = Field(description="The organism of the cell", examples=["Homo sapiens"]) + self_reported_ethnicity: t.Optional[str] = Field( + description="The self reported ethnicity of the cell donor", examples=["Japanese"] + ) + sex: t.Optional[str] = Field(description="The sex of the cell donor", examples=["male"]) + suspension_type: t.Optional[str] = Field(description="The cell suspension types used", examples=["nucleus"]) + tissue: t.Optional[str] = Field( + description="The tissue-type that the cell was a part of", examples=["cerebellum"] + ) + total_mrna_umis: t.Optional[int] = Field( + description="The count of mRNA UMIs associated with this cell", examples=[24312] + ) + + # Ontology term IDs for the fields + cell_type_ontology_term_id: t.Optional[str] = Field( + description="The ID used by the ontology for the type of the cell", examples=["CL:0000121"] + ) + assay_ontology_term_id: t.Optional[str] = Field( + description="The ID used by the ontology for the assay used to generate the cell", examples=["EFO:0010550"] + ) + disease_ontology_term_id: t.Optional[str] = Field( + description="The ID used by the ontology for the disease state of the cell", examples=["PATO:0000461"] + ) + development_stage_ontology_term_id: t.Optional[str] = Field( + description="The ID used by the ontology for the development stage of the cell donor", + examples=["HsapDv:0000053"], + ) + organism_ontology_term_id: t.Optional[str] = Field( + description="The ID used by the ontology for the organism of the cell", examples=["NCBITaxon:9606"] + ) + self_reported_ethnicity_ontology_term_id: t.Optional[str] = Field( + description="The ID used by the ontology for the self reported ethnicity of the cell donor", + examples=["HANCESTRO:0019"], + ) + sex_ontology_term_id: t.Optional[str] = Field( + description="The ID used by the ontology for the sex of the cell donor", examples=["PATO:0000384"] + ) + tissue_ontology_term_id: t.Optional[str] = Field( + description="The ID used by the ontology for the tissue type that the cell was a part of", + examples=["UBERON:0002037"], + ) + + data: t.List["CellQueryResults.CellariumCellMetadata"] = Field(description="The metadata of the found cells") + + +CellQueryResults.model_rebuild() diff --git a/requirements/base.txt b/requirements/base.txt index 0a635b0..0261db3 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -9,4 +9,5 @@ Deprecated~=1.2 tqdm~=4.66 typing_extensions~=4.7.1 owlready2 -networkx \ No newline at end of file +networkx +pydantic==2.5.3 \ No newline at end of file From 1fecbb6152fa9e34eff53fb9a02f68d7cc9dbb0f Mon Sep 17 00:00:00 2001 From: nmalfroy Date: Thu, 5 Sep 2024 11:20:02 -0400 Subject: [PATCH 3/7] Update client to use new model classes to return --- cellarium/cas/client.py | 165 +++++---- .../cas/postprocessing/ontology_aware.py | 22 +- .../circular_tree_plot_umap_dash_app/app.py | 3 +- tests/unit/test_cas_client.py | 335 +++++++++++++++--- 4 files changed, 387 insertions(+), 138 deletions(-) diff --git a/cellarium/cas/client.py b/cellarium/cas/client.py index 5afe326..3fe3519 100644 --- a/cellarium/cas/client.py +++ b/cellarium/cas/client.py @@ -16,7 +16,7 @@ from cellarium.cas.service import action_context_manager -from . import _io, constants, exceptions, preprocessing, service, settings, version +from . import _io, constants, exceptions, models, preprocessing, service, settings, version @contextmanager @@ -44,7 +44,7 @@ class CASClient: `Default:` ``3`` """ - def _print_models(self, models): + def __print_models(self, models): s = "Allowed model list in Cellarium CAS:\n" for model in models: model_name = model["model_name"] @@ -56,7 +56,7 @@ def _print_models(self, models): s += f" - {model_name}\n Description: {description}\n Schema: {model_schema}\n Embedding dimension: {embedding_dimension}\n" - self._print(s) + self.__print(s) @action_context_manager() def __init__( @@ -70,10 +70,10 @@ def __init__( api_token=api_token, api_url=api_url, client_session_id=self.client_session_id ) - self._print(f"Connecting to the Cellarium Cloud backend with session {self.client_session_id}...") + self.__print(f"Connecting to the Cellarium Cloud backend with session {self.client_session_id}...") self.user_info = self.cas_api_service.validate_token() username = self.user_info["username"] - self._print(f"User is {username}") + self.__print(f"User is {username}") self.should_show_feedback = True if "should_ask_for_feedback" in self.user_info: self.should_show_feedback = self.user_info["should_ask_for_feedback"] @@ -92,9 +92,9 @@ def __init__( self._feature_schemas_cache = {} self.num_attempts_per_chunk = num_attempts_per_chunk - self._print(f"Authenticated in Cellarium Cloud v. {application_info['application_version']}") + self.__print(f"Authenticated in Cellarium Cloud v. {application_info['application_version']}") - self._print_models(self.model_objects_list) + self.__print_models(self.model_objects_list) @property def allowed_models_list(self): @@ -104,17 +104,17 @@ def allowed_models_list(self): return self.__allowed_models_list @staticmethod - def _get_number_of_chunks(adata, chunk_size): + def __get_number_of_chunks(adata, chunk_size): return math.ceil(len(adata) / chunk_size) @staticmethod - def _get_timestamp() -> str: + def __get_timestamp() -> str: return datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] - def _print(self, str_to_print: str) -> None: - print(f"* [{self._get_timestamp()}] {str_to_print}") + def __print(self, str_to_print: str) -> None: + print(f"* [{self.__get_timestamp()}] {str_to_print}") - def _render_feedback_link(self): + def __render_feedback_link(self): try: if settings.is_interactive_environment() and self.should_show_feedback: # only import IPython if we are in an interactive environment @@ -127,7 +127,7 @@ def _render_feedback_link(self): def feedback_opt_out(self): self.should_show_feedback = False self.user_info = self.cas_api_service.feedback_opt_out() - self._print("Successfully opted out. You will no longer receive requests to provide feedback.") + self.__print("Successfully opted out. You will no longer receive requests to provide feedback.") def validate_version(self): """ @@ -136,7 +136,7 @@ def validate_version(self): client_version = version.get_version() version_validation_info = self.cas_api_service.validate_version(version_str=client_version) if version_validation_info["is_valid"]: - self._print(f"Client version {client_version} is compatible with selected server.") + self.__print(f"Client version {client_version} is compatible with selected server.") else: raise exceptions.ClientTooOldError( f"Client version {client_version} is older than the minimum version for this server {version_validation_info['min_version']}. " @@ -194,24 +194,24 @@ def validate_and_sanitize_input_data( ) except exceptions.DataValidationError as e: if e.extra_features > 0: - self._print( + self.__print( f"The input data matrix has {e.extra_features} extra features compared to '{feature_schema_name}' " f"CAS schema ({len(cas_feature_schema_list)}). " f"Extra input features will be dropped." ) if e.missing_features > 0: - self._print( + self.__print( f"The input data matrix has {e.missing_features} missing features compared to " f"'{feature_schema_name}' CAS schema ({len(cas_feature_schema_list)}). " f"Missing features will be imputed with zeros." ) if e.extra_features == 0 and e.missing_features == 0: - self._print( + self.__print( f"Input datafile has all the necessary features as {feature_schema_name}, but it's still " f"incompatible because of the different order. The features will be reordered " f"according to {feature_schema_name}..." ) - self._print( + self.__print( f"The input data matrix contains all of the features specified in '{feature_schema_name}' " f"CAS schema but in a different order. The input features will be reordered according to " f"'{feature_schema_name}'" @@ -224,7 +224,7 @@ def validate_and_sanitize_input_data( feature_names_column_name=feature_names_column_name, ) else: - self._print(f"The input data matrix conforms with the '{feature_schema_name}' CAS schema.") + self.__print(f"The input data matrix conforms with the '{feature_schema_name}' CAS schema.") return new_adata def print_user_quota(self) -> None: @@ -232,7 +232,7 @@ def print_user_quota(self) -> None: Print the user's quota information """ user_quota = self.cas_api_service.get_user_quota() - self._print( + self.__print( f"User quota: {user_quota['quota']}, Remaining quota: {user_quota['remaining_quota']}, " f"Reset date: {user_quota['quota_reset_date']}" ) @@ -273,8 +273,8 @@ async def sharded_request_task(**callback_kwargs): results[chunk_index] = await service_request_callback(**callback_kwargs) except (exceptions.HTTPError5XX, exceptions.HTTPClientError) as e: - self._print(str(e)) - self._print( + self.__print(str(e)) + self.__print( f"Resubmitting chunk #{chunk_index + 1:2.0f} ({chunk_start_i:5.0f}, " f"{chunk_end_i:5.0f}) to CAS ..." ) @@ -282,16 +282,16 @@ async def sharded_request_task(**callback_kwargs): retry_delay = min(retry_delay * 2, settings.MAX_RETRY_DELAY) continue except exceptions.HTTPError401: - self._print("Unauthorized token. Please check your API token or request a new one.") + self.__print("Unauthorized token. Please check your API token or request a new one.") break except exceptions.HTTPError403 as e: - self._print(str(e)) + self.__print(str(e)) break except Exception as e: - self._print(f"Unexpected error: {e.__class__.__name__}; Message: {str(e)}") + self.__print(f"Unexpected error: {e.__class__.__name__}; Message: {str(e)}") break else: - self._print( + self.__print( f"Received the result for cell chunk #{chunk_index + 1:2.0f} ({chunk_start_i:5.0f}, " f"{chunk_end_i:5.0f}) ..." ) @@ -311,14 +311,14 @@ async def sharded_request(): i, j = 0, chunk_size tasks = [] semaphore = asyncio.Semaphore(settings.MAX_NUM_REQUESTS_AT_A_TIME) - number_of_chunks = self._get_number_of_chunks(adata, chunk_size=chunk_size) + number_of_chunks = self.__get_number_of_chunks(adata, chunk_size=chunk_size) results = [[] for _ in range(number_of_chunks)] for chunk_index in range(number_of_chunks): chunk = adata[i:j, :] chunk_start_i = i chunk_end_i = i + len(chunk) - self._print( + self.__print( f"Submitting cell chunk #{chunk_index + 1:2.0f} ({chunk_start_i:5.0f}, {chunk_end_i:5.0f}) " f"to CAS ..." ) @@ -373,7 +373,7 @@ def __postprocess_sharded_response( processed_response.append(query_item) if num_unannotated_cells > 0: - self._print(f"{num_unannotated_cells} cells were not processed by CAS") + self.__print(f"{num_unannotated_cells} cells were not processed by CAS") return processed_response @@ -382,6 +382,11 @@ def __postprocess_annotations( ) -> t.List[t.Dict[str, t.Any]]: """ Postprocess results by matching the order of cells in the response with the order of cells in the input + + :param query_response: List of dictionaries with annotations for each of the cells from input adata + :param adata: :class:`anndata.AnnData` instance to annotate + + :return: A list of dictionaries with annotations for each of the cells from input adata """ return self.__postprocess_sharded_response( query_response=query_response, @@ -395,6 +400,11 @@ def __postprocess_nearest_neighbor_search_response( """ Postprocess nearest neighbor search response by matching the order of cells in the response with the order of cells in the input + + :param query_response: List of dictionaries with annotations for each of the cells from input adata + :param adata: :class:`anndata.AnnData` instance to annotate + + :return: A list of dictionaries with nearest neighbor search results for each of the cells from input adata """ return self.__postprocess_sharded_response( query_response=query_response, @@ -481,8 +491,8 @@ def __prepare_input_for_sharded_request( cas_model = self._model_name_obj_map[cas_model_name] cas_model_name = cas_model["model_name"] - self._print(f"Cellarium CAS (Model ID: {cas_model_name})") - self._print(f"Total number of input cells: {len(adata)}") + self.__print(f"Cellarium CAS (Model ID: {cas_model_name})") + self.__print(f"Total number of input cells: {len(adata)}") self.__validate_cells_under_quota(cell_count=len(adata)) @@ -505,7 +515,7 @@ def annotate_anndata( feature_ids_column_name: str = "index", feature_names_column_name: t.Optional[str] = None, include_dev_metadata: bool = False, - ) -> t.List[t.Dict[str, t.Any]]: + ) -> models.CellTypeSummaryStatisticsResults: """ Send an instance of :class:`anndata.AnnData` to the Cellarium Cloud backend for annotations. The function splits the ``adata`` into smaller chunks and asynchronously sends them to the backend API service. Each chunk is @@ -532,7 +542,8 @@ def annotate_anndata( :param include_dev_metadata: Boolean indicating whether to include a breakdown of the number of cells by dataset - :return: A list of dictionaries with annotations for each of the cells from input adata + :return: A :class:`~.models.CellTypeSummaryStatisticsResults` object with annotations for each of the cells from the + adata input """ cas_model_name = self.default_model_name if cas_model_name == "default" else cas_model_name @@ -554,8 +565,10 @@ def annotate_anndata( }, ) result = self.__postprocess_annotations(results, adata) - self._print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") - self._render_feedback_link() + # cast the object to the correct type + result = models.CellTypeSummaryStatisticsResults(data=result) + self.__print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") + self.__render_feedback_link() return result @deprecated(version="1.4.3", reason="Use :meth:`annotate_matrix_cell_type_statistics_strategy` instead") @@ -569,7 +582,7 @@ def annotate_anndata_file( feature_ids_column_name: str = "index", feature_names_column_name: t.Optional[str] = None, include_dev_metadata: bool = False, - ) -> t.List[t.Dict[str, t.Any]]: + ) -> models.CellTypeSummaryStatisticsResults: """ Read the 'h5ad' file into a :class:`anndata.AnnData` matrix and apply the :meth:`annotate_anndata` method to it. @@ -594,7 +607,8 @@ def annotate_anndata_file( :param include_dev_metadata: Boolean indicating whether to include a breakdown of the number of cells per dataset - :return: A list of dictionaries with annotations for each of the cells from input adata + :return: A :class:`~.models.CellTypeSummaryStatisticsResults` object with annotations for each of the cells from + the input adata """ with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -621,7 +635,7 @@ def annotate_10x_h5_file( feature_ids_column_name: str = "index", feature_names_column_name: t.Optional[str] = None, include_dev_metadata: bool = False, - ) -> t.List[t.Dict[str, t.Any]]: + ) -> models.CellTypeSummaryStatisticsResults: """ Parse the 10x 'h5' matrix and apply the :meth:`annotate_anndata` method to it. @@ -644,7 +658,9 @@ def annotate_10x_h5_file( column. |br| `Default:` ``None`` :param include_dev_metadata: Boolean indicating whether to include a breakdown of the number of cells by dataset - :return: A list of dictionaries with annotations for each of the cells from input adata + + :return: A :class:`~.models.CellTypeSummaryStatisticsResults` object with annotations for each of the cells from + the input adata """ adata = _io.read_10x_h5(filepath) @@ -668,7 +684,7 @@ def annotate_matrix_cell_type_summary_statistics_strategy( include_extended_statistics: bool = True, cas_model_name: t.Optional[str] = None, feature_names_column_name: t.Optional[str] = None, - ) -> t.List[t.Dict[str, t.Any]]: + ) -> models.CellTypeSummaryStatisticsResults: """ Send an instance of :class:`anndata.AnnData` to the Cellarium Cloud backend for annotations. The function splits the ``adata`` into smaller chunks and asynchronously sends them to the backend API service. Each chunk is @@ -696,7 +712,8 @@ def annotate_matrix_cell_type_summary_statistics_strategy( column. |br| `Default:` ``None`` - :return: A list of dictionaries with annotations for each of the cells from input adata + :return: A :class:`~.models.CellTypeSummaryStatisticsResults` object with annotations for each of the cells from + the input adata """ if isinstance(matrix, str): matrix = _io.read_h5_or_h5ad(filename=matrix) @@ -721,9 +738,11 @@ def annotate_matrix_cell_type_summary_statistics_strategy( "include_extended_output": include_extended_statistics, }, ) - result = self.__postprocess_annotations(results, matrix) - self._print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") - self._render_feedback_link() + result = self.__postprocess_annotations(query_response=results, adata=matrix) + # cast the object to the correct type + result = models.CellTypeSummaryStatisticsResults(data=result) + self.__print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") + self.__render_feedback_link() return result @action_context_manager() @@ -737,7 +756,7 @@ def annotate_matrix_cell_type_ontology_aware_strategy( feature_names_column_name: t.Optional[str] = None, prune_threshold: float = DEFAULT_PRUNE_THRESHOLD, weighting_prefactor: float = DEFAULT_WEIGHTING_PREFACTOR, - ) -> t.List[t.Dict[str, t.Any]]: + ) -> models.CellTypeOntologyAwareResults: """ Send an instance of :class:`anndata.AnnData` to the Cellarium Cloud backend for annotations using ontology aware strategy . The function splits the ``adata`` into smaller chunks and asynchronously sends them to the @@ -768,7 +787,8 @@ def annotate_matrix_cell_type_ontology_aware_strategy( weighting_prefactor results in a steeper decay (weights drop off more quickly as distance increases), whereas a smaller absolute value results in a slower decay - :return: A list of dictionaries with annotations for each of the cells from input adata + :return: A :class:`~.models.CellTypeOntologyAwareResults` object with annotations for each of the cells from + the input adata """ if isinstance(matrix, str): matrix = _io.read_h5_or_h5ad(filename=matrix) @@ -794,8 +814,10 @@ def annotate_matrix_cell_type_ontology_aware_strategy( }, ) result = self.__postprocess_annotations(results, matrix) - self._print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") - self._render_feedback_link() + # cast the object to the correct type + result = models.CellTypeOntologyAwareResults(data=result) + self.__print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") + self.__render_feedback_link() return result @deprecated(version="1.4.3", reason="Use :meth:`search_matrix` instead") @@ -808,7 +830,7 @@ def search_anndata( count_matrix_input: constants.CountMatrixInput = constants.CountMatrixInput.X, feature_ids_column_name: str = "index", feature_names_column_name: t.Optional[str] = None, - ) -> t.List[t.Dict[str, t.Any]]: + ) -> models.MatrixQueryResults: """ Send an instance of :class:`anndata.AnnData` to the Cellarium Cloud backend for nearest neighbor search. The function splits the ``adata`` into smaller chunks and asynchronously sends them to the backend API service. @@ -834,7 +856,8 @@ def search_anndata( column. |br| `Default:` ``None`` - :return: A list of dictionaries with annotations for each of the cells from input adata + :return: A :class:`~.models.MatrixQueryResults` object with search results for each of the cells from + the input adata """ if chunk_size > settings.MAX_CHUNK_SIZE_SEARCH_METHOD: raise ValueError("Chunk size greater than 500 not supported yet.") @@ -856,7 +879,9 @@ def search_anndata( request_callback_kwargs={"model_name": cas_model_name}, ) result = self.__postprocess_nearest_neighbor_search_response(results, adata) - self._print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") + # cast the object to the correct type + result = models.MatrixQueryResults(data=result) + self.__print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") return result @deprecated(version="1.4.3", reason="Use :meth:`search_matrix` instead") @@ -869,7 +894,7 @@ def search_10x_h5_file( count_matrix_input: constants.CountMatrixInput = constants.CountMatrixInput.X, feature_ids_column_name: str = "index", feature_names_column_name: t.Optional[str] = None, - ) -> t.List[t.Dict[str, t.Any]]: + ) -> models.MatrixQueryResults: """ Parse the 10x 'h5' matrix and apply the :meth:`search_anndata` method to it. @@ -892,7 +917,8 @@ def search_10x_h5_file( column. |br| `Default:` ``None`` - :return: A list of dictionaries with annotations for each of the cells from input adata + :return: A :class:`~.models.MatrixQueryResults` object with search results for each of the cells from + the input adata """ adata = _io.read_10x_h5(filepath) @@ -914,7 +940,7 @@ def search_matrix( feature_ids_column_name: str = "index", cas_model_name: t.Optional[str] = None, feature_names_column_name: t.Optional[str] = None, - ) -> t.List[t.Dict[str, t.Any]]: + ) -> models.MatrixQueryResults: """ Send an instance of :class:`anndata.AnnData` to the Cellarium Cloud backend for nearest neighbor search. The function splits the ``adata`` into smaller chunks and asynchronously sends them to the backend API service. @@ -941,7 +967,8 @@ def search_matrix( column. |br| `Default:` ``None`` - :return: A list of dictionaries with annotations for each of the cells from input adata + :return: A :class:`~.models.MatrixQueryResults` object with search results for each of the cells from + the input adata """ if isinstance(matrix, str): matrix = _io.read_h5_or_h5ad(filename=matrix) @@ -966,36 +993,34 @@ def search_matrix( request_callback_kwargs={"model_name": cas_model_name}, ) result = self.__postprocess_nearest_neighbor_search_response(results, matrix) - self._print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") - self._render_feedback_link() + # cast the object to the correct type + result = models.MatrixQueryResults(data=result) + self.__print(f"Total wall clock time: {f'{time.time() - start:10.4f}'} seconds") + self.__render_feedback_link() return result @action_context_manager() def query_cells_by_ids( - self, cell_ids: t.List[int], model_name: t.Optional[str], metadata_feature_names: t.List[str] = None - ) -> t.List[t.Dict[str, t.Any]]: + self, cell_ids: t.List[int], metadata_feature_names: t.List[constants.CellMetadataFeatures] = None + ) -> models.CellQueryResults: """ Query cells by their ids from a single anndata file with Cellarium CAS. Input file should be validated and sanitized according to the model schema. :param cell_ids: List of cell ids to query - :param model_name: Model name to use for annotation. |br| - `Allowed Values:` Model name from the :attr:`allowed_models_list` list or ``None`` - keyword, which refers to the default selected model in the Cellarium backend. |br| - `Default:` ``None`` - :param metadata_feature_names: List of metadata feature names to include in the response. |br| + :param metadata_feature_names: List of metadata features to include in the response. |br| - :return: List of cells with metadata + :return: A :class:`~.models.CellQueryResults` object with cell query results """ - model_name = self.default_model_name if model_name is None else model_name results = self.cas_api_service.query_cells_by_ids( cell_ids=cell_ids, - model_name=model_name, metadata_feature_names=metadata_feature_names, ) - cells = self.__postprocess_query_cells_by_ids_response(query_response=results) - self._render_feedback_link() - return cells + result = self.__postprocess_query_cells_by_ids_response(query_response=results) + # cast the object to the correct type + result = models.CellQueryResults(data=result) + self.__render_feedback_link() + return result def validate_model_name(self, model_name: t.Optional[str] = None) -> None: """ diff --git a/cellarium/cas/postprocessing/ontology_aware.py b/cellarium/cas/postprocessing/ontology_aware.py index 6310f1a..4f60e7b 100644 --- a/cellarium/cas/postprocessing/ontology_aware.py +++ b/cellarium/cas/postprocessing/ontology_aware.py @@ -8,6 +8,8 @@ import scipy.sparse as sp from anndata import AnnData +from cellarium.cas.models import CellTypeOntologyAwareResults + from .cell_ontology.cell_ontology_cache import CL_CELL_ROOT_NODE, CL_EUKARYOTIC_CELL_ROOT_NODE, CellOntologyCache from .common import get_obs_indices_for_cluster @@ -22,7 +24,7 @@ def convert_cas_ontology_aware_response_to_score_matrix( - adata: AnnData, cas_ontology_aware_response: list, cl: CellOntologyCache + adata: AnnData, cas_ontology_aware_response: CellTypeOntologyAwareResults, cl: CellOntologyCache ) -> sp.csr_matrix: """ Generate a sparse matrix of CAS ontology-aware scores. @@ -35,7 +37,7 @@ def convert_cas_ontology_aware_response_to_score_matrix( :type adata: AnnData :param cas_ontology_aware_response: A list of CAS ontology-aware responses. - :type cas_ontology_aware_response: list + :type cas_ontology_aware_response: CellTypeOntologyAwareResults :param cl: A CellOntologyCache object containing the cell ontology information. :type cl: CellOntologyCache @@ -48,27 +50,27 @@ def convert_cas_ontology_aware_response_to_score_matrix( data = [] obs_values = adata.obs.index.values - for obs_idx, cas_cell_response in enumerate(cas_ontology_aware_response): - assert cas_cell_response["query_cell_id"] == obs_values[obs_idx] - for match in cas_cell_response["matches"]: + for obs_idx, cas_cell_response in enumerate(cas_ontology_aware_response.data): + assert cas_cell_response.query_cell_id == obs_values[obs_idx] + for match in cas_cell_response.matches: row.append(obs_idx) - col.append(cl.cl_names_to_idx_map[match["cell_type_ontology_term_id"]]) - data.append(match["score"]) + col.append(cl.cl_names_to_idx_map[match.cell_type_ontology_term_id]) + data.append(match.score) - n_obs = len(cas_ontology_aware_response) + n_obs = len(cas_ontology_aware_response.data) n_cl_names = len(cl.cl_names) return sp.coo_matrix((data, (row, col)), shape=(n_obs, n_cl_names)).tocsr() def insert_cas_ontology_aware_response_into_adata( - cas_ontology_aware_response: list, adata: AnnData, cl: CellOntologyCache + cas_ontology_aware_response: CellTypeOntologyAwareResults, adata: AnnData, cl: CellOntologyCache ) -> None: """ Inserts Cellarium CAS ontology aware response into `obsm` property of a provided AnnData file as a :class:`scipy.sparse.csr_matrix` named `cas_cl_scores`. :param cas_ontology_aware_response: The Cellarium CAS ontology aware response. - :type cas_ontology_aware_response: list + :type cas_ontology_aware_response: CellTypeOntologyAwareResults :param adata: The AnnData object to insert the response into. :type adata: AnnData diff --git a/cellarium/cas/visualization/circular_tree_plot_umap_dash_app/app.py b/cellarium/cas/visualization/circular_tree_plot_umap_dash_app/app.py index eb76e12..92affca 100644 --- a/cellarium/cas/visualization/circular_tree_plot_umap_dash_app/app.py +++ b/cellarium/cas/visualization/circular_tree_plot_umap_dash_app/app.py @@ -15,6 +15,7 @@ from dash.development.base_component import Component from plotly.express.colors import sample_colorscale +from cellarium.cas.models import CellTypeOntologyAwareResults from cellarium.cas.postprocessing import ( CAS_CL_SCORES_ANNDATA_OBSM_KEY, CellOntologyScoresAggregationDomain, @@ -175,7 +176,7 @@ class CASCircularTreePlotUMAPDashApp: def __init__( self, adata: AnnData, - cas_ontology_aware_response: list, + cas_ontology_aware_response: CellTypeOntologyAwareResults, cluster_label_obs_column: t.Optional[str] = None, aggregation_op: CellOntologyScoresAggregationOp = CellOntologyScoresAggregationOp.MEAN, aggregation_domain: CellOntologyScoresAggregationDomain = CellOntologyScoresAggregationDomain.OVER_THRESHOLD, diff --git a/tests/unit/test_cas_client.py b/tests/unit/test_cas_client.py index 5f406ea..9d69bc4 100644 --- a/tests/unit/test_cas_client.py +++ b/tests/unit/test_cas_client.py @@ -10,6 +10,7 @@ import scipy.sparse as sp from mockito import ANY, captor, mock, unstub, verify, when from mockito.matchers import ArgumentCaptor +from parameterized import parameterized from cellarium.cas import constants from cellarium.cas.client import CASClient @@ -35,8 +36,14 @@ def teardown_method(self) -> None: unstub() self.async_post_mocks = {} - def test_initialize(self): - self._mock_constructor_calls() + @parameterized.expand( + [ + (False), + (True), + ] + ) + def test_initialize(self, include_extended_output): + self.__mock_constructor_calls() cas_client = CASClient(api_token=TEST_TOKEN, api_url=TEST_URL) # Verify that the expected header values were sent. @@ -70,20 +77,31 @@ def test_initialize(self): assert len(sent_actions) == 1 - def test_annotate_matrix_cell_type_summary_statistics_strategy(self): num_cells = 10 - self._mock_constructor_calls() - self._mock_annotate_matrix_cell_type_summary_statistics_strategy_calls(num_cells=num_cells) + self.__mock_constructor_calls() + self.__mock_annotate_matrix_cell_type_summary_statistics_strategy_calls( + num_cells=num_cells, include_extended_output=include_extended_output + ) cas_client = CASClient(api_token=TEST_TOKEN, api_url=TEST_URL) response = cas_client.annotate_matrix_cell_type_summary_statistics_strategy( - matrix=self._mock_anndata_matrix(num_cells=num_cells), chunk_size=100 + matrix=self.__mock_anndata_matrix(num_cells=num_cells), + chunk_size=100, + include_extended_statistics=include_extended_output, ) - assert len(response) == num_cells - self._verify_headers( - urls=[ + assert len(response.data) == num_cells + + if include_extended_output: + for i in range(num_cells): + assert response.data[i].matches[0].dataset_ids_with_counts is not None + else: + for i in range(num_cells): + assert response.data[i].matches[0].dataset_ids_with_counts is None + + self.__verify_headers( + get_urls=[ f"{TEST_URL}/api/cellarium-general/validate-token", f"{TEST_URL}/api/cellarium-general/quota", f"{TEST_URL}/api/cellarium-general/application-info", @@ -95,24 +113,24 @@ def test_annotate_matrix_cell_type_summary_statistics_strategy(self): f"{TEST_URL}/api/cellarium-cell-operations/annotate-cell-type-summary-statistics-strategy" ], active_session_id=cas_client.client_session_id, - num__expected_actions=2, # one for initialization and one for the annotation + num_expected_actions=2, # one for initialization and one for the annotation ) def test_annotate_matrix_cell_type_summary_statistics_strategy_with_chunking(self): num_cells = 100 - self._mock_constructor_calls() - self._mock_annotate_matrix_cell_type_summary_statistics_strategy_calls(num_cells=num_cells) + self.__mock_constructor_calls() + self.__mock_annotate_matrix_cell_type_summary_statistics_strategy_calls(num_cells=num_cells) cas_client = CASClient(api_token=TEST_TOKEN, api_url=TEST_URL) # This should cause 10 chunks to be sent response = cas_client.annotate_matrix_cell_type_summary_statistics_strategy( - matrix=self._mock_anndata_matrix(num_cells=num_cells), chunk_size=10 + matrix=self.__mock_anndata_matrix(num_cells=num_cells), chunk_size=10 ) - assert len(response) == num_cells - self._verify_headers( - urls=[ + assert len(response.data) == num_cells + self.__verify_headers( + get_urls=[ f"{TEST_URL}/api/cellarium-general/validate-token", f"{TEST_URL}/api/cellarium-general/quota", f"{TEST_URL}/api/cellarium-general/application-info", @@ -124,29 +142,29 @@ def test_annotate_matrix_cell_type_summary_statistics_strategy_with_chunking(sel f"{TEST_URL}/api/cellarium-cell-operations/annotate-cell-type-summary-statistics-strategy" ], active_session_id=cas_client.client_session_id, - num__expected_actions=2, # one for initialization and one for the annotation + num_expected_actions=2, # one for initialization and one for the annotation ) def test_annotate_matrix_cell_type_summary_statistics_strategy_with_several_calls(self): num_cells = 100 - self._mock_constructor_calls() - self._mock_annotate_matrix_cell_type_summary_statistics_strategy_calls(num_cells=num_cells) + self.__mock_constructor_calls() + self.__mock_annotate_matrix_cell_type_summary_statistics_strategy_calls(num_cells=num_cells) cas_client = CASClient(api_token=TEST_TOKEN, api_url=TEST_URL) # This should cause 10 chunks to be sent response1 = cas_client.annotate_matrix_cell_type_summary_statistics_strategy( - matrix=self._mock_anndata_matrix(num_cells=num_cells), chunk_size=10 + matrix=self.__mock_anndata_matrix(num_cells=num_cells), chunk_size=10 ) - assert len(response1) == num_cells + assert len(response1.data) == num_cells response2 = cas_client.annotate_matrix_cell_type_summary_statistics_strategy( - matrix=self._mock_anndata_matrix(num_cells=num_cells), chunk_size=10 + matrix=self.__mock_anndata_matrix(num_cells=num_cells), chunk_size=10 ) - assert len(response2) == num_cells + assert len(response2.data) == num_cells - self._verify_headers( - urls=[ + self.__verify_headers( + get_urls=[ f"{TEST_URL}/api/cellarium-general/validate-token", f"{TEST_URL}/api/cellarium-general/quota", f"{TEST_URL}/api/cellarium-general/application-info", @@ -158,24 +176,101 @@ def test_annotate_matrix_cell_type_summary_statistics_strategy_with_several_call f"{TEST_URL}/api/cellarium-cell-operations/annotate-cell-type-summary-statistics-strategy" ], active_session_id=cas_client.client_session_id, - num__expected_actions=3, # one for initialization and one for *each* annotation call + num_expected_actions=3, # one for initialization and one for *each* annotation call + ) + + def test_annotate_matrix_cell_type_ontology_aware_strategy(self): + num_cells = 100 + self.__mock_constructor_calls() + self.__mock_annotate_matrix_cell_type_ontology_aware_strategy_calls(num_cells=num_cells) + + cas_client = CASClient(api_token=TEST_TOKEN, api_url=TEST_URL) + + # This should cause 10 chunks to be sent + response = cas_client.annotate_matrix_cell_type_ontology_aware_strategy( + matrix=self.__mock_anndata_matrix(num_cells=num_cells), chunk_size=10 + ) + assert len(response.data) == num_cells + + self.__verify_headers( + get_urls=[ + f"{TEST_URL}/api/cellarium-general/validate-token", + f"{TEST_URL}/api/cellarium-general/quota", + f"{TEST_URL}/api/cellarium-general/application-info", + f"{TEST_URL}/api/cellarium-general/list-models", + f"{TEST_URL}/api/cellarium-general/feature-schemas", + f"{TEST_URL}/api/cellarium-general/feature-schema/{TEST_SCHEMA}", + ], + async_post_urls=[f"{TEST_URL}/api/cellarium-cell-operations/annotate-cell-type-ontology-aware-strategy"], + active_session_id=cas_client.client_session_id, + num_expected_actions=2, # one for initialization and one for *each* annotation call + ) + + def test_search_nearest_neighbors_by_matrix(self): + num_cells = 100 + self.__mock_constructor_calls() + self.__mock_search_nearest_neighbor_by_matrix_calls(num_cells=num_cells) + + cas_client = CASClient(api_token=TEST_TOKEN, api_url=TEST_URL) + + # This should cause 10 chunks to be sent + response = cas_client.search_matrix(matrix=self.__mock_anndata_matrix(num_cells=num_cells), chunk_size=10) + assert len(response.data) == num_cells + + self.__verify_headers( + get_urls=[ + f"{TEST_URL}/api/cellarium-general/validate-token", + f"{TEST_URL}/api/cellarium-general/quota", + f"{TEST_URL}/api/cellarium-general/application-info", + f"{TEST_URL}/api/cellarium-general/list-models", + f"{TEST_URL}/api/cellarium-general/feature-schemas", + f"{TEST_URL}/api/cellarium-general/feature-schema/{TEST_SCHEMA}", + ], + async_post_urls=[f"{TEST_URL}/api/cellarium-cell-operations/nearest-neighbor-search"], + active_session_id=cas_client.client_session_id, + num_expected_actions=2, # one for initialization and one for *each* search call ) - def _mock_constructor_calls(self): + def test_query_cells(self): + num_cells = 1 + self.__mock_constructor_calls() + self.__mock_cell_query(num_cells=num_cells) + + cas_client = CASClient(api_token=TEST_TOKEN, api_url=TEST_URL) + + response = cas_client.query_cells_by_ids( + cell_ids=range(num_cells), + metadata_feature_names=[constants.CellMetadataFeatures.CELL_TYPE, constants.CellMetadataFeatures.ASSAY], + ) + assert len(response.data) == num_cells + + self.__verify_headers( + get_urls=[ + f"{TEST_URL}/api/cellarium-general/validate-token", + f"{TEST_URL}/api/cellarium-general/application-info", + f"{TEST_URL}/api/cellarium-general/list-models", + f"{TEST_URL}/api/cellarium-general/feature-schemas", + ], + post_urls=[f"{TEST_URL}/api/cellarium-cell-operations/query-cells-by-ids"], + active_session_id=cas_client.client_session_id, + num_expected_actions=2, # one for initialization and one for *each* search call + ) + + def __mock_constructor_calls(self): """ Mocks the calls made by the CASClient constructor """ - self._mock_response( + self.__mock_response( url=f"{TEST_URL}/api/cellarium-general/validate-token", status_code=200, response_body={"username": "foo", "email": "foo@bar.com"}, ) - self._mock_response( + self.__mock_response( url=f"{TEST_URL}/api/cellarium-general/application-info", status_code=200, response_body={"application_version": "1.0.0", "default_feature_schema": "foo"}, ) - self._mock_response( + self.__mock_response( url=f"{TEST_URL}/api/cellarium-general/list-models", status_code=200, response_body=[ @@ -188,42 +283,29 @@ def _mock_constructor_calls(self): } ], ) - self._mock_response( + self.__mock_response( url=f"{TEST_URL}/api/cellarium-general/feature-schemas", status_code=200, response_body=[ {"schema_name": TEST_SCHEMA}, ], ) - self._mock_response( + self.__mock_response( url=f"{TEST_URL}/api/cellarium-general/validate-client-version", status_code=200, response_body={"is_valid": True, "min_version": "1.4.0"}, method="post", ) - def _mock_annotate_matrix_cell_type_summary_statistics_strategy_calls( - self, num_cells: int = 3, num_features: int = 3 + def __mock_annotate_matrix_cell_type_summary_statistics_strategy_calls( + self, num_cells: int = 3, num_features: int = 3, include_extended_output: bool = False ): """ Mocks the calls made by the CASClient to do an annotation call with the summary statistics strategy """ - self._mock_response( - url=f"{TEST_URL}/api/cellarium-general/feature-schema/{TEST_SCHEMA}", - status_code=200, - response_body=[f"field{i}" for i in range(num_features)], - ) - self._mock_response( - url=f"{TEST_URL}/api/cellarium-general/quota", - status_code=200, - response_body={ - "user_id": 0, - "quota": 1000, - "remaining_quota": 1000, - "quota_reset_date": datetime.datetime.today() + 7 * datetime.timedelta(days=1), - }, - ) - self._mock_async_post_response( + self.__mock_pre_call_requests(num_features=num_features) + + self.__mock_async_post_response( url=f"{TEST_URL}/api/cellarium-cell-operations/annotate-cell-type-summary-statistics-strategy", status_code=200, response_body=[ @@ -238,6 +320,20 @@ def _mock_annotate_matrix_cell_type_summary_statistics_strategy_calls( "median_distance": 10.0, "p75_distance": 9.0, "max_distance": 13.0, + "dataset_ids_with_counts": ( + [ + { + "dataset_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "count_per_dataset": 10, + "min_distance": 1.0, + "max_distance": 20.0, + "median_distance": 10.0, + "mean_distance": 11.0, + } + ] + if include_extended_output + else None + ), } ], } @@ -245,13 +341,122 @@ def _mock_annotate_matrix_cell_type_summary_statistics_strategy_calls( ], ) - def _mock_anndata_matrix(self, num_features: int = 3, num_cells: int = 3) -> anndata.AnnData: + def __mock_annotate_matrix_cell_type_ontology_aware_strategy_calls(self, num_cells: int = 3, num_features: int = 3): + """ + Mocks the calls made by the CASClient to do an annotation call with the ontology aware strategy + """ + self.__mock_pre_call_requests(num_features=num_features) + + self.__mock_async_post_response( + url=f"{TEST_URL}/api/cellarium-cell-operations/annotate-cell-type-ontology-aware-strategy", + status_code=200, + response_body=[ + { + "query_cell_id": f"cell{i}", + "matches": [ + {"score": 1, "cell_type_ontology_term_id": "CL_0000000", "cell_type": "cell"}, + {"score": 0.9, "cell_type_ontology_term_id": "CL_0000540", "cell_type": "neuron"}, + {"score": 0.6, "cell_type_ontology_term_id": "CL_0000255", "cell_type": "eukaryotic cell"}, + { + "score": 0.5, + "cell_type_ontology_term_id": "CL_0000404", + "cell_type": "electrically signaling cell", + }, + ], + "total_weight": 286.18, + "total_neighbors": 500, + "total_neighbors_unrecognized": 0, + } + for i in range(num_cells) + ], + ) + + def __mock_search_nearest_neighbor_by_matrix_calls(self, num_cells: int = 3, num_features: int = 3): + """ + Mocks the calls made by the CASClient to do a nearest neighbor search + """ + self.__mock_pre_call_requests(num_features=num_features) + + self.__mock_async_post_response( + url=f"{TEST_URL}/api/cellarium-cell-operations/nearest-neighbor-search", + status_code=200, + response_body=[ + { + "query_cell_id": f"cell{i}", + "neighbors": [ + {"cas_cell_index": 123, "distance": 0.123}, + {"cas_cell_index": 456, "distance": 0.456}, + {"cas_cell_index": 789, "distance": 0.789}, + ], + } + for i in range(num_cells) + ], + ) + + def __mock_cell_query(self, num_cells: int = 3): + """ + Mocks the calls made by the CASClient to do a cell query + """ + + self.__mock_response( + url=f"{TEST_URL}/api/cellarium-cell-operations/query-cells-by-ids", + status_code=200, + method="post", + response_body=[ + { + "cas_cell_index": i, + "cell_type": "enterocyte", + "assay": "10x 3' v2", + "disease": "glioblastoma", + "donor_id": "H20.33.013", + "is_primary_data": True, + "development_stage": "human adult stage", + "organism": "Homo sapiens", + "self_reported_ethnicity": "Japanese", + "sex": "male", + "suspension_type": "nucleus", + "tissue": "cerebellum", + "total_mrna_umis": 24312, + "cell_type_ontology_term_id": "CL:0000121", + "assay_ontology_term_id": "EFO:0010550", + "disease_ontology_term_id": "PATO:0000461", + "development_stage_ontology_term_id": "HsapDv:0000053", + "organism_ontology_term_id": "NCBITaxon:9606", + "self_reported_ethnicity_ontology_term_id": "HANCESTRO:0019", + "sex_ontology_term_id": "PATO:0000384", + "tissue_ontology_term_id": "UBERON:0002037", + } + for i in range(num_cells) + ], + ) + + def __mock_pre_call_requests(self, num_features: int = 3): + """ + Mocks the calls made by the CASClient before the actual annotation or query/search calls + """ + self.__mock_response( + url=f"{TEST_URL}/api/cellarium-general/feature-schema/{TEST_SCHEMA}", + status_code=200, + response_body=[f"field{i}" for i in range(num_features)], + ) + self.__mock_response( + url=f"{TEST_URL}/api/cellarium-general/quota", + status_code=200, + response_body={ + "user_id": 0, + "quota": 1000, + "remaining_quota": 1000, + "quota_reset_date": datetime.datetime.today() + 7 * datetime.timedelta(days=1), + }, + ) + + def __mock_anndata_matrix(self, num_features: int = 3, num_cells: int = 3) -> anndata.AnnData: d = NP_RANDOM_STATE.randint(0, 500, size=(num_cells, num_features)) X = sp.csr_matrix(d) obs = pd.DataFrame(index=[f"cell{i}" for i in range(num_cells)]) return anndata.AnnData(X=X, obs=obs, dtype=np.float32) - def _mock_response( + def __mock_response( self, url: str, status_code: int, @@ -271,7 +476,7 @@ def _mock_response( else: raise ValueError(f"Unsupported method: {method}") - def _mock_async_post_response( + def __mock_async_post_response( self, url: str, status_code: int, response_body: t.Union[dict, list], post_data: t.Union[dict, list] = None ): # Mock response @@ -290,8 +495,13 @@ def _mock_async_post_response( when(aiohttp).ClientSession(connector=ANY, timeout=ANY).thenReturn(session) self.async_post_mocks[url] = session - def _verify_headers( - self, urls: t.List[str], async_post_urls: t.List[str], active_session_id: str, num__expected_actions: int + def __verify_headers( + self, + get_urls: t.List[str] = [], + post_urls: t.List[str] = [], + async_post_urls: t.List[str] = [], + active_session_id: str = None, + num_expected_actions: int = None, ): """ Verify that the expected header values were sent. @@ -301,7 +511,7 @@ def _verify_headers( sent_sessions: t.Set[str] = set() sent_actions: t.Set[str] = set() - for url in urls: + for url in get_urls: header_captor: ArgumentCaptor[t.Dict[str, any]] = captor() verify(requests, atleast=1).get(url=url, headers=header_captor) headers = header_captor.value @@ -312,6 +522,17 @@ def _verify_headers( if constants.Headers.client_action_id in headers: sent_actions.add(headers[constants.Headers.client_action_id]) + for url in post_urls: + header_captor: ArgumentCaptor[t.Dict[str, any]] = captor() + verify(requests, atleast=1).post(url=url, headers=header_captor, json=ANY) + headers = header_captor.value + if constants.Headers.authorization in headers: + sent_tokens.add(headers[constants.Headers.authorization]) + if constants.Headers.client_session_id in headers: + sent_sessions.add(headers[constants.Headers.client_session_id]) + if constants.Headers.client_action_id in headers: + sent_actions.add(headers[constants.Headers.client_action_id]) + for url in async_post_urls or []: header_captor: ArgumentCaptor[t.Dict[str, any]] = captor() verify(self.async_post_mocks[url], atleast=1).post(url, headers=header_captor, data=ANY) @@ -329,4 +550,4 @@ def _verify_headers( assert len(sent_sessions) == 1 assert str(active_session_id) in sent_sessions - assert len(sent_actions) == num__expected_actions + assert len(sent_actions) == num_expected_actions From af775720ff0e448f2e87909a774458b4989f1003 Mon Sep 17 00:00:00 2001 From: nmalfroy Date: Thu, 5 Sep 2024 11:21:02 -0400 Subject: [PATCH 4/7] Update quickstart tutorial to use new return object --- notebooks/quickstart_tutorial.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/quickstart_tutorial.ipynb b/notebooks/quickstart_tutorial.ipynb index aba3eb3..0004763 100644 --- a/notebooks/quickstart_tutorial.ipynb +++ b/notebooks/quickstart_tutorial.ipynb @@ -263,7 +263,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let us take a quick look at the anatomy of the CAS ontology-aware cell type query response. In brief, the response is a Python list with as many elements as the number of cells in the queried AnnData file:" + "Let us take a quick look at the anatomy of the CAS ontology-aware cell type query response. In brief, the response is a Python object of type CellTypeOntologyAwareResults with results that contain as many elements as the number of cells in the queried AnnData file:" ] }, { @@ -281,7 +281,7 @@ "metadata": {}, "outputs": [], "source": [ - "len(cas_ontology_aware_response)" + "len(cas_ontology_aware_response.data)" ] }, { @@ -297,7 +297,7 @@ "metadata": {}, "outputs": [], "source": [ - "cas_ontology_aware_response[2425]" + "cas_ontology_aware_response.data[2425]" ] }, { @@ -471,7 +471,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.9" } }, "nbformat": 4, From 68bbe6747373fb19a72082d31b476dba4f876f82 Mon Sep 17 00:00:00 2001 From: nmalfroy Date: Thu, 5 Sep 2024 11:21:20 -0400 Subject: [PATCH 5/7] Update docs to show new model --- docs/source/automodules/client.rst | 19 ++++++++++++++++++- docs/source/conf.py | 10 ++++++++++ requirements/docs.txt | 3 ++- tox.ini | 2 +- 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/docs/source/automodules/client.rst b/docs/source/automodules/client.rst index efede2c..a71af85 100644 --- a/docs/source/automodules/client.rst +++ b/docs/source/automodules/client.rst @@ -7,4 +7,21 @@ Client .. autoclass:: cellarium.cas.constants.CountMatrixInput :members: :undoc-members: - :member-order: bysource \ No newline at end of file + :member-order: bysource + +.. autoclass:: cellarium.cas.constants.CellMetadataFeatures + :members: + :undoc-members: + :member-order: bysource + +.. autopydantic_model:: cellarium.cas.models::CellTypeSummaryStatisticsResults + :member-order: bysource + +.. autopydantic_model:: cellarium.cas.models::CellTypeOntologyAwareResults + :member-order: bysource + +.. autopydantic_model:: cellarium.cas.models::MatrixQueryResults + :member-order: bysource + +.. autopydantic_model:: cellarium.cas.models::CellQueryResults + :member-order: bysource diff --git a/docs/source/conf.py b/docs/source/conf.py index 6c4906d..960425f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,6 +25,7 @@ "sphinx.ext.viewcode", "sphinx.ext.intersphinx", "sphinx_substitution_extensions", + "sphinxcontrib.autodoc_pydantic", ] # Provide substitutions for common values @@ -50,3 +51,12 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = "sphinx_rtd_theme" + +nitpicky = True +nitpick_ignore_regex = [ + # Ignore exceptions from nested Pydantic models + (r'py:.*', r'cellarium\.cas\.models\..*'), +] + +# The JSON schema is a bit much in the docs +autodoc_pydantic_model_show_json = False diff --git a/requirements/docs.txt b/requirements/docs.txt index 3452cb0..df3e15a 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -3,4 +3,5 @@ Sphinx~=7.4 sphinx_gallery~=0.14 sphinx_rtd_theme~=2.0 sphinx_substitution_extensions==2024.8.6 -setuptools-git-versioning==2.0.0 \ No newline at end of file +setuptools-git-versioning==2.0.0 +autodoc_pydantic==2.2.0 \ No newline at end of file diff --git a/tox.ini b/tox.ini index 55ee250..210a5a3 100644 --- a/tox.ini +++ b/tox.ini @@ -69,7 +69,7 @@ deps = changedir = {toxinidir}/docs commands = - make html SPHINXOPTS="-W --keep-going -n" + make html SPHINXOPTS="-W --keep-going" [gh-actions] From e7bd72873082ce0d7b4efc92209a5e316bc87591 Mon Sep 17 00:00:00 2001 From: nmalfroy Date: Thu, 5 Sep 2024 14:29:01 -0400 Subject: [PATCH 6/7] Rename Matches to Match --- cellarium/cas/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cellarium/cas/models.py b/cellarium/cas/models.py index decf781..3d7e51a 100644 --- a/cellarium/cas/models.py +++ b/cellarium/cas/models.py @@ -71,7 +71,7 @@ class CellTypeOntologyAwareResults(BaseModel): Represents the data object returned by the CAS API for a ontology-aware annotations. """ - class Matches(BaseModel): + class Match(BaseModel): score: float = Field(description="The score of the match", examples=[0.789]) cell_type_ontology_term_id: str = Field( description="The ontology term ID of the cell type for the match", examples=["CL:0000121"] @@ -84,7 +84,7 @@ class OntologyAwareAnnotation(BaseModel): """ query_cell_id: str = Field(description="The ID of the querying cell", examples=["ATTACTTATTTAGTT-12311"]) - matches: t.List["CellTypeOntologyAwareResults.Matches"] = Field( + matches: t.List["CellTypeOntologyAwareResults.Match"] = Field( description="The matches found for the querying cell" ) total_weight: float = Field(description="The total weight of the matches", examples=[11.23232]) @@ -105,7 +105,7 @@ class MatrixQueryResults(BaseModel): (e.g. a query of the cell database using a matrix). """ - class Matches(BaseModel): + class Match(BaseModel): cas_cell_index: float = Field(description="CAS-specific ID of a single cell", examples=[123]) distance: float = Field( description="The distance between this querying cell and the found cell", examples=[0.123] @@ -117,7 +117,7 @@ class MatrixQueryResult(BaseModel): """ query_cell_id: str = Field(description="The ID of the querying cell", examples=["ATTACTTATTTAGTT-12311"]) - neighbors: t.List["MatrixQueryResults.Matches"] + neighbors: t.List["MatrixQueryResults.Match"] data: t.List["MatrixQueryResults.MatrixQueryResult"] = Field(description="The results of the query") From 6e66f2da3cff81d351fb27a8946fd98663b19315 Mon Sep 17 00:00:00 2001 From: nmalfroy Date: Fri, 6 Sep 2024 12:56:23 -0400 Subject: [PATCH 7/7] Update changelog --- CHANGELOG.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c26e890..bdb0e97 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,9 +17,11 @@ and this project adheres to `Semantic Versioning