@@ -107,10 +107,10 @@ def _filter_op(self) -> Dict[str, Any]:
107
107
filt [self .cursor_field ]["$gt" ] = self .incremental .end_value
108
108
109
109
return filt
110
-
111
- def _projection_op (self , projection ) -> Optional [Dict [str , Any ]]:
110
+
111
+ def _projection_op (self , projection : Optional [ Union [ Mapping [ str , Any ], Iterable [ str ]]] ) -> Optional [Dict [str , Any ]]:
112
112
"""Build a projection operator.
113
-
113
+
114
114
A tuple of fields to include or a dict specifying fields to include or exclude.
115
115
The incremental `primary_key` needs to be handle differently for inclusion
116
116
and exclusion projections.
@@ -123,17 +123,16 @@ def _projection_op(self, projection) -> Optional[Dict[str, Any]]:
123
123
124
124
projection_dict = dict (_fields_list_to_dict (projection , "projection" ))
125
125
126
- # NOTE we can still filter on primary_key if it's excluded from projection
127
126
if self .incremental :
128
127
# this is an inclusion projection
129
- if any (v == 1 for v in projection .values ()):
128
+ if any (v == 1 for v in projection_dict .values ()):
130
129
# ensure primary_key is included
131
- projection_dict .update ({self .incremental .primary_key : 1 })
130
+ projection_dict .update (m = {self .incremental .primary_key : 1 })
132
131
# this is an exclusion projection
133
132
else :
134
133
try :
135
134
# ensure primary_key isn't excluded
136
- projection_dict .pop (self .incremental .primary_key )
135
+ projection_dict .pop (self .incremental .primary_key ) # type: ignore
137
136
except KeyError :
138
137
pass # primary_key was properly not included in exclusion projection
139
138
else :
@@ -174,6 +173,7 @@ def load_documents(
174
173
Args:
175
174
filter_ (Dict[str, Any]): The filter to apply to the collection.
176
175
limit (Optional[int]): The number of documents to load.
176
+ projection: selection of fields to create Cursor
177
177
178
178
Yields:
179
179
Iterator[TDataItem]: An iterator of the loaded documents.
@@ -279,6 +279,7 @@ def load_documents(
279
279
Args:
280
280
filter_ (Dict[str, Any]): The filter to apply to the collection.
281
281
limit (Optional[int]): The number of documents to load.
282
+ projection: selection of fields to create Cursor
282
283
283
284
Yields:
284
285
Iterator[TDataItem]: An iterator of the loaded documents.
@@ -300,19 +301,22 @@ def load_documents(
300
301
filter_ : Dict [str , Any ],
301
302
limit : Optional [int ] = None ,
302
303
projection : Optional [Union [Mapping [str , Any ], Iterable [str ]]] = None ,
304
+ pymongoarrow_schema : Any = None ,
303
305
) -> Iterator [Any ]:
304
306
"""
305
307
Load documents from the collection in Apache Arrow format.
306
308
307
309
Args:
308
310
filter_ (Dict[str, Any]): The filter to apply to the collection.
309
311
limit (Optional[int]): The number of documents to load.
312
+ projection: selection of fields to create Cursor
313
+ pymongoarrow_schema: mapping of field types to convert BSON to Arrow
310
314
311
315
Yields:
312
316
Iterator[Any]: An iterator of the loaded documents.
313
317
"""
314
318
from pymongoarrow .context import PyMongoArrowContext # type: ignore
315
- from pymongoarrow .lib import process_bson_stream
319
+ from pymongoarrow .lib import process_bson_stream # type: ignore
316
320
317
321
filter_op = self ._filter_op
318
322
_raise_if_intersection (filter_op , filter_ )
@@ -330,7 +334,8 @@ def load_documents(
330
334
cursor = self ._limit (cursor , limit ) # type: ignore
331
335
332
336
context = PyMongoArrowContext .from_schema (
333
- None , codec_options = self .collection .codec_options
337
+ schema = pymongoarrow_schema ,
338
+ codec_options = self .collection .codec_options
334
339
)
335
340
for batch in cursor :
336
341
process_bson_stream (batch , context )
@@ -343,6 +348,58 @@ class CollectionArrowLoaderParallel(CollectionLoaderParallel):
343
348
Mongo DB collection parallel loader, which uses
344
349
Apache Arrow for data processing.
345
350
"""
351
+ def load_documents (
352
+ self ,
353
+ filter_ : Dict [str , Any ],
354
+ limit : Optional [int ] = None ,
355
+ projection : Optional [Union [Mapping [str , Any ], Iterable [str ]]] = None ,
356
+ pymongoarrow_schema : Any = None ,
357
+ ) -> Iterator [TDataItem ]:
358
+ """Load documents from the collection in parallel.
359
+
360
+ Args:
361
+ filter_ (Dict[str, Any]): The filter to apply to the collection.
362
+ limit (Optional[int]): The number of documents to load.
363
+ projection: selection of fields to create Cursor
364
+ pymongoarrow_schema: mapping of field types to convert BSON to Arrow
365
+
366
+ Yields:
367
+ Iterator[TDataItem]: An iterator of the loaded documents.
368
+ """
369
+ yield from self ._get_all_batches (
370
+ limit = limit ,
371
+ filter_ = filter_ ,
372
+ projection = projection ,
373
+ pymongoarrow_schema = pymongoarrow_schema
374
+ )
375
+
376
+ def _get_all_batches (
377
+ self ,
378
+ filter_ : Dict [str , Any ],
379
+ limit : Optional [int ] = None ,
380
+ projection : Optional [Union [Mapping [str , Any ], Iterable [str ]]] = None ,
381
+ pymongoarrow_schema : Any = None ,
382
+ ) -> Iterator [TDataItem ]:
383
+ """Load all documents from the collection in parallel batches.
384
+
385
+ Args:
386
+ filter_ (Dict[str, Any]): The filter to apply to the collection.
387
+ limit (Optional[int]): The maximum number of documents to load.
388
+ projection: selection of fields to create Cursor
389
+ pymongoarrow_schema: mapping of field types to convert BSON to Arrow
390
+
391
+ Yields:
392
+ Iterator[TDataItem]: An iterator of the loaded documents.
393
+ """
394
+ batches = self ._create_batches (limit = limit )
395
+ cursor = self ._get_cursor (filter_ = filter_ , projection = projection )
396
+ for batch in batches :
397
+ yield self ._run_batch (
398
+ cursor = cursor ,
399
+ batch = batch ,
400
+ pymongoarrow_schema = pymongoarrow_schema ,
401
+ )
402
+
346
403
def _get_cursor (
347
404
self ,
348
405
filter_ : Dict [str , Any ],
@@ -352,6 +409,7 @@ def _get_cursor(
352
409
353
410
Args:
354
411
filter_ (Dict[str, Any]): The filter to apply to the collection.
412
+ projection: selection of fields to create Cursor
355
413
356
414
Returns:
357
415
Cursor: The cursor for the collection.
@@ -371,14 +429,20 @@ def _get_cursor(
371
429
return cursor
372
430
373
431
@dlt .defer
374
- def _run_batch (self , cursor : TCursor , batch : Dict [str , int ]) -> TDataItem :
432
+ def _run_batch (
433
+ self ,
434
+ cursor : TCursor ,
435
+ batch : Dict [str , int ],
436
+ pymongoarrow_schema : Any = None ,
437
+ ) -> TDataItem :
375
438
from pymongoarrow .context import PyMongoArrowContext
376
439
from pymongoarrow .lib import process_bson_stream
377
440
378
441
cursor = cursor .clone ()
379
442
380
443
context = PyMongoArrowContext .from_schema (
381
- None , codec_options = self .collection .codec_options
444
+ schema = pymongoarrow_schema ,
445
+ codec_options = self .collection .codec_options
382
446
)
383
447
for chunk in cursor .skip (batch ["skip" ]).limit (batch ["limit" ]):
384
448
process_bson_stream (chunk , context )
@@ -390,7 +454,8 @@ def collection_documents(
390
454
client : TMongoClient ,
391
455
collection : TCollection ,
392
456
filter_ : Dict [str , Any ],
393
- projection : Union [Dict [str , Any ], List [str ]], # TODO kwargs reserved for dlt?
457
+ projection : Union [Dict [str , Any ], List [str ]],
458
+ pymongoarrow_schema : "pymongoarrow.schema.Schema" ,
394
459
incremental : Optional [dlt .sources .incremental [Any ]] = None ,
395
460
parallel : bool = False ,
396
461
limit : Optional [int ] = None ,
@@ -413,12 +478,13 @@ def collection_documents(
413
478
Supported formats:
414
479
object - Python objects (dicts, lists).
415
480
arrow - Apache Arrow tables.
416
- projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select columns
481
+ projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields
417
482
when loading the collection. Supported inputs:
418
483
include (list) - ["year", "title"]
419
- include (dict) - {"year": 1, "title": 1}
420
- exclude (dict) - {"released": 0, "runtime": 0}
421
- Note: Can't mix include and exclude statements '{"title": 1, "released": 0}`
484
+ include (dict) - {"year": True, "title": True}
485
+ exclude (dict) - {"released": False, "runtime": False}
486
+ Note: Can't mix include and exclude statements '{"title": True, "released": False}`
487
+ pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types of a collection to convert BSON to Arrow
422
488
423
489
Returns:
424
490
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
@@ -429,6 +495,19 @@ def collection_documents(
429
495
)
430
496
data_item_format = "object"
431
497
498
+ if data_item_format != "arrow" and pymongoarrow_schema :
499
+ dlt .common .logger .warn (
500
+ "Received value for `pymongoarrow_schema`, but `data_item_format=='object'` "
501
+ "Use `data_item_format=='arrow'` to enforce schema."
502
+ )
503
+
504
+ if data_item_format == "arrow" and pymongoarrow_schema and projection :
505
+ dlt .common .logger .warn (
506
+ "Received values for both `pymongoarrow_schema` and `projection`. Since both "
507
+ "create a projection to select fields, `projection` will be ignored."
508
+ )
509
+
510
+
432
511
if parallel :
433
512
if data_item_format == "arrow" :
434
513
LoaderClass = CollectionArrowLoaderParallel
@@ -443,11 +522,24 @@ def collection_documents(
443
522
loader = LoaderClass (
444
523
client , collection , incremental = incremental , chunk_size = chunk_size
445
524
)
446
- for data in loader .load_documents (limit = limit , filter_ = filter_ , projection = projection ):
447
- yield data
525
+ if isinstance (loader , (CollectionArrowLoader , CollectionArrowLoaderParallel )):
526
+ yield from loader .load_documents (
527
+ limit = limit ,
528
+ filter_ = filter_ ,
529
+ projection = projection ,
530
+ pymongoarrow_schema = pymongoarrow_schema ,
531
+ )
532
+ else :
533
+ yield from loader .load_documents (limit = limit , filter_ = filter_ , projection = projection )
448
534
449
535
450
536
def convert_mongo_objs (value : Any ) -> Any :
537
+ """MongoDB to dlt type conversion when using Python loaders.
538
+
539
+ Notes:
540
+ The method `ObjectId.__str__()` creates an hexstring using `binascii.hexlify(__id).decode()`
541
+
542
+ """
451
543
if isinstance (value , (ObjectId , Decimal128 )):
452
544
return str (value )
453
545
if isinstance (value , _datetime .datetime ):
@@ -464,6 +556,13 @@ def convert_mongo_objs(value: Any) -> Any:
464
556
def convert_arrow_columns (table : Any ) -> Any :
465
557
"""Convert the given table columns to Python types.
466
558
559
+ Notes:
560
+ Calling str() matches the `convert_mongo_obs()` used in non-arrow code.
561
+ Pymongoarrow converts ObjectId to `fixed_size_binary[12]`, which can't be
562
+ converted to a string as a vectorized operation because it contains ASCII characters.
563
+
564
+ Instead, you need to loop over values using: `value.as_buffer().hex().decode()`
565
+
467
566
Args:
468
567
table (pyarrow.lib.Table): The table to convert.
469
568
@@ -539,6 +638,7 @@ class MongoDbCollectionResourceConfiguration(BaseConfiguration):
539
638
incremental : Optional [dlt .sources .incremental ] = None # type: ignore[type-arg]
540
639
write_disposition : Optional [str ] = dlt .config .value
541
640
parallel : Optional [bool ] = False
641
+ projection : Optional [Union [Mapping [str , Any ], Iterable [str ]]] = dlt .config .value
542
642
543
643
544
644
__source_name__ = "mongodb"
0 commit comments