Skip to content

Commit

Permalink
proposed changes to improve neo4j operations
Browse files Browse the repository at this point in the history
- use neo4j-rust-ext instead of plain neo4j driver for 10x perf improvement
- always match on a single node label (equivalent to the constraint), never blank matches
- group relationships by type, source- and target-type
- increase batch size
- use vector property procedure to set fp32 instead of p64
- method to select the main label for a node
- TODO: create vector index would need information from the embedder (dimension) and similarity function (from config)
- Set extra labels
  • Loading branch information
jexp committed Jan 28, 2025
1 parent cbb1f3a commit a2b1944
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 24 deletions.
2 changes: 1 addition & 1 deletion requirements/connectors/neo4j.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
neo4j
neo4j-rust-ext
cymple
networkx
2 changes: 1 addition & 1 deletion requirements/connectors/neo4j.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# uv pip compile neo4j.in --output-file neo4j.txt --no-strip-extras --python-version 3.9
cymple==0.12.0
# via -r neo4j.in
neo4j==5.27.0
neo4j-rust-ext==5.27.0.0
# via -r neo4j.in
networkx==3.2.1
# via -r neo4j.in
Expand Down
77 changes: 55 additions & 22 deletions unstructured_ingest/v2/processes/connectors/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
edges = [
_Edge(
source_id=u.id_,
source_labels=u.labels,
destination_id=v.id_,
destination_labels=v.labels,
relationship=Relationship(data_dict["relationship"]),
)
for u, v, data_dict in nx_graph.edges(data=True)
Expand All @@ -210,7 +212,9 @@ class _Edge(BaseModel):
model_config = ConfigDict()

source_id: str
source_labels: list[Label]
destination_id: str
destination_labels: list[Label]
relationship: Relationship


Expand All @@ -229,7 +233,7 @@ class Relationship(Enum):

class Neo4jUploaderConfig(UploaderConfig):
batch_size: int = Field(
default=100, description="Maximal number of nodes/relationships created per transaction."
default=1000, description="Maximal number of nodes/relationships created per transaction."
)


Expand Down Expand Up @@ -257,6 +261,8 @@ async def run_async(self, path: Path, file_data: FileData, **kwargs) -> None: #
graph_data = _GraphData.model_validate(staged_data)
async with self.connection_config.get_client() as client:
await self._create_uniqueness_constraints(client)
# TODO need chunker config
# await self._create_vector_index(client, self.upload_config.dimensions, self.upload_config.similarity_function)
await self._delete_old_data_if_exists(file_data, client=client)
await self._merge_graph(graph_data=graph_data, client=client)

Expand All @@ -274,48 +280,73 @@ async def _create_uniqueness_constraints(self, client: AsyncDriver) -> None:
"""
)

async def _create_vector_index(self, client: AsyncDriver, dimensions : int, similarity_function: str) -> None:
label = Label.CHUNK
logger.info(
f"Adding id uniqueness constraint for nodes labeled '{label.value}'"
" if it does not already exist."
)
index_name = f"{label.value.lower()}_vector"
await client.execute_query(
f"""
CREATE VECTOR INDEX {index_name} IF NOT EXISTS
FOR (n:{label.value}) ON n.embedding
OPTIONS {{`vector.similarity_function`: '{similarity_function}', `vector.dimensions`: {dimensions}}}
"""
)

async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
_, summary, _ = await client.execute_query(
f"""
MATCH (n: {Label.DOCUMENT.value} {{id: $identifier}})
MATCH (n)--(m: {Label.CHUNK.value}|{Label.UNSTRUCTURED_ELEMENT.value})
DETACH DELETE m""",
MATCH (n: `{Label.DOCUMENT.value}` {{id: $identifier}})
MATCH (n)--(m: `{Label.CHUNK.value}`|`{Label.UNSTRUCTURED_ELEMENT.value}`)
DETACH DELETE m
DETACH DELETE n""",
identifier=file_data.identifier,
)
logger.info(
f"Deleted {summary.counters.nodes_deleted} nodes"
f" and {summary.counters.relationships_deleted} relationships."
)

