Skip to content

Commit 33c8063

Browse files
author
zilto
committed
added pymongoarrow_schema; linting
1 parent 39a6448 commit 33c8063

File tree

3 files changed

+160
-38
lines changed

3 files changed

+160
-38
lines changed

sources/mongodb/__init__.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def mongodb(
2424
parallel: Optional[bool] = dlt.config.value,
2525
limit: Optional[int] = None,
2626
filter_: Optional[Dict[str, Any]] = None,
27+
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
28+
pymongoarrow_schema: Optional[Any] = None
2729
) -> Iterable[DltResource]:
2830
"""
2931
A DLT source which loads data from a mongo database using PyMongo.
@@ -41,6 +43,13 @@ def mongodb(
4143
The maximum number of documents to load. The limit is
4244
applied to each requested collection separately.
4345
filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection.
46+
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields of a collection
47+
when loading the collection. Supported inputs:
48+
include (list) - ["year", "title"]
49+
include (dict) - {"year": True, "title": True}
50+
exclude (dict) - {"released": False, "runtime": False}
51+
Note: Can't mix include and exclude statements '{"title": True, "released": False}`
52+
pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types of a collection to convert BSON to Arrow
4453
4554
Returns:
4655
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
@@ -73,12 +82,15 @@ def mongodb(
7382
parallel=parallel,
7483
limit=limit,
7584
filter_=filter_ or {},
76-
projection=None,
85+
projection=projection,
86+
pymongoarrow_schema=pymongoarrow_schema,
7787
)
7888

7989

80-
@dlt.common.configuration.with_config(
81-
sections=("sources", "mongodb"), spec=MongoDbCollectionResourceConfiguration
90+
@dlt.resource(
91+
name=lambda args: args["collection"],
92+
standalone=True,
93+
spec=MongoDbCollectionResourceConfiguration,
8294
)
8395
def mongodb_collection(
8496
connection_url: str = dlt.secrets.value,
@@ -91,7 +103,8 @@ def mongodb_collection(
91103
chunk_size: Optional[int] = 10000,
92104
data_item_format: Optional[TDataItemFormat] = "object",
93105
filter_: Optional[Dict[str, Any]] = None,
94-
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
106+
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = dlt.config.value,
107+
pymongoarrow_schema: Optional[Any] = None
95108
) -> Any:
96109
"""
97110
A DLT source which loads a collection from a mongo database using PyMongo.
@@ -111,12 +124,13 @@ def mongodb_collection(
111124
object - Python objects (dicts, lists).
112125
arrow - Apache Arrow tables.
113126
filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection.
114-
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select columns
127+
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields
115128
when loading the collection. Supported inputs:
116129
include (list) - ["year", "title"]
117-
include (dict) - {"year": 1, "title": 1}
118-
exclude (dict) - {"released": 0, "runtime": 0}
119-
Note: Can't mix include and exclude statements '{"title": 1, "released": 0}`
130+
include (dict) - {"year": True, "title": True}
131+
exclude (dict) - {"released": False, "runtime": False}
132+
Note: Can't mix include and exclude statements '{"title": True, "released": False}`
133+
pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types to convert BSON to Arrow
120134
121135
Returns:
122136
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
@@ -145,4 +159,5 @@ def mongodb_collection(
145159
data_item_format=data_item_format,
146160
filter_=filter_ or {},
147161
projection=projection,
162+
pymongoarrow_schema=pymongoarrow_schema,
148163
)

sources/mongodb/helpers.py

+118-18
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def _filter_op(self) -> Dict[str, Any]:
107107
filt[self.cursor_field]["$gt"] = self.incremental.end_value
108108

109109
return filt
110-
111-
def _projection_op(self, projection) -> Optional[Dict[str, Any]]:
110+
111+
def _projection_op(self, projection:Optional[Union[Mapping[str, Any], Iterable[str]]]) -> Optional[Dict[str, Any]]:
112112
"""Build a projection operator.
113-
113+
114114
A tuple of fields to include or a dict specifying fields to include or exclude.
115115
The incremental `primary_key` needs to be handle differently for inclusion
116116
and exclusion projections.
@@ -123,17 +123,16 @@ def _projection_op(self, projection) -> Optional[Dict[str, Any]]:
123123

124124
projection_dict = dict(_fields_list_to_dict(projection, "projection"))
125125

126-
# NOTE we can still filter on primary_key if it's excluded from projection
127126
if self.incremental:
128127
# this is an inclusion projection
129-
if any(v == 1 for v in projection.values()):
128+
if any(v == 1 for v in projection_dict.values()):
130129
# ensure primary_key is included
131-
projection_dict.update({self.incremental.primary_key: 1})
130+
projection_dict.update(m={self.incremental.primary_key: 1})
132131
# this is an exclusion projection
133132
else:
134133
try:
135134
# ensure primary_key isn't excluded
136-
projection_dict.pop(self.incremental.primary_key)
135+
projection_dict.pop(self.incremental.primary_key) # type: ignore
137136
except KeyError:
138137
pass # primary_key was properly not included in exclusion projection
139138
else:
@@ -174,6 +173,7 @@ def load_documents(
174173
Args:
175174
filter_ (Dict[str, Any]): The filter to apply to the collection.
176175
limit (Optional[int]): The number of documents to load.
176+
projection: selection of fields to create Cursor
177177
178178
Yields:
179179
Iterator[TDataItem]: An iterator of the loaded documents.
@@ -279,6 +279,7 @@ def load_documents(
279279
Args:
280280
filter_ (Dict[str, Any]): The filter to apply to the collection.
281281
limit (Optional[int]): The number of documents to load.
282+
projection: selection of fields to create Cursor
282283
283284
Yields:
284285
Iterator[TDataItem]: An iterator of the loaded documents.
@@ -300,19 +301,22 @@ def load_documents(
300301
filter_: Dict[str, Any],
301302
limit: Optional[int] = None,
302303
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
304+
pymongoarrow_schema: Any = None,
303305
) -> Iterator[Any]:
304306
"""
305307
Load documents from the collection in Apache Arrow format.
306308
307309
Args:
308310
filter_ (Dict[str, Any]): The filter to apply to the collection.
309311
limit (Optional[int]): The number of documents to load.
312+
projection: selection of fields to create Cursor
313+
pymongoarrow_schema: mapping of field types to convert BSON to Arrow
310314
311315
Yields:
312316
Iterator[Any]: An iterator of the loaded documents.
313317
"""
314318
from pymongoarrow.context import PyMongoArrowContext # type: ignore
315-
from pymongoarrow.lib import process_bson_stream
319+
from pymongoarrow.lib import process_bson_stream # type: ignore
316320

317321
filter_op = self._filter_op
318322
_raise_if_intersection(filter_op, filter_)
@@ -330,7 +334,8 @@ def load_documents(
330334
cursor = self._limit(cursor, limit) # type: ignore
331335

332336
context = PyMongoArrowContext.from_schema(
333-
None, codec_options=self.collection.codec_options
337+
schema=pymongoarrow_schema,
338+
codec_options=self.collection.codec_options
334339
)
335340
for batch in cursor:
336341
process_bson_stream(batch, context)
@@ -343,6 +348,58 @@ class CollectionArrowLoaderParallel(CollectionLoaderParallel):
343348
Mongo DB collection parallel loader, which uses
344349
Apache Arrow for data processing.
345350
"""
351+
def load_documents(
352+
self,
353+
filter_: Dict[str, Any],
354+
limit: Optional[int] = None,
355+
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
356+
pymongoarrow_schema: Any = None,
357+
) -> Iterator[TDataItem]:
358+
"""Load documents from the collection in parallel.
359+
360+
Args:
361+
filter_ (Dict[str, Any]): The filter to apply to the collection.
362+
limit (Optional[int]): The number of documents to load.
363+
projection: selection of fields to create Cursor
364+
pymongoarrow_schema: mapping of field types to convert BSON to Arrow
365+
366+
Yields:
367+
Iterator[TDataItem]: An iterator of the loaded documents.
368+
"""
369+
yield from self._get_all_batches(
370+
limit=limit,
371+
filter_=filter_,
372+
projection=projection,
373+
pymongoarrow_schema=pymongoarrow_schema
374+
)
375+
376+
def _get_all_batches(
377+
self,
378+
filter_: Dict[str, Any],
379+
limit: Optional[int] = None,
380+
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
381+
pymongoarrow_schema: Any = None,
382+
) -> Iterator[TDataItem]:
383+
"""Load all documents from the collection in parallel batches.
384+
385+
Args:
386+
filter_ (Dict[str, Any]): The filter to apply to the collection.
387+
limit (Optional[int]): The maximum number of documents to load.
388+
projection: selection of fields to create Cursor
389+
pymongoarrow_schema: mapping of field types to convert BSON to Arrow
390+
391+
Yields:
392+
Iterator[TDataItem]: An iterator of the loaded documents.
393+
"""
394+
batches = self._create_batches(limit=limit)
395+
cursor = self._get_cursor(filter_=filter_, projection=projection)
396+
for batch in batches:
397+
yield self._run_batch(
398+
cursor=cursor,
399+
batch=batch,
400+
pymongoarrow_schema=pymongoarrow_schema,
401+
)
402+
346403
def _get_cursor(
347404
self,
348405
filter_: Dict[str, Any],
@@ -352,6 +409,7 @@ def _get_cursor(
352409
353410
Args:
354411
filter_ (Dict[str, Any]): The filter to apply to the collection.
412+
projection: selection of fields to create Cursor
355413
356414
Returns:
357415
Cursor: The cursor for the collection.
@@ -371,14 +429,20 @@ def _get_cursor(
371429
return cursor
372430

373431
@dlt.defer
374-
def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem:
432+
def _run_batch(
433+
self,
434+
cursor: TCursor,
435+
batch: Dict[str, int],
436+
pymongoarrow_schema: Any = None,
437+
) -> TDataItem:
375438
from pymongoarrow.context import PyMongoArrowContext
376439
from pymongoarrow.lib import process_bson_stream
377440

378441
cursor = cursor.clone()
379442

380443
context = PyMongoArrowContext.from_schema(
381-
None, codec_options=self.collection.codec_options
444+
schema=pymongoarrow_schema,
445+
codec_options=self.collection.codec_options
382446
)
383447
for chunk in cursor.skip(batch["skip"]).limit(batch["limit"]):
384448
process_bson_stream(chunk, context)
@@ -390,7 +454,8 @@ def collection_documents(
390454
client: TMongoClient,
391455
collection: TCollection,
392456
filter_: Dict[str, Any],
393-
projection: Union[Dict[str, Any], List[str]], # TODO kwargs reserved for dlt?
457+
projection: Union[Dict[str, Any], List[str]],
458+
pymongoarrow_schema: "pymongoarrow.schema.Schema",
394459
incremental: Optional[dlt.sources.incremental[Any]] = None,
395460
parallel: bool = False,
396461
limit: Optional[int] = None,
@@ -413,12 +478,13 @@ def collection_documents(
413478
Supported formats:
414479
object - Python objects (dicts, lists).
415480
arrow - Apache Arrow tables.
416-
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select columns
481+
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields
417482
when loading the collection. Supported inputs:
418483
include (list) - ["year", "title"]
419-
include (dict) - {"year": 1, "title": 1}
420-
exclude (dict) - {"released": 0, "runtime": 0}
421-
Note: Can't mix include and exclude statements '{"title": 1, "released": 0}`
484+
include (dict) - {"year": True, "title": True}
485+
exclude (dict) - {"released": False, "runtime": False}
486+
Note: Can't mix include and exclude statements '{"title": True, "released": False}`
487+
pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types of a collection to convert BSON to Arrow
422488
423489
Returns:
424490
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
@@ -429,6 +495,19 @@ def collection_documents(
429495
)
430496
data_item_format = "object"
431497

498+
if data_item_format != "arrow" and pymongoarrow_schema:
499+
dlt.common.logger.warn(
500+
"Received value for `pymongoarrow_schema`, but `data_item_format=='object'` "
501+
"Use `data_item_format=='arrow'` to enforce schema."
502+
)
503+
504+
if data_item_format == "arrow" and pymongoarrow_schema and projection:
505+
dlt.common.logger.warn(
506+
"Received values for both `pymongoarrow_schema` and `projection`. Since both "
507+
"create a projection to select fields, `projection` will be ignored."
508+
)
509+
510+
432511
if parallel:
433512
if data_item_format == "arrow":
434513
LoaderClass = CollectionArrowLoaderParallel
@@ -443,11 +522,24 @@ def collection_documents(
443522
loader = LoaderClass(
444523
client, collection, incremental=incremental, chunk_size=chunk_size
445524
)
446-
for data in loader.load_documents(limit=limit, filter_=filter_, projection=projection):
447-
yield data
525+
if isinstance(loader, (CollectionArrowLoader, CollectionArrowLoaderParallel)):
526+
yield from loader.load_documents(
527+
limit=limit,
528+
filter_=filter_,
529+
projection=projection,
530+
pymongoarrow_schema=pymongoarrow_schema,
531+
)
532+
else:
533+
yield from loader.load_documents(limit=limit, filter_=filter_, projection=projection)
448534

449535

450536
def convert_mongo_objs(value: Any) -> Any:
537+
"""MongoDB to dlt type conversion when using Python loaders.
538+
539+
Notes:
540+
The method `ObjectId.__str__()` creates an hexstring using `binascii.hexlify(__id).decode()`
541+
542+
"""
451543
if isinstance(value, (ObjectId, Decimal128)):
452544
return str(value)
453545
if isinstance(value, _datetime.datetime):
@@ -464,6 +556,13 @@ def convert_mongo_objs(value: Any) -> Any:
464556
def convert_arrow_columns(table: Any) -> Any:
465557
"""Convert the given table columns to Python types.
466558
559+
Notes:
560+
Calling str() matches the `convert_mongo_obs()` used in non-arrow code.
561+
Pymongoarrow converts ObjectId to `fixed_size_binary[12]`, which can't be
562+
converted to a string as a vectorized operation because it contains ASCII characters.
563+
564+
Instead, you need to loop over values using: `value.as_buffer().hex().decode()`
565+
467566
Args:
468567
table (pyarrow.lib.Table): The table to convert.
469568
@@ -539,6 +638,7 @@ class MongoDbCollectionResourceConfiguration(BaseConfiguration):
539638
incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg]
540639
write_disposition: Optional[str] = dlt.config.value
541640
parallel: Optional[bool] = False
641+
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = dlt.config.value
542642

543643

544644
__source_name__ = "mongodb"

0 commit comments

Comments
 (0)