Skip to content

Commit

Permalink
support writing pandas DataFrame and Series
Browse files Browse the repository at this point in the history
  • Loading branch information
briancappello committed Aug 8, 2020
1 parent dc8a6e8 commit e832205
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 58 deletions.
46 changes: 14 additions & 32 deletions pymarketstore/grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import grpc
import logging
import numpy as np
import pandas as pd

from typing import List, Union

from .params import Params, ListSymbolsFormat
from .proto import marketstore_pb2 as proto
from .proto import marketstore_pb2_grpc as gp
from .results import QueryReply
from .utils import is_iterable
from .utils import is_iterable, timeseries_data_to_write_request

logger = logging.getLogger(__name__)

Expand All @@ -27,37 +28,18 @@ def query(self, params: Union[Params, List[Params]]) -> QueryReply:
reply = self.stub.Query(self._build_query(params))
return QueryReply.from_grpc_response(reply)

def write(self, recarray: np.array, tbk: str, isvariablelength: bool = False) -> proto.MultiServerResponse:
types = [
recarray.dtype[name].str.replace('<', '')
for name in recarray.dtype.names
]
names = recarray.dtype.names
data = [
bytes(memoryview(recarray[name]))
for name in recarray.dtype.names
]
length = len(recarray)
start_index = {tbk: 0}
lengths = {tbk: len(recarray)}

req = proto.MultiWriteRequest(requests=[
proto.WriteRequest(
data=proto.NumpyMultiDataset(
data=proto.NumpyDataset(
column_types=types,
column_names=names,
column_data=data,
length=length,
# data_shapes = [],
),
start_index=start_index,
lengths=lengths,
),
is_variable_length=isvariablelength,
)
])

def write(self, data: Union[pd.DataFrame, pd.Series, np.ndarray, np.recarray],
tbk: str,
isvariablelength: bool = False,
) -> proto.MultiServerResponse:
req = proto.MultiWriteRequest(requests=[dict(
data=dict(
data=timeseries_data_to_write_request(data, tbk),
start_index={tbk: 0},
lengths={tbk: len(data)},
),
is_variable_length=isvariablelength,
)])
return self.stub.Write(req)

def _build_query(self, params: Union[Params, List[Params]]) -> proto.MultiQueryRequest:
Expand Down
43 changes: 17 additions & 26 deletions pymarketstore/jsonrpc_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import numpy as np
import pandas as pd
import re
import requests

Expand All @@ -9,7 +10,7 @@
from .params import Params, ListSymbolsFormat
from .results import QueryReply
from .stream import StreamConn
from .utils import is_iterable
from .utils import is_iterable, timeseries_data_to_write_request

logger = logging.getLogger(__name__)

Expand All @@ -36,31 +37,21 @@ def query(self, params: Union[Params, List[Params]]) -> QueryReply:
])
return QueryReply.from_response(reply)

def write(self, recarray: np.array, tbk: str, isvariablelength: bool = False) -> str:
data = {}
data['types'] = [
recarray.dtype[name].str.replace('<', '')
for name in recarray.dtype.names
]
data['names'] = recarray.dtype.names
data['data'] = [
bytes(memoryview(recarray[name]))
for name in recarray.dtype.names
]
data['length'] = len(recarray)
data['startindex'] = {tbk: 0}
data['lengths'] = {tbk: len(recarray)}
write_request = {}
write_request['dataset'] = data
write_request['is_variable_length'] = isvariablelength
writer = {}
writer['requests'] = [write_request]

try:
return self.rpc.call("DataService.Write", **writer)
except requests.exceptions.ConnectionError:
raise requests.exceptions.ConnectionError(
"Could not contact server")
def write(self, data: Union[pd.DataFrame, pd.Series, np.ndarray, np.recarray],
tbk: str,
isvariablelength: bool = False,
) -> dict:
dataset = timeseries_data_to_write_request(data, tbk)
return self.rpc.call("DataService.Write", requests=[dict(
dataset=dict(
types=dataset['column_types'],
names=dataset['column_names'],
data=dataset['column_data'],
startindex={tbk: 0},
lengths={tbk: len(data)},
),
is_variable_length=isvariablelength,
)])

