Skip to content

Commit

Permalink
Adds database, timeout, and sanitize options to schema functions
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Feb 20, 2025
1 parent 5101575 commit 34368cf
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 19 deletions.
165 changes: 147 additions & 18 deletions src/neo4j_graphrag/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,13 @@ def query_database(
return json_data


def get_schema(driver: neo4j.Driver, is_enhanced: bool = False) -> str:
def get_schema(
driver: neo4j.Driver,
is_enhanced: bool = False,
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
) -> str:
"""
Returns the schema of the graph as a string with following format:
Expand All @@ -197,16 +203,34 @@ def get_schema(driver: neo4j.Driver, is_enhanced: bool = False) -> str:
driver (neo4j.Driver): Neo4j Python driver instance.
is_enhanced (bool): Flag indicating whether to format the schema with
detailed statistics (True) or in a simpler overview format (False).
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
timeout (Optional[float]): The timeout for transactions in seconds.
Useful for terminating long-running queries.
By default, there is no timeout set.
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
Returns:
str: the graph schema information in a serialized format.
"""
structured_schema = get_structured_schema(driver, is_enhanced)
structured_schema = get_structured_schema(
driver=driver,
is_enhanced=is_enhanced,
database=database,
timeout=timeout,
sanitize=sanitize,
)
return format_schema(structured_schema, is_enhanced)


def get_structured_schema(
driver: neo4j.Driver, is_enhanced: bool = False
driver: neo4j.Driver,
is_enhanced: bool = False,
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
) -> dict[str, Any]:
"""
Returns the structured schema of the graph.
Expand Down Expand Up @@ -249,45 +273,75 @@ def get_structured_schema(
driver (neo4j.Driver): Neo4j Python driver instance.
is_enhanced (bool): Flag indicating whether to format the schema with
detailed statistics (True) or in a simpler overview format (False).
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
timeout (Optional[float]): The timeout for transactions in seconds.
Useful for terminating long-running queries.
By default, there is no timeout set.
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
Returns:
dict[str, Any]: the graph schema information in a structured format.
"""
node_properties = [
data["output"]
for data in query_database(
driver,
NODE_PROPERTIES_QUERY,
driver=driver,
query=NODE_PROPERTIES_QUERY,
params={
"EXCLUDED_LABELS": EXCLUDED_LABELS
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
},
database=database,
timeout=timeout,
sanitize=sanitize,
)
]

rel_properties = [
data["output"]
for data in query_database(
driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS}
driver=driver,
query=REL_PROPERTIES_QUERY,
params={"EXCLUDED_LABELS": EXCLUDED_RELS},
database=database,
timeout=timeout,
sanitize=sanitize,
)
]

relationships = [
data["output"]
for data in query_database(
driver,
REL_QUERY,
driver=driver,
query=REL_QUERY,
params={
"EXCLUDED_LABELS": EXCLUDED_LABELS
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
},
database=database,
timeout=timeout,
sanitize=sanitize,
)
]

# Get constraints and indexes
try:
constraint = query_database(driver, "SHOW CONSTRAINTS")
index = query_database(driver, INDEX_QUERY)
constraint = query_database(
driver=driver,
query="SHOW CONSTRAINTS",
database=database,
timeout=timeout,
sanitize=sanitize,
)
index = query_database(
driver=driver,
query=INDEX_QUERY,
database=database,
timeout=timeout,
sanitize=sanitize,
)
except ClientError:
constraint = []
index = []
Expand All @@ -299,7 +353,13 @@ def get_structured_schema(
"metadata": {"constraint": constraint, "index": index},
}
if is_enhanced:
enhance_schema(driver=driver, structured_schema=structured_schema)
enhance_schema(
driver=driver,
structured_schema=structured_schema,
database=database,
timeout=timeout,
sanitize=sanitize,
)
return structured_schema


