1
1
import logging
2
2
from datetime import datetime , timedelta , timezone
3
- from typing import Dict , Generic , List , Optional , Type , TypeVar , get_args , get_origin
3
+ from enum import Enum
4
+ from typing import (
5
+ Any ,
6
+ Dict ,
7
+ Generic ,
8
+ List ,
9
+ Optional ,
10
+ Type ,
11
+ TypeVar ,
12
+ get_args ,
13
+ get_origin ,
14
+ )
4
15
5
16
from bson import Binary
6
17
from fastapi import APIRouter , Depends , HTTPException , Path , Query , Response , status
@@ -184,10 +195,6 @@ def get_initial_match_clause(
184
195
return match_clause
185
196
186
197
187
- def is_pydantic_model (obj ):
188
- return isinstance (obj , type ) and issubclass (obj , BaseModel )
189
-
190
-
191
198
def get_nested_fields (model : Type [BaseModel ]) -> Dict [str , Type [BaseModel ]]:
192
199
nested_fields = {}
193
200
for field_name , field in model .__fields__ .items ():
@@ -198,17 +205,14 @@ def get_nested_fields(model: Type[BaseModel]) -> Dict[str, Type[BaseModel]]:
198
205
return nested_fields
199
206
200
207
201
- def is_list_of_models (field : ModelField ) -> bool :
202
- if field .shape != 2 : # 2 represents a list shape in Pydantic
203
- return False
204
- args = field .sub_fields [0 ].type_ .__args__
205
- return args and is_pydantic_model (args [0 ])
206
-
207
-
208
208
def get_list_item_type (field : ModelField ) -> Type [BaseModel ]:
209
209
return field .sub_fields [0 ].type_ .__args__ [0 ]
210
210
211
211
212
+ def is_enum (field_type ):
213
+ return issubclass (field_type , Enum )
214
+
215
+
212
216
def create_nested_pipeline (model : Type [BaseModel ], prefix = "" ):
213
217
logger .debug (f"Creating nested pipeline for model: { model .__name__ } , prefix: { prefix } " )
214
218
match_conditions = {}
@@ -229,13 +233,21 @@ def create_nested_pipeline(model: Type[BaseModel], prefix=""):
229
233
unit_field_name = f"{ prefix } { mongo_field } _unit"
230
234
pipeline ["unit" ] = f"${ unit_field_name } "
231
235
match_conditions [unit_field_name ] = {"$exists" : True }
232
- elif is_list_of_models (field ):
233
- item_model = get_list_item_type (field )
236
+ elif is_pydantic_model (field .type_ ):
237
+ nested_pipeline , nested_match = create_nested_pipeline (field .type_ , f"{ field_name } ." )
238
+ pipeline [field_name ] = nested_pipeline
239
+ match_conditions .update ({f"{ field_name } .{ k } " : v for k , v in nested_match .items ()})
240
+ elif get_origin (field .type_ ) is List and is_pydantic_model (get_args (field .type_ )[0 ]):
241
+ item_model = get_args (field .type_ )[0 ]
234
242
nested_pipeline , nested_match = create_nested_pipeline (item_model , "" )
235
243
pipeline [field_name ] = {
236
244
"$map" : {"input" : f"${ full_mongo_field_name } " , "as" : "item" , "in" : nested_pipeline }
237
245
}
238
246
match_conditions [full_mongo_field_name ] = {"$exists" : True , "$ne" : []}
247
+ elif is_enum (field .type_ ):
248
+ # Handle enum fields as simple fields
249
+ pipeline [field_name ] = f"${ full_mongo_field_name } "
250
+ match_conditions [full_mongo_field_name ] = {"$exists" : True }
239
251
else :
240
252
# Handle simple field
241
253
pipeline [field_name ] = f"${ full_mongo_field_name } "
@@ -248,6 +260,10 @@ def create_nested_pipeline(model: Type[BaseModel], prefix=""):
248
260
return pipeline , match_conditions
249
261
250
262
263
+ def is_pydantic_model (obj : Any ) -> bool :
264
+ return isinstance (obj , type ) and issubclass (obj , BaseModel )
265
+
266
+
251
267
def create_model_instance (model : Type [BaseModel ], data : dict , target_unit : Optional [str ] = None ):
252
268
nested_fields = get_nested_fields (model )
253
269
0 commit comments