Skip to content

Commit 98d98c6

Browse files
committed
Attempt to fix the relays projection in a generic way.
1 parent 7dbb95c commit 98d98c6

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

opensensor/collection_apis.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
import logging
22
from datetime import datetime, timedelta, timezone
3-
from typing import Dict, Generic, List, Optional, Type, TypeVar, get_args, get_origin
3+
from enum import Enum
4+
from typing import (
5+
Any,
6+
Dict,
7+
Generic,
8+
List,
9+
Optional,
10+
Type,
11+
TypeVar,
12+
get_args,
13+
get_origin,
14+
)
415

516
from bson import Binary
617
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Response, status
@@ -184,10 +195,6 @@ def get_initial_match_clause(
184195
return match_clause
185196

186197

187-
def is_pydantic_model(obj):
188-
return isinstance(obj, type) and issubclass(obj, BaseModel)
189-
190-
191198
def get_nested_fields(model: Type[BaseModel]) -> Dict[str, Type[BaseModel]]:
192199
nested_fields = {}
193200
for field_name, field in model.__fields__.items():
@@ -198,17 +205,14 @@ def get_nested_fields(model: Type[BaseModel]) -> Dict[str, Type[BaseModel]]:
198205
return nested_fields
199206

200207

201-
def is_list_of_models(field: ModelField) -> bool:
202-
if field.shape != 2: # 2 represents a list shape in Pydantic
203-
return False
204-
args = field.sub_fields[0].type_.__args__
205-
return args and is_pydantic_model(args[0])
206-
207-
208208
def get_list_item_type(field: ModelField) -> Type[BaseModel]:
209209
return field.sub_fields[0].type_.__args__[0]
210210

211211

212+
def is_enum(field_type):
213+
return issubclass(field_type, Enum)
214+
215+
212216
def create_nested_pipeline(model: Type[BaseModel], prefix=""):
213217
logger.debug(f"Creating nested pipeline for model: {model.__name__}, prefix: {prefix}")
214218
match_conditions = {}
@@ -229,13 +233,21 @@ def create_nested_pipeline(model: Type[BaseModel], prefix=""):
229233
unit_field_name = f"{prefix}{mongo_field}_unit"
230234
pipeline["unit"] = f"${unit_field_name}"
231235
match_conditions[unit_field_name] = {"$exists": True}
232-
elif is_list_of_models(field):
233-
item_model = get_list_item_type(field)
236+
elif is_pydantic_model(field.type_):
237+
nested_pipeline, nested_match = create_nested_pipeline(field.type_, f"{field_name}.")
238+
pipeline[field_name] = nested_pipeline
239+
match_conditions.update({f"{field_name}.{k}": v for k, v in nested_match.items()})
240+
elif get_origin(field.type_) is List and is_pydantic_model(get_args(field.type_)[0]):
241+
item_model = get_args(field.type_)[0]
234242
nested_pipeline, nested_match = create_nested_pipeline(item_model, "")
235243
pipeline[field_name] = {
236244
"$map": {"input": f"${full_mongo_field_name}", "as": "item", "in": nested_pipeline}
237245
}
238246
match_conditions[full_mongo_field_name] = {"$exists": True, "$ne": []}
247+
elif is_enum(field.type_):
248+
# Handle enum fields as simple fields
249+
pipeline[field_name] = f"${full_mongo_field_name}"
250+
match_conditions[full_mongo_field_name] = {"$exists": True}
239251
else:
240252
# Handle simple field
241253
pipeline[field_name] = f"${full_mongo_field_name}"
@@ -248,6 +260,10 @@ def create_nested_pipeline(model: Type[BaseModel], prefix=""):
248260
return pipeline, match_conditions
249261

250262

263+
def is_pydantic_model(obj: Any) -> bool:
264+
return isinstance(obj, type) and issubclass(obj, BaseModel)
265+
266+
251267
def create_model_instance(model: Type[BaseModel], data: dict, target_unit: Optional[str] = None):
252268
nested_fields = get_nested_fields(model)
253269

0 commit comments

Comments
 (0)