index to neo4j working and tested. check queires next.
This commit is contained in:
@@ -4,24 +4,42 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
import inspect
|
||||
|
||||
|
||||
|
||||
|
||||
# import package.common.utils as utils
|
||||
|
||||
|
||||
# import package.common.utils as utils
|
||||
from lightrag.utils import load_json, logger, write_json
|
||||
from ..base import (
|
||||
BaseGraphStorage
|
||||
)
|
||||
from neo4j import GraphDatabase
|
||||
from neo4j import GraphDatabase, exceptions as neo4jExceptions
|
||||
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
|
||||
# during indexing.
|
||||
|
||||
|
||||
|
||||
|
||||
# Replace with your actual URI, username, and password
|
||||
#local
|
||||
URI = "neo4j://localhost:7687"
|
||||
USERNAME = "neo4j"
|
||||
PASSWORD = "your_password"
|
||||
PASSWORD = "password"
|
||||
|
||||
#aura
|
||||
# URI = "neo4j+s://91fbae6c.databases.neo4j.io"
|
||||
# USERNAME = "neo4j"
|
||||
# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
|
||||
# Create a driver object
|
||||
|
||||
|
||||
@@ -33,7 +51,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
|
||||
def __post_init__(self):
|
||||
# self._graph = preloaded_graph or nx.Graph()
|
||||
self._driver = GraphDatabase.driver("neo4j+s://91fbae6c.databases.neo4j.io", auth=("neo4j", "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"))
|
||||
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
@@ -129,13 +147,11 @@ class GraphStorage(BaseGraphStorage):
|
||||
|
||||
# degree = session.read_transaction(get_edge_degree, 1, 2)
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
entity_name__label_source = src_id.strip('\"')
|
||||
entity_name_label_source = src_id.strip('\"')
|
||||
entity_name_label_target = tgt_id.strip('\"')
|
||||
with self._driver.session() as session:
|
||||
query = """MATCH (n1:`{node_label1}`)-[r]-(n2:`{node_label2}`)
|
||||
RETURN count(r) AS degree""".format(entity_name__label_source=entity_name__label_source,
|
||||
entity_name_label_target=entity_name_label_target)
|
||||
|
||||
query = f"""MATCH (n1:`{entity_name_label_source}`)-[r]-(n2:`{entity_name_label_target}`)
|
||||
RETURN count(r) AS degree"""
|
||||
result = session.run(query)
|
||||
record = result.single()
|
||||
logger.info(
|
||||
@@ -144,7 +160,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
return record["degree"]
|
||||
|
||||
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
|
||||
entity_name__label_source = source_node_id.strip('\"')
|
||||
entity_name_label_source = source_node_id.strip('\"')
|
||||
entity_name_label_target = target_node_id.strip('\"')
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
@@ -156,28 +172,25 @@ class GraphStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
with self._driver.session() as session:
|
||||
with self._driver.session() as session:
|
||||
query = f"""
|
||||
MATCH (source:`{entity_name__label_source}`)-[r]-(target:`{entity_name_label_target}`)
|
||||
RETURN r
|
||||
"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
|
||||
|
||||
result = session.run(query)
|
||||
for logrecord in result:
|
||||
result = session.run(query)
|
||||
record = result.single()
|
||||
if record:
|
||||
result = dict(record["edge_properties"])
|
||||
logger.info(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{logrecord["r"]}'
|
||||
)
|
||||
|
||||
|
||||
return [record["r"] for record in result]
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
node_label = source_node_id.strip('\"')
|
||||
|
||||
@@ -208,8 +221,8 @@ class GraphStorage(BaseGraphStorage):
|
||||
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
|
||||
|
||||
if source_label and target_label:
|
||||
print (f"appending: {[source_label, target_label]}")
|
||||
edges.append([source_label, target_label])
|
||||
print (f"appending: {(source_label, target_label)}")
|
||||
edges.append((source_label, target_label))
|
||||
|
||||
return edges
|
||||
|
||||
@@ -218,57 +231,54 @@ class GraphStorage(BaseGraphStorage):
|
||||
return edges
|
||||
|
||||
|
||||
# try:
|
||||
# with self._driver.session() as session:
|
||||
# if self.has_node(node_label):
|
||||
# edges = session.read_transaction(fetch_edges,node_label)
|
||||
# return list(edges)
|
||||
# return edges
|
||||
# finally:
|
||||
# print ("consider closign driver here")
|
||||
# # driver.close()
|
||||
|
||||
from typing import List, Tuple
|
||||
async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]:
|
||||
def run_query(tx):
|
||||
query = f"""
|
||||
MATCH (n:`{label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected
|
||||
"""
|
||||
results = tx.run(query)
|
||||
# from typing import List, Tuple
|
||||
# async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]:
|
||||
# def get_connections_for_node(tx):
|
||||
# query = f"""
|
||||
# MATCH (n:`{label}`)
|
||||
# OPTIONAL MATCH (n)-[r]-(connected)
|
||||
# RETURN n, r, connected
|
||||
# """
|
||||
# results = tx.run(query)
|
||||
|
||||
connections = []
|
||||
for record in results:
|
||||
source_node = record['n']
|
||||
connected_node = record['connected']
|
||||
|
||||
source_label = list(source_node.labels)[0] if source_node.labels else None
|
||||
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
|
||||
|
||||
if source_label and target_label:
|
||||
connections.append((source_label, target_label))
|
||||
|
||||
return connections
|
||||
# connections = []
|
||||
# for record in results:
|
||||
# source_node = record['n']
|
||||
# connected_node = record['connected']
|
||||
|
||||
# source_label = list(source_node.labels)[0] if source_node.labels else None
|
||||
# target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
|
||||
|
||||
# if source_label and target_label:
|
||||
# connections.append((source_label, target_label))
|
||||
|
||||
with driver.session() as session:
|
||||
return session.read_transaction(run_query)
|
||||
# logger.info(
|
||||
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
|
||||
# )
|
||||
# return connections
|
||||
|
||||
# with driver.session() as session:
|
||||
|
||||
# return session.read_transaction(get_connections_for_node)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#upsert_node
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
label = node_id.strip('\"')
|
||||
properties = node_data
|
||||
"""
|
||||
Upsert a node with the given label and properties within a transaction.
|
||||
If a node with the same label exists, it will:
|
||||
- Update existing properties with new values
|
||||
- Add new properties that don't exist
|
||||
If no node exists, creates a new node with all properties.
|
||||
|
||||
Args:
|
||||
label: The node label to search for and apply
|
||||
properties: Dictionary of node properties
|
||||
@@ -355,7 +365,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
|
||||
result = tx.run(query, properties=edge_properties)
|
||||
logger.info(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{None}'
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
|
||||
)
|
||||
return result.single()
|
||||
|
||||
@@ -369,6 +379,8 @@ class GraphStorage(BaseGraphStorage):
|
||||
# return result
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
print ("this is never called. checking to be sure.")
|
||||
|
||||
# async def _node2vec_embed(self):
|
||||
with self._driver.session() as session:
|
||||
#Define the Cypher query
|
||||
|
@@ -102,8 +102,8 @@ class LightRAG:
|
||||
|
||||
# module = importlib.import_module('kg.neo4j')
|
||||
# Neo4JStorage = getattr(module, 'GraphStorage')
|
||||
|
||||
if True==True:
|
||||
print ("using KG")
|
||||
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
|
||||
else:
|
||||
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
|
||||
|
@@ -235,7 +235,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
from graspologic import embed
|
||||
|
||||
print ("is this ever called?")
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
|
Reference in New Issue
Block a user