def _main_label(self, labels: list[Label]) -> Label:
if labels is None or len(labels) == 0: return None

for label in [Label.CHUNK, Label.DOCUMENT, Label.UNSTRUCTURED_ELEMENT]:
if label in labels:
return label
else:
return labels[0]

async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> None:
nodes_by_labels: defaultdict[tuple[Label, ...], list[_Node]] = defaultdict(list)
nodes_by_labels: defaultdict[Label, list[_Node]] = defaultdict(list)
for node in graph_data.nodes:
nodes_by_labels[tuple(node.labels)].append(node)

nodes_by_labels[self._main_label(node.labels)].append(node)
logger.info(f"Merging {len(graph_data.nodes)} graph nodes.")
# NOTE: Processed in parallel as there's no overlap between accessed nodes
await self._execute_queries(
[
self._create_nodes_query(nodes_batch, labels)
for labels, nodes in nodes_by_labels.items()
self._create_nodes_query(nodes_batch, label)
for label, nodes in nodes_by_labels.items()
for nodes_batch in batch_generator(nodes, batch_size=self.upload_config.batch_size)
],
client=client,
in_parallel=True,
)
logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")

edges_by_relationship: defaultdict[Relationship, list[_Edge]] = defaultdict(list)
edges_by_relationship: defaultdict[tuple[Relationship, Label, Label], list[_Edge]] = defaultdict(list)
for edge in graph_data.edges:
edges_by_relationship[edge.relationship].append(edge)
key = tuple([edge.relationship, self._main_label(edge.source_labels), self._main_label(edge.destination_labels)])
edges_by_relationship[key].append(edge)

logger.info(f"Merging {len(graph_data.edges)} graph relationships (edges).")
# NOTE: Processed sequentially to avoid queries locking node access to one another
await self._execute_queries(
[
self._create_edges_query(edges_batch, relationship)
for relationship, edges in edges_by_relationship.items()
self._create_edges_query(edges_batch, relationship_key)
for relationship_key, edges in edges_by_relationship.items()
for edges_batch in batch_generator(edges, batch_size=self.upload_config.batch_size)
],
client=client,
Expand Down Expand Up @@ -348,25 +379,27 @@ async def _execute_queries(
)

@staticmethod
def _create_nodes_query(nodes: list[_Node], labels: tuple[Label, ...]) -> tuple[str, dict]:
labels_string = ", ".join([label.value for label in labels])
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
def _create_nodes_query(nodes: list[_Node], label: Label) -> tuple[str, dict]:
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{label}'.")
query_string = f"""
UNWIND $nodes AS node
MERGE (n: {labels_string} {{id: node.id}})
MERGE (n: `{label}` {{id: node.id}})
SET n += node.properties
SET n:$(node.labels)
WITH * WHERE node.vector IS NOT NULL
CALL db.create.setNodeVectorProperty(n, 'embedding', node.vector)
"""
parameters = {"nodes": [{"id": node.id_, "properties": node.properties} for node in nodes]}
parameters = {"nodes": [{"id": node.id_, "labels":[l.value for l in node.labels if l != label],"vector":node.properties.pop('embedding', None), "properties": node.properties} for node in nodes]}
return query_string, parameters

@staticmethod
def _create_edges_query(edges: list[_Edge], relationship: Relationship) -> tuple[str, dict]:
def _create_edges_query(edges: list[_Edge], relationship: tuple[Relationship,Label,Label]) -> tuple[str, dict]:
logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
query_string = f"""
UNWIND $edges AS edge
MATCH (u {{id: edge.source}})
MATCH (v {{id: edge.destination}})
MERGE (u)-[:{relationship.value}]->(v)
MATCH (u: `{relationship[1].value}` {{id: edge.source}})
MATCH (v: `{relationship[2].value}` {{id: edge.destination}})
MERGE (u)-[:`{relationship[0].value}`]->(v)
"""
parameters = {
"edges": [
Expand Down

0 comments on commit a2b1944

Please sign in to comment.