Expand Down Expand Up @@ -436,6 +496,9 @@ def _build_str_clauses(
label_or_type: str,
exhaustive: bool,
prop_index: Optional[List[Any]] = None,
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Build Cypher clauses for string property statistics.
Expand All @@ -455,6 +518,13 @@ def _build_str_clauses(
prop_index (Optional[List[Any]]): Optional metadata about the property's
index. If provided, certain optimizations are applied based on
distinct value limits and index availability.
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
timeout (Optional[float]): The timeout for transactions in seconds.
Useful for terminating long-running queries.
By default, there is no timeout set.
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
Returns:
Tuple[List[str], List[str]]:
Expand All @@ -471,9 +541,14 @@ def _build_str_clauses(
and prop_index[0].get("distinctValues") <= DISTINCT_VALUE_LIMIT
):
distinct_values = query_database(
driver,
f"CALL apoc.schema.properties.distinct("
f"'{label_or_type}', '{prop_name}') YIELD value",
driver=driver,
query=(
f"CALL apoc.schema.properties.distinct("
f"'{label_or_type}', '{prop_name}') YIELD value"
),
database=database,
timeout=timeout,
sanitize=sanitize,
)[0]["value"]
return_clauses.append(
(f"values: {distinct_values}," f" distinct_count: {len(distinct_values)}")
Expand Down Expand Up @@ -582,6 +657,9 @@ def get_enhanced_schema_cypher(
exhaustive: bool,
sample_size: int = 5,
is_relationship: bool = False,
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
) -> str:
"""
Build a Cypher query for enhanced schema information.
Expand All @@ -605,6 +683,13 @@ def get_enhanced_schema_cypher(
exhaustive is False. Defaults to 5.
is_relationship (bool, optional): Indicates if the query is for
a relationship type (True) or a node label (False). Defaults to False.
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
timeout (Optional[float]): The timeout for transactions in seconds.
Useful for terminating long-running queries.
By default, there is no timeout set.
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
Returns:
str: A Cypher query string that gathers enhanced property metadata.
Expand Down Expand Up @@ -643,6 +728,9 @@ def get_enhanced_schema_cypher(
label_or_type=label_or_type,
exhaustive=exhaustive,
prop_index=prop_index,
database=database,
timeout=timeout,
sanitize=sanitize,
)
with_clauses += str_w_clauses
return_clauses += str_r_clauses
Expand Down Expand Up @@ -682,6 +770,9 @@ def enhance_properties(
structured_schema: Dict[str, Any],
prop_dict: Dict[str, Any],
is_relationship: bool,
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
) -> None:
"""
Enhance the structured schema with detailed statistics for a single node label or relationship type.
Expand All @@ -699,6 +790,13 @@ def enhance_properties(
relationship type to be enhanced.
is_relationship (bool): Indicates whether the properties to be enhanced belong to a relationship
(True) or a node (False).
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
timeout (Optional[float]): The timeout for transactions in seconds.
Useful for terminating long-running queries.
By default, there is no timeout set.
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
Returns:
None
Expand All @@ -720,6 +818,9 @@ def enhance_properties(
properties=props,
exhaustive=count < EXHAUSTIVE_SEARCH_LIMIT,
is_relationship=is_relationship,
database=database,
timeout=timeout,
sanitize=sanitize,
)
# Due to schema-flexible nature of neo4j errors can happen
try:
Expand All @@ -733,9 +834,12 @@ def enhance_properties(
else {}
)
enhanced_info = query_database(
driver,
enhanced_cypher,
driver=driver,
query=enhanced_cypher,
session_params=session_params,
database=database,
timeout=timeout,
sanitize=sanitize,
)[0]["output"]
for prop in props:
if prop["property"] in enhanced_info:
Expand All @@ -744,7 +848,13 @@ def enhance_properties(
return


def enhance_schema(driver: neo4j.Driver, structured_schema: Dict[str, Any]) -> None:
def enhance_schema(
driver: neo4j.Driver,
structured_schema: Dict[str, Any],
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
) -> None:
"""
Enhance the structured schema with detailed property statistics.
Expand All @@ -759,18 +869,34 @@ def enhance_schema(driver: neo4j.Driver, structured_schema: Dict[str, Any]) -> N
structured_schema (Dict[str, Any]): The initial structured schema
containing node and relationship properties, which will be updated
with enhanced statistics.
database (Optional[str]): The name of the database to connect to. Default is 'neo4j'.
timeout (Optional[float]): The timeout for transactions in seconds.
Useful for terminating long-running queries.
By default, there is no timeout set.
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
Returns:
None
"""
schema_counts = query_database(driver, SCHEMA_COUNTS_QUERY)
schema_counts = query_database(
driver=driver,
query=SCHEMA_COUNTS_QUERY,
database=database,
timeout=timeout,
sanitize=sanitize,
)
# Update node info
for node in schema_counts[0]["nodes"]:
enhance_properties(
driver=driver,
structured_schema=structured_schema,
prop_dict=node,
is_relationship=False,
database=database,
timeout=timeout,
sanitize=sanitize,
)
# Update rel info
for rel in schema_counts[0]["relationships"]:
Expand All @@ -779,4 +905,7 @@ def enhance_schema(driver: neo4j.Driver, structured_schema: Dict[str, Any]) -> N
structured_schema=structured_schema,
prop_dict=rel,
is_relationship=True,
database=database,
timeout=timeout,
sanitize=sanitize,
)
2 changes: 1 addition & 1 deletion tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


def _query_return_value(*args: Any, **kwargs: Any) -> list[Any]:
query = args[1]
query = kwargs.get("query", args[1] if len(args) > 1 else None)
if NODE_PROPERTIES_QUERY in query:
return [
{
Expand Down

0 comments on commit 34368cf

Please sign in to comment.