def list_symbols(self, fmt: ListSymbolsFormat = ListSymbolsFormat.SYMBOL) -> List[str]:
reply = self._request('DataService.ListSymbols', format=fmt.value)
Expand Down
52 changes: 52 additions & 0 deletions pymarketstore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,55 @@ def is_iterable(something: Any) -> bool:
:return: bool. true if something is a list, tuple or set
"""
return isinstance(something, (list, tuple, set))


def timeseries_data_to_write_request(data: Union[pd.DataFrame, pd.Series, np.ndarray, np.recarray],
tbk: str,
) -> dict:
if isinstance(data, (np.ndarray, np.recarray)):
return _np_array_to_dataset_params(data)
elif isinstance(data, pd.Series):
return _pd_series_to_dataset_params(data, tbk)
elif isinstance(data, pd.DataFrame):
return _pd_dataframe_to_dataset_params(data)
raise TypeError('data must be pd.DataFrame, pd.Series, np.ndarray, or np.recarray')


def _np_array_to_dataset_params(data: Union[np.ndarray, np.recarray]) -> dict:
if not data.dtype.names:
raise TypeError('numpy arrays must declare named column dtypes')

return dict(column_types=[data.dtype[name].str.replace('<', '')
for name in data.dtype.names],
column_names=list(data.dtype.names),
column_data=[bytes(memoryview(data[name]))
for name in data.dtype.names],
length=len(data))


def _pd_series_to_dataset_params(data: pd.Series, tbk: str) -> dict:
# single column of data (indexed by timestamp, eg from ohlcv_df['ColName'])
if data.index.name == 'Epoch':
epoch = bytes(memoryview(data.index.to_numpy(dtype='i8') // 10**9))
return dict(column_types=['i8', data.dtype.str.replace('<', '')],
column_names=['Epoch', data.name or tbk.split('/')[-1]],
column_data=[epoch, bytes(memoryview(data.to_numpy()))],
length=len(data))

# single row of data (named indexes for one timestamp, eg from ohlcv_df.iloc[N])
epoch = bytes(memoryview(data.name.to_numpy().astype(dtype='i8') // 10**9))
return dict(column_types=['i8'] + [data.dtype.str.replace('<', '')
for _ in range(0, len(data))],
column_names=['Epoch'] + data.index.to_list(),
column_data=[epoch] + [bytes(memoryview(val)) for val in data.array],
length=1)


def _pd_dataframe_to_dataset_params(data: pd.DataFrame) -> dict:
epoch = bytes(memoryview(data.index.to_numpy(dtype='i8') // 10**9))
return dict(column_types=['i8'] + [dtype.str.replace('<', '')
for dtype in data.dtypes],
column_names=['Epoch'] + data.columns.to_list(),
column_data=[epoch] + [bytes(memoryview(data[col].to_numpy()))
for col in data.columns],
length=len(data))
7 changes: 7 additions & 0 deletions tests/test_results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pandas as pd

from ast import literal_eval
from pymarketstore import results

Expand Down Expand Up @@ -35,6 +37,11 @@
'version': 'dev'}
""") # noqa: E501

btc_array = results.decode_responses(testdata1['responses'])[0]['BTC/1Min/OHLCV']
btc_bytes = testdata1['responses'][0]['result']['data']
btc_df = pd.DataFrame(btc_array).set_index('Epoch')
btc_df.index = pd.DatetimeIndex(btc_df.index * 10**9, tz='UTC')


def test_results():
reply = results.QueryReply.from_response(testdata1)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pandas as pd

from pymarketstore.utils import timeseries_data_to_write_request

from .test_results import btc_array, btc_bytes, btc_df


class TestTimeseriesDataToWriteRequest:
def test_np_array(self):
assert timeseries_data_to_write_request(btc_array, 'BTC/1Min/OHLCV') == dict(
column_data=btc_bytes,
column_names=['Epoch', 'Open', 'High', 'Low', 'Close', 'Volume'],
column_types=['i8', 'f8', 'f8', 'f8', 'f8', 'f8'],
length=5,
)

def test_pd_series_indexed_by_timestamp(self):
series = pd.Series(btc_df.Open, index=btc_df.index)
assert timeseries_data_to_write_request(series, 'BTC/1Min/Open') == dict(
column_data=[btc_bytes[0], btc_bytes[1]],
column_names=['Epoch', 'Open'],
column_types=['i8', 'f8'],
length=5,
)

def test_pd_series_row_from_df(self):
series = btc_df.iloc[0]
expected_epoch = bytes(memoryview(series.name.to_numpy().astype(dtype='i8') // 10**9))
assert timeseries_data_to_write_request(series, 'BTC/1Min/OHLCV') == dict(
column_data=[expected_epoch] + [bytes(memoryview(val)) for val in series.array],
column_names=['Epoch', 'Open', 'High', 'Low', 'Close', 'Volume'],
column_types=['i8', 'f8', 'f8', 'f8', 'f8', 'f8'],
length=1,
)

def test_pd_dataframe(self):
assert timeseries_data_to_write_request(btc_df, 'BTC/1Min/OHLCV') == dict(
column_data=btc_bytes,
column_names=['Epoch', 'Open', 'High', 'Low', 'Close', 'Volume'],
column_types=['i8', 'f8', 'f8', 'f8', 'f8', 'f8'],
length=5,
)

0 comments on commit e832205

Please sign in to comment.