From 0c8949bcaebab158417c3dafcd1db5873003491c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Fri, 5 Jan 2024 15:32:27 +0100 Subject: [PATCH] Split get responses into own function --- src/ert/dark_storage/common.py | 62 +++++++++++++++------ src/ert/dark_storage/endpoints/records.py | 19 +------ src/ert/dark_storage/endpoints/responses.py | 28 +++++++--- 3 files changed, 66 insertions(+), 43 deletions(-) diff --git a/src/ert/dark_storage/common.py b/src/ert/dark_storage/common.py index 0cf4bfee031..2bf15d8cfd3 100644 --- a/src/ert/dark_storage/common.py +++ b/src/ert/dark_storage/common.py @@ -1,7 +1,9 @@ +import io from typing import Any, Dict, List, Optional, Sequence, Union from uuid import UUID import pandas as pd +from fastapi.responses import Response from ert.config import EnkfObservationImplementationType from ert.libres_facade import LibresFacade @@ -27,6 +29,28 @@ def get_response_names(ensemble: EnsembleReader) -> List[str]: return result +def get_response(ensemble: EnsembleReader, key: str) -> pd.DataFrame: + if key in ensemble.get_summary_keyset(): + data = ensemble.load_all_summary_data([key]) + data = data[key].unstack(level="Date") + elif key in ensemble.get_gen_data_keyset(): + key_parts = key.split("@") + key = key_parts[0] + report_step = int(key_parts[1]) if len(key_parts) > 1 else 0 + + try: + data = ensemble.load_gen_data( + key, + report_step, + None, + ).T + except (ValueError, KeyError): + return pd.DataFrame() + else: + return pd.DataFrame() + return data + + def data_for_key( ensemble: EnsembleReader, key: str, @@ -38,30 +62,14 @@ def data_for_key( if key.startswith("LOG10_"): key = key[6:] - if key in ensemble.get_summary_keyset(): - data = ensemble.load_all_summary_data([key], realization_index) - data = data[key].unstack(level="Date") - elif key in ensemble.get_gen_kw_keyset(): + if key in ensemble.get_gen_kw_keyset(): data = ensemble.load_all_gen_kw_data(key.split(":")[0], realization_index) if data.empty: return pd.DataFrame() data = data[key].to_frame().dropna() data.columns = pd.Index([0]) - elif key in ensemble.get_gen_data_keyset(): - key_parts = key.split("@") - key = key_parts[0] - report_step = int(key_parts[1]) if len(key_parts) > 1 else 0 - - try: - data = ensemble.load_gen_data( - key, - report_step, - realization_index, - ).T - except (ValueError, KeyError): - return pd.DataFrame() else: - return pd.DataFrame() + data = get_response(ensemble, key) try: return data.astype(float) @@ -124,3 +132,21 @@ def _prepare_x_axis( return [pd.Timestamp(x).isoformat() for x in x_axis] return [str(x) for x in x_axis] + + +def format_dataframe(df: pd.DataFrame, media_type) -> Response: + if media_type == "application/x-parquet": + df.columns = [str(s) for s in df.columns] + stream = io.BytesIO() + df.to_parquet(stream) + return Response( + content=stream.getvalue(), + media_type="application/x-parquet", + ) + elif media_type == "application/json": + return Response(df.to_json(), media_type="application/json") + else: + return Response( + content=df.to_csv().encode(), + media_type="text/csv", + ) diff --git a/src/ert/dark_storage/endpoints/records.py b/src/ert/dark_storage/endpoints/records.py index 5de1cf58247..4f7767ea98c 100644 --- a/src/ert/dark_storage/endpoints/records.py +++ b/src/ert/dark_storage/endpoints/records.py @@ -1,17 +1,16 @@ -import io from itertools import chain from typing import Any, Dict, List, Mapping, Optional, Union from uuid import UUID, uuid4 import pandas as pd from fastapi import APIRouter, Body, Depends, File, Header, Request, UploadFile, status -from fastapi.responses import Response from typing_extensions import Annotated from ert.dark_storage import json_schema as js from ert.dark_storage.common import ( data_for_key, ensemble_parameters, + format_dataframe, get_observation_name, observations_for_obs_keys, ) @@ -185,21 +184,7 @@ async def get_ensemble_record( dataframe = pd.DataFrame(dataframe.loc[realization_index]).T media_type = accept if accept is not None else "text/csv" - if media_type == "application/x-parquet": - dataframe.columns = [str(s) for s in dataframe.columns] - stream = io.BytesIO() - dataframe.to_parquet(stream) - return Response( - content=stream.getvalue(), - media_type="application/x-parquet", - ) - elif media_type == "application/json": - return Response(dataframe.to_json(), media_type="application/json") - else: - return Response( - content=dataframe.to_csv().encode(), - media_type="text/csv", - ) + return format_dataframe(dataframe, media_type) @router.get("/ensembles/{ensemble_id}/records/{name}/labels", response_model=List[str]) diff --git a/src/ert/dark_storage/endpoints/responses.py b/src/ert/dark_storage/endpoints/responses.py index d54825e1631..cdd9b2d4386 100644 --- a/src/ert/dark_storage/endpoints/responses.py +++ b/src/ert/dark_storage/endpoints/responses.py @@ -1,9 +1,11 @@ +from typing import Union from uuid import UUID -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Header, status from fastapi.responses import Response +from typing_extensions import Annotated -from ert.dark_storage.common import data_for_key +from ert.dark_storage.common import format_dataframe, get_response from ert.dark_storage.enkf import get_storage from ert.storage import StorageReader @@ -12,15 +14,25 @@ DEFAULT_STORAGE = Depends(get_storage) -@router.get("/ensembles/{ensemble_id}/responses/{response_name}/data") +@router.get( + "/ensembles/{ensemble_id}/responses/{response_name}/data", + responses={ + status.HTTP_200_OK: { + "content": { + "application/json": {}, + "text/csv": {}, + "application/x-parquet": {}, + } + } + }, +) async def get_ensemble_response_dataframe( *, db: StorageReader = DEFAULT_STORAGE, ensemble_id: UUID, response_name: str, + accept: Annotated[Union[str, None], Header()] = None, ) -> Response: - dataframe = data_for_key(db.get_ensemble(ensemble_id), response_name) - return Response( - content=dataframe.to_csv().encode(), - media_type="text/csv", - ) + dataframe = get_response(db.get_ensemble(ensemble_id), response_name) + media_type = accept if accept is not None else "text/csv" + return format_dataframe(dataframe, media_type)