Skip to content

Commit

Permalink
Merge branch 'main' into fix-io-error-during-pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
yunzheng authored May 17, 2024
2 parents b4e41b2 + 0865b50 commit 7622cda
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 11 deletions.
45 changes: 39 additions & 6 deletions flow/record/adapter/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import queue
import threading
from typing import Iterator, Union
from typing import Iterator, Optional, Union

import elasticsearch
import elasticsearch.helpers
Expand All @@ -22,9 +22,11 @@
[PROTOCOL]: http or https. Defaults to https when "+[PROTOCOL]" is omitted
Optional arguments:
[API_KEY]: base64 encoded api key to authenticate with (default: False)
[INDEX]: name of the index to use (default: records)
[VERIFY_CERTS]: verify certs of Elasticsearch instance (default: True)
[HASH_RECORD]: make record unique by hashing record [slow] (default: False)
[_META_*]: record metadata fields (default: None)
"""

log = logging.getLogger(__name__)
Expand All @@ -38,14 +40,25 @@ def __init__(
verify_certs: Union[str, bool] = True,
http_compress: Union[str, bool] = True,
hash_record: Union[str, bool] = False,
api_key: Optional[str] = None,
**kwargs,
) -> None:
self.index = index
self.uri = uri
verify_certs = str(verify_certs).lower() in ("1", "true")
http_compress = str(http_compress).lower() in ("1", "true")
self.hash_record = str(hash_record).lower() in ("1", "true")
self.es = elasticsearch.Elasticsearch(uri, verify_certs=verify_certs, http_compress=http_compress)

if not uri.lower().startswith(("http://", "https://")):
uri = "http://" + uri

self.es = elasticsearch.Elasticsearch(
uri,
verify_certs=verify_certs,
http_compress=http_compress,
api_key=api_key,
)

self.json_packer = JsonRecordPacker()
self.queue: queue.Queue[Union[Record, StopIteration]] = queue.Queue()
self.event = threading.Event()
Expand All @@ -58,25 +71,34 @@ def __init__(

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

self.metadata_fields = {}
for arg_key, arg_val in kwargs.items():
if arg_key.startswith("_meta_"):
self.metadata_fields[arg_key[6:]] = arg_val

def record_to_document(self, record: Record, index: str) -> dict:
"""Convert a record to a Elasticsearch compatible document dictionary"""
rdict = record._asdict()

# Store record metadata under `_record_metadata`
# Store record metadata under `_record_metadata`.
rdict_meta = {
"descriptor": {
"name": record._desc.name,
"hash": record._desc.descriptor_hash,
},
}

# Move all dunder fields to `_record_metadata` to avoid naming clash with ES.
dunder_keys = [key for key in rdict if key.startswith("_")]
for key in dunder_keys:
rdict_meta[key.lstrip("_")] = rdict.pop(key)
# remove _generated field from metadata to ensure determinstic documents

# Remove _generated field from metadata to ensure determinstic documents.
if self.hash_record:
rdict_meta.pop("generated", None)
rdict["_record_metadata"] = rdict_meta

rdict["_record_metadata"] = rdict_meta.copy()
rdict["_record_metadata"].update(self.metadata_fields)

document = {
"_index": index,
Expand Down Expand Up @@ -106,6 +128,7 @@ def streaming_bulk_thread(self) -> None:
):
if not ok:
log.error("Failed to insert %r", item)

self.event.set()

def write(self, record: Record) -> None:
Expand All @@ -129,14 +152,24 @@ def __init__(
verify_certs: Union[str, bool] = True,
http_compress: Union[str, bool] = True,
selector: Union[None, Selector, CompiledSelector] = None,
api_key: Optional[str] = None,
**kwargs,
) -> None:
self.index = index
self.uri = uri
self.selector = selector
verify_certs = str(verify_certs).lower() in ("1", "true")
http_compress = str(http_compress).lower() in ("1", "true")
self.es = elasticsearch.Elasticsearch(uri, verify_certs=verify_certs, http_compress=http_compress)

if not uri.lower().startswith(("http://", "https://")):
uri = "http://" + uri

self.es = elasticsearch.Elasticsearch(
uri,
verify_certs=verify_certs,
http_compress=http_compress,
api_key=api_key,
)

if not verify_certs:
# Disable InsecureRequestWarning of urllib3, caused by the verify_certs flag.
Expand Down
15 changes: 13 additions & 2 deletions flow/record/fieldtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,15 +663,15 @@ def __new__(cls, *args):
#
# This construction works around that by converting all path parts
# to strings first.
args = (str(arg) for arg in args)
args = tuple(str(arg) for arg in args)
elif isinstance(path_part, pathlib.PurePosixPath):
cls = posix_path
elif _is_windowslike_path(path_part):
# This handles any custom PurePath based implementations that have a windows
# like path separator (\).
cls = windows_path
if not PY_312:
args = (str(arg) for arg in args)
args = tuple(str(arg) for arg in args)
elif _is_posixlike_path(path_part):
# This handles any custom PurePath based implementations that don't have a
# windows like path separator (\).
Expand All @@ -684,13 +684,24 @@ def __new__(cls, *args):
obj = super().__new__(cls)
else:
obj = cls._from_parts(args)

obj._empty_path = False
if not args or args == ("",):
obj._empty_path = True
return obj

def __eq__(self, other: Any) -> bool:
if isinstance(other, str):
return str(self) == other or self == self.__class__(other)
if self._empty_path:
return isinstance(other, self.__class__) and other._empty_path
return super().__eq__(other)

def __str__(self) -> str:
if self._empty_path:
return ""
return super().__str__()

def __repr__(self) -> str:
return repr(str(self))

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ splunk = [
test = [
"flow.record[compression]",
"flow.record[avro]",
"flow.record[elastic]",
"duckdb; platform_python_implementation != 'PyPy' and python_version < '3.12'", # duckdb
"pytz; platform_python_implementation != 'PyPy' and python_version < '3.12'", # duckdb
]
Expand Down
53 changes: 53 additions & 0 deletions tests/test_elastic_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import json

import pytest

from flow.record import RecordDescriptor
from flow.record.adapter.elastic import ElasticWriter

MyRecord = RecordDescriptor(
"my/record",
[
("string", "field_one"),
("string", "field_two"),
],
)


@pytest.mark.parametrize(
"record",
[
MyRecord("first", "record"),
MyRecord("second", "record"),
],
)
def test_elastic_writer_metadata(record):
options = {
"_meta_foo": "some value",
"_meta_bar": "another value",
}

with ElasticWriter(uri="elasticsearch:9200", **options) as writer:
assert writer.metadata_fields == {"foo": "some value", "bar": "another value"}

assert writer.record_to_document(record, "some-index") == {
"_index": "some-index",
"_source": json.dumps(
{
"field_one": record.field_one,
"field_two": record.field_two,
"_record_metadata": {
"descriptor": {
"name": "my/record",
"hash": record._desc.descriptor_hash,
},
"source": None,
"classification": None,
"generated": record._generated.isoformat(),
"version": 1,
"foo": "some value",
"bar": "another value",
},
}
),
}
80 changes: 77 additions & 3 deletions tests/test_fieldtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

import flow.record.fieldtypes
from flow.record import RecordDescriptor, RecordReader, RecordWriter
from flow.record import RecordDescriptor, RecordReader, RecordWriter, fieldtypes
from flow.record.fieldtypes import (
PY_312,
TYPE_POSIX,
Expand Down Expand Up @@ -617,8 +617,8 @@ def test_path():
assert r.value is None

r = TestRecord("")
assert str(r.value) == "."
assert r.value == "."
assert str(r.value) == ""
assert r.value == ""

if os.name == "nt":
native_path_str = windows_path_str
Expand Down Expand Up @@ -1132,5 +1132,79 @@ def test_command_failed() -> None:
command(b"failed")


@pytest.mark.parametrize(
"path_cls",
[
fieldtypes.posix_path,
fieldtypes.windows_path,
fieldtypes.path,
],
)
def test_empty_path(path_cls) -> None:
# initialize with empty string
p1 = path_cls("")
assert p1 == ""
assert p1._empty_path
assert str(p1) == ""
assert p1 != path_cls(".")

# initialize without any arguments
p2 = path_cls()
assert p2 == ""
assert p2._empty_path
assert str(p2) == ""
assert p2 != path_cls(".")

assert p1 == p2


def test_empty_path_different_types() -> None:
assert fieldtypes.posix_path("") != fieldtypes.windows_path("")


def test_record_empty_path() -> None:
TestRecord = RecordDescriptor(
"test/path",
[
("path", "value"),
],
)

r = TestRecord()
assert r.value is None
assert repr(r) == "<test/path value=None>"

r = TestRecord("")
assert r.value == ""
assert repr(r) == "<test/path value=''>"


def test_empty_path_serialization(tmp_path) -> None:
TestRecord = RecordDescriptor(
"test/path",
[
("path", "value"),
],
)

# Test path value=None serialization
p_tmp_records = tmp_path / "none_path"
with RecordWriter(p_tmp_records) as writer:
record = TestRecord()
writer.write(record)
with RecordReader(p_tmp_records) as reader:
for record in reader:
assert record.value is None

# Test path value="" serialization
p_tmp_records = tmp_path / "empty_str"
with RecordWriter(p_tmp_records) as writer:
record = TestRecord("")
writer.write(record)
with RecordReader(p_tmp_records) as reader:
for record in reader:
assert record.value == ""


if __name__ == "__main__":
__import__("standalone_test").main(globals())

0 comments on commit 7622cda

Please sign in to comment.