Skip to content

Commit

Permalink
Refactored enhance_schema
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Jan 30, 2025
1 parent 72db53b commit 240d946
Showing 1 changed file with 79 additions and 55 deletions.
134 changes: 79 additions & 55 deletions src/neo4j_graphrag/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ def value_sanitize(d: Any) -> Any:
def query_database(
driver: neo4j.Driver,
query: str,
params: Optional[dict[str, Any]] = None,
session_params: Optional[dict[str, Any]] = None,
params: Dict[str, Any] = {},
session_params: Dict[str, Any] = {},
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
) -> list[dict[str, Any]]:
) -> List[Dict[str, Any]]:
"""
Queries the database.
Expand All @@ -160,8 +160,6 @@ def query_database(
Returns:
list[dict[str, Any]]: the result of the query in json format.
"""
if params is None:
params = {}
if not session_params:
try:
data = driver.execute_query(
Expand Down Expand Up @@ -194,8 +192,8 @@ def query_database(
)
):
raise
# fallback to allow implicit transactions
session_params = session_params or {"database": database}
# Fallback to allow implicit transactions
session_params.setdefault("database", database)
with driver.session(**session_params) as session:
result = session.run(Query(text=query, timeout=timeout), params)
json_data = [r.data() for r in result]
Expand Down Expand Up @@ -698,6 +696,71 @@ def get_enhanced_schema_cypher(
return cypher_query


def enhance_properties(
driver: neo4j.Driver,
structured_schema: Dict[str, Any],
prop_dict: Dict[str, Any],
is_relationship: bool,
) -> None:
"""
Enhance the structured schema with detailed statistics for a single node label or relationship type.
For the specified node label or relationship type, this function queries the database to gather
property statistics such as minimum and maximum values, distinct value counts, and sample values.
These statistics are then integrated into the provided structured schema, enriching the schema with
more in-depth information about each property.
Args:
driver (neo4j.Driver): A Neo4j Python driver instance used to run queries against the database.
structured_schema (Dict[str, Any]): A dictionary representing the current structured schema,
which will be updated with enhanced property statistics.
prop_dict (Dict[str, Any]): A dictionary containing the name and count of the node label or
relationship type to be enhanced.
is_relationship (bool): Indicates whether the properties to be enhanced belong to a relationship
(True) or a node (False).
Returns:
None
"""
name = prop_dict["name"]
count = prop_dict["count"]
excluded = EXCLUDED_RELS if is_relationship else EXCLUDED_LABELS
if name in excluded:
return
props = structured_schema["node_props"].get(name)
if not props: # The node has no properties
return
enhanced_cypher = get_enhanced_schema_cypher(
driver=driver,
structured_schema=structured_schema,
label_or_type=name,
properties=props,
exhaustive=count < EXHAUSTIVE_SEARCH_LIMIT,
is_relationship=is_relationship,
)
# Due to schema-flexible nature of neo4j errors can happen
try:
# Disable the
# Neo.ClientNotification.Statement.AggregationSkippedNull
# notifications raised by the use of collect in the enhanced
# schema query for nodes
session_params = (
{"notifications_disabled_categories": ["UNRECOGNIZED"]}
if not is_relationship
else {}
)
enhanced_info = query_database(
driver,
enhanced_cypher,
session_params=session_params,
)[0]["output"]
for prop in props:
if prop["property"] in enhanced_info:
prop.update(enhanced_info[prop["property"]])
except CypherTypeError:
return


def enhance_schema(driver: neo4j.Driver, structured_schema: Dict[str, Any]) -> None:
"""
Enhance the structured schema with detailed property statistics.
Expand All @@ -720,56 +783,17 @@ def enhance_schema(driver: neo4j.Driver, structured_schema: Dict[str, Any]) -> N
schema_counts = query_database(driver, SCHEMA_COUNTS_QUERY)
# Update node info
for node in schema_counts[0]["nodes"]:
# Skip bloom labels
if node["name"] in EXCLUDED_LABELS:
continue
node_props = structured_schema["node_props"].get(node["name"])
if not node_props: # The node has no properties
continue
enhanced_cypher = get_enhanced_schema_cypher(
driver,
structured_schema,
node["name"],
node_props,
node["count"] < EXHAUSTIVE_SEARCH_LIMIT,
enhance_properties(
driver=driver,
structured_schema=structured_schema,
prop_dict=node,
is_relationship=False,
)
# Due to schema-flexible nature of neo4j errors can happen
try:
enhanced_info = query_database(
driver,
enhanced_cypher,
# Disable the
# Neo.ClientNotification.Statement.AggregationSkippedNull
# notifications raised by the use of collect in the enhanced
# schema query
params={"notifications_disabled_categories": ["UNRECOGNIZED"]},
)[0]["output"]
for prop in node_props:
if prop["property"] in enhanced_info:
prop.update(enhanced_info[prop["property"]])
except CypherTypeError:
continue
# Update rel info
for rel in schema_counts[0]["relationships"]:
# Skip bloom labels
if rel["name"] in EXCLUDED_RELS:
continue
rel_props = structured_schema["rel_props"].get(rel["name"])
if not rel_props: # The rel has no properties
continue
enhanced_cypher = get_enhanced_schema_cypher(
driver,
structured_schema,
rel["name"],
rel_props,
rel["count"] < EXHAUSTIVE_SEARCH_LIMIT,
enhance_properties(
driver=driver,
structured_schema=structured_schema,
prop_dict=rel,
is_relationship=True,
)
try:
enhanced_info = query_database(driver, enhanced_cypher)[0]["output"]
for prop in rel_props:
if prop["property"] in enhanced_info:
prop.update(enhanced_info[prop["property"]])
# Due to schema-flexible nature of neo4j errors can happen
except CypherTypeError:
continue

0 comments on commit 240d946

Please sign in to comment.