Skip to content

Commit

Permalink
Split get responses into own function
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Jan 5, 2024
1 parent f346b4e commit 0c8949b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 43 deletions.
62 changes: 44 additions & 18 deletions src/ert/dark_storage/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import io
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID

import pandas as pd
from fastapi.responses import Response

from ert.config import EnkfObservationImplementationType
from ert.libres_facade import LibresFacade
Expand All @@ -27,6 +29,28 @@ def get_response_names(ensemble: EnsembleReader) -> List[str]:
return result


def get_response(ensemble: EnsembleReader, key: str) -> pd.DataFrame:
if key in ensemble.get_summary_keyset():
data = ensemble.load_all_summary_data([key])
data = data[key].unstack(level="Date")
elif key in ensemble.get_gen_data_keyset():
key_parts = key.split("@")
key = key_parts[0]
report_step = int(key_parts[1]) if len(key_parts) > 1 else 0

try:
data = ensemble.load_gen_data(
key,
report_step,
None,
).T
except (ValueError, KeyError):
return pd.DataFrame()
else:
return pd.DataFrame()
return data


def data_for_key(
ensemble: EnsembleReader,
key: str,
Expand All @@ -38,30 +62,14 @@ def data_for_key(

if key.startswith("LOG10_"):
key = key[6:]
if key in ensemble.get_summary_keyset():
data = ensemble.load_all_summary_data([key], realization_index)
data = data[key].unstack(level="Date")
elif key in ensemble.get_gen_kw_keyset():
if key in ensemble.get_gen_kw_keyset():
data = ensemble.load_all_gen_kw_data(key.split(":")[0], realization_index)
if data.empty:
return pd.DataFrame()
data = data[key].to_frame().dropna()
data.columns = pd.Index([0])
elif key in ensemble.get_gen_data_keyset():
key_parts = key.split("@")
key = key_parts[0]
report_step = int(key_parts[1]) if len(key_parts) > 1 else 0

try:
data = ensemble.load_gen_data(
key,
report_step,
realization_index,
).T
except (ValueError, KeyError):
return pd.DataFrame()
else:
return pd.DataFrame()
data = get_response(ensemble, key)

try:
return data.astype(float)
Expand Down Expand Up @@ -124,3 +132,21 @@ def _prepare_x_axis(
return [pd.Timestamp(x).isoformat() for x in x_axis]

return [str(x) for x in x_axis]


def format_dataframe(df: pd.DataFrame, media_type) -> Response:
if media_type == "application/x-parquet":
df.columns = [str(s) for s in df.columns]
stream = io.BytesIO()
df.to_parquet(stream)
return Response(
content=stream.getvalue(),
media_type="application/x-parquet",
)
elif media_type == "application/json":
return Response(df.to_json(), media_type="application/json")
else:
return Response(
content=df.to_csv().encode(),
media_type="text/csv",
)
19 changes: 2 additions & 17 deletions src/ert/dark_storage/endpoints/records.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import io
from itertools import chain
from typing import Any, Dict, List, Mapping, Optional, Union
from uuid import UUID, uuid4

import pandas as pd
from fastapi import APIRouter, Body, Depends, File, Header, Request, UploadFile, status
from fastapi.responses import Response
from typing_extensions import Annotated

from ert.dark_storage import json_schema as js
from ert.dark_storage.common import (
data_for_key,
ensemble_parameters,
format_dataframe,
get_observation_name,
observations_for_obs_keys,
)
Expand Down Expand Up @@ -185,21 +184,7 @@ async def get_ensemble_record(
dataframe = pd.DataFrame(dataframe.loc[realization_index]).T

media_type = accept if accept is not None else "text/csv"
if media_type == "application/x-parquet":
dataframe.columns = [str(s) for s in dataframe.columns]
stream = io.BytesIO()
dataframe.to_parquet(stream)
return Response(
content=stream.getvalue(),
media_type="application/x-parquet",
)
elif media_type == "application/json":
return Response(dataframe.to_json(), media_type="application/json")
else:
return Response(
content=dataframe.to_csv().encode(),
media_type="text/csv",
)
return format_dataframe(dataframe, media_type)


@router.get("/ensembles/{ensemble_id}/records/{name}/labels", response_model=List[str])
Expand Down
28 changes: 20 additions & 8 deletions src/ert/dark_storage/endpoints/responses.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Union
from uuid import UUID

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Header, status
from fastapi.responses import Response
from typing_extensions import Annotated

from ert.dark_storage.common import data_for_key
from ert.dark_storage.common import format_dataframe, get_response
from ert.dark_storage.enkf import get_storage
from ert.storage import StorageReader

Expand All @@ -12,15 +14,25 @@
DEFAULT_STORAGE = Depends(get_storage)


@router.get("/ensembles/{ensemble_id}/responses/{response_name}/data")
@router.get(
"/ensembles/{ensemble_id}/responses/{response_name}/data",
responses={
status.HTTP_200_OK: {
"content": {
"application/json": {},
"text/csv": {},
"application/x-parquet": {},
}
}
},
)
async def get_ensemble_response_dataframe(
*,
db: StorageReader = DEFAULT_STORAGE,
ensemble_id: UUID,
response_name: str,
accept: Annotated[Union[str, None], Header()] = None,
) -> Response:
dataframe = data_for_key(db.get_ensemble(ensemble_id), response_name)
return Response(
content=dataframe.to_csv().encode(),
media_type="text/csv",
)
dataframe = get_response(db.get_ensemble(ensemble_id), response_name)
media_type = accept if accept is not None else "text/csv"
return format_dataframe(dataframe, media_type)

0 comments on commit 0c8949b

Please sign in to comment.