index to neo4j working and tested. check queires next.

This commit is contained in:
Ken Wiltshire
2024-10-30 17:48:14 -04:00
parent 9ab7312ecc
commit e4509327dd
13 changed files with 32177 additions and 9668 deletions

View File

@@ -4,24 +4,42 @@ import os
from dataclasses import dataclass
from typing import Any, Union, cast
import numpy as np
from nano_vectordb import NanoVectorDB
import inspect
# import package.common.utils as utils
# import package.common.utils as utils
from lightrag.utils import load_json, logger, write_json
from ..base import (
BaseGraphStorage
)
from neo4j import GraphDatabase
from neo4j import GraphDatabase, exceptions as neo4jExceptions
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
# during indexing.
# Replace with your actual URI, username, and password
#local
URI = "neo4j://localhost:7687"
USERNAME = "neo4j"
PASSWORD = "your_password"
PASSWORD = "password"
#aura
# URI = "neo4j+s://91fbae6c.databases.neo4j.io"
# USERNAME = "neo4j"
# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
# Create a driver object
@@ -33,7 +51,7 @@ class GraphStorage(BaseGraphStorage):
def __post_init__(self):
# self._graph = preloaded_graph or nx.Graph()
self._driver = GraphDatabase.driver("neo4j+s://91fbae6c.databases.neo4j.io", auth=("neo4j", "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"))
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
@@ -129,13 +147,11 @@ class GraphStorage(BaseGraphStorage):
# degree = session.read_transaction(get_edge_degree, 1, 2)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name__label_source = src_id.strip('\"')
entity_name_label_source = src_id.strip('\"')
entity_name_label_target = tgt_id.strip('\"')
with self._driver.session() as session:
query = """MATCH (n1:`{node_label1}`)-[r]-(n2:`{node_label2}`)
RETURN count(r) AS degree""".format(entity_name__label_source=entity_name__label_source,
entity_name_label_target=entity_name_label_target)
query = f"""MATCH (n1:`{entity_name_label_source}`)-[r]-(n2:`{entity_name_label_target}`)
RETURN count(r) AS degree"""
result = session.run(query)
record = result.single()
logger.info(
@@ -144,7 +160,7 @@ class GraphStorage(BaseGraphStorage):
return record["degree"]
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
entity_name__label_source = source_node_id.strip('\"')
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
"""
Find all edges between nodes of two given labels
@@ -156,28 +172,25 @@ class GraphStorage(BaseGraphStorage):
Returns:
list: List of all relationships/edges found
"""
with self._driver.session() as session:
with self._driver.session() as session:
query = f"""
MATCH (source:`{entity_name__label_source}`)-[r]-(target:`{entity_name_label_target}`)
RETURN r
"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
result = session.run(query)
for logrecord in result:
result = session.run(query)
record = result.single()
if record:
result = dict(record["edge_properties"])
logger.info(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{logrecord["r"]}'
)
return [record["r"] for record in result]
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
)
return result
else:
return None
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def get_node_edges(self, source_node_id: str):
node_label = source_node_id.strip('\"')
@@ -208,8 +221,8 @@ class GraphStorage(BaseGraphStorage):
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
if source_label and target_label:
print (f"appending: {[source_label, target_label]}")
edges.append([source_label, target_label])
print (f"appending: {(source_label, target_label)}")
edges.append((source_label, target_label))
return edges
@@ -218,57 +231,54 @@ class GraphStorage(BaseGraphStorage):
return edges
# try:
# with self._driver.session() as session:
# if self.has_node(node_label):
# edges = session.read_transaction(fetch_edges,node_label)
# return list(edges)
# return edges
# finally:
# print ("consider closign driver here")
# # driver.close()
from typing import List, Tuple
async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]:
def run_query(tx):
query = f"""
MATCH (n:`{label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected
"""
results = tx.run(query)
# from typing import List, Tuple
# async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]:
# def get_connections_for_node(tx):
# query = f"""
# MATCH (n:`{label}`)
# OPTIONAL MATCH (n)-[r]-(connected)
# RETURN n, r, connected
# """
# results = tx.run(query)
connections = []
for record in results:
source_node = record['n']
connected_node = record['connected']
source_label = list(source_node.labels)[0] if source_node.labels else None
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
if source_label and target_label:
connections.append((source_label, target_label))
return connections
# connections = []
# for record in results:
# source_node = record['n']
# connected_node = record['connected']
# source_label = list(source_node.labels)[0] if source_node.labels else None
# target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
# if source_label and target_label:
# connections.append((source_label, target_label))
with driver.session() as session:
return session.read_transaction(run_query)
# logger.info(
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
# )
# return connections
# with driver.session() as session:
# return session.read_transaction(get_connections_for_node)
#upsert_node
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
label = node_id.strip('\"')
properties = node_data
"""
Upsert a node with the given label and properties within a transaction.
If a node with the same label exists, it will:
- Update existing properties with new values
- Add new properties that don't exist
If no node exists, creates a new node with all properties.
Args:
label: The node label to search for and apply
properties: Dictionary of node properties
@@ -355,7 +365,7 @@ class GraphStorage(BaseGraphStorage):
result = tx.run(query, properties=edge_properties)
logger.info(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{None}'
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
)
return result.single()
@@ -369,6 +379,8 @@ class GraphStorage(BaseGraphStorage):
# return result
async def _node2vec_embed(self):
print ("this is never called. checking to be sure.")
# async def _node2vec_embed(self):
with self._driver.session() as session:
#Define the Cypher query

View File

@@ -102,8 +102,8 @@ class LightRAG:
# module = importlib.import_module('kg.neo4j')
# Neo4JStorage = getattr(module, 'GraphStorage')
if True==True:
print ("using KG")
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
else:
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage

View File

@@ -235,7 +235,7 @@ class NetworkXStorage(BaseGraphStorage):
async def _node2vec_embed(self):
from graspologic import embed
print ("is this ever called?")
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],