index to neo4j working and tested. check queires next.

This commit is contained in:
Ken Wiltshire
2024-10-30 17:48:14 -04:00
parent 9ab7312ecc
commit e4509327dd
13 changed files with 32177 additions and 9668 deletions

5
.gitignore vendored
View File

@@ -5,4 +5,7 @@ book.txt
lightrag-dev/
.idea/
dist/
env/
env/
local_neo4jWorkDir/
local_neo4jWorkDir.bak/
neo4jWorkDir/

34
get_all_edges_nx.py Normal file
View File

@@ -0,0 +1,34 @@
import networkx as nx
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
def get_all_edges_and_nodes(G):
# Get all edges and their properties
edges_with_properties = []
for u, v, data in G.edges(data=True):
edges_with_properties.append({
'start': u,
'end': v,
'label': data.get('label', ''), # Assuming 'label' is used for edge type
'properties': data,
'start_node_properties': G.nodes[u],
'end_node_properties': G.nodes[v]
})
return edges_with_properties
# Example usage
if __name__ == "__main__":
# Assume G is your NetworkX graph loaded from Neo4j
all_edges = get_all_edges_and_nodes(G)
# Print all edges and node properties
for edge in all_edges:
print(f"Edge Label: {edge['label']}")
print(f"Edge Properties: {edge['properties']}")
print(f"Start Node: {edge['start']}")
print(f"Start Node Properties: {edge['start_node_properties']}")
print(f"End Node: {edge['end']}")
print(f"End Node Properties: {edge['end_node_properties']}")
print("---")

File diff suppressed because it is too large Load Diff

View File

@@ -4,24 +4,42 @@ import os
from dataclasses import dataclass
from typing import Any, Union, cast
import numpy as np
from nano_vectordb import NanoVectorDB
import inspect
# import package.common.utils as utils
# import package.common.utils as utils
from lightrag.utils import load_json, logger, write_json
from ..base import (
BaseGraphStorage
)
from neo4j import GraphDatabase
from neo4j import GraphDatabase, exceptions as neo4jExceptions
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
# during indexing.
# Replace with your actual URI, username, and password
#local
URI = "neo4j://localhost:7687"
USERNAME = "neo4j"
PASSWORD = "your_password"
PASSWORD = "password"
#aura
# URI = "neo4j+s://91fbae6c.databases.neo4j.io"
# USERNAME = "neo4j"
# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
# Create a driver object
@@ -33,7 +51,7 @@ class GraphStorage(BaseGraphStorage):
def __post_init__(self):
# self._graph = preloaded_graph or nx.Graph()
self._driver = GraphDatabase.driver("neo4j+s://91fbae6c.databases.neo4j.io", auth=("neo4j", "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"))
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
@@ -129,13 +147,11 @@ class GraphStorage(BaseGraphStorage):
# degree = session.read_transaction(get_edge_degree, 1, 2)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name__label_source = src_id.strip('\"')
entity_name_label_source = src_id.strip('\"')
entity_name_label_target = tgt_id.strip('\"')
with self._driver.session() as session:
query = """MATCH (n1:`{node_label1}`)-[r]-(n2:`{node_label2}`)
RETURN count(r) AS degree""".format(entity_name__label_source=entity_name__label_source,
entity_name_label_target=entity_name_label_target)
query = f"""MATCH (n1:`{entity_name_label_source}`)-[r]-(n2:`{entity_name_label_target}`)
RETURN count(r) AS degree"""
result = session.run(query)
record = result.single()
logger.info(
@@ -144,7 +160,7 @@ class GraphStorage(BaseGraphStorage):
return record["degree"]
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
entity_name__label_source = source_node_id.strip('\"')
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
"""
Find all edges between nodes of two given labels
@@ -156,28 +172,25 @@ class GraphStorage(BaseGraphStorage):
Returns:
list: List of all relationships/edges found
"""
with self._driver.session() as session:
with self._driver.session() as session:
query = f"""
MATCH (source:`{entity_name__label_source}`)-[r]-(target:`{entity_name_label_target}`)
RETURN r
"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
result = session.run(query)
for logrecord in result:
result = session.run(query)
record = result.single()
if record:
result = dict(record["edge_properties"])
logger.info(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{logrecord["r"]}'
)
return [record["r"] for record in result]
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
)
return result
else:
return None
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def get_node_edges(self, source_node_id: str):
node_label = source_node_id.strip('\"')
@@ -208,8 +221,8 @@ class GraphStorage(BaseGraphStorage):
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
if source_label and target_label:
print (f"appending: {[source_label, target_label]}")
edges.append([source_label, target_label])
print (f"appending: {(source_label, target_label)}")
edges.append((source_label, target_label))
return edges
@@ -218,57 +231,54 @@ class GraphStorage(BaseGraphStorage):
return edges
# try:
# with self._driver.session() as session:
# if self.has_node(node_label):
# edges = session.read_transaction(fetch_edges,node_label)
# return list(edges)
# return edges
# finally:
# print ("consider closign driver here")
# # driver.close()
from typing import List, Tuple
async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]:
def run_query(tx):
query = f"""
MATCH (n:`{label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected
"""
results = tx.run(query)
# from typing import List, Tuple
# async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]:
# def get_connections_for_node(tx):
# query = f"""
# MATCH (n:`{label}`)
# OPTIONAL MATCH (n)-[r]-(connected)
# RETURN n, r, connected
# """
# results = tx.run(query)
connections = []
for record in results:
source_node = record['n']
connected_node = record['connected']
source_label = list(source_node.labels)[0] if source_node.labels else None
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
if source_label and target_label:
connections.append((source_label, target_label))
return connections
# connections = []
# for record in results:
# source_node = record['n']
# connected_node = record['connected']
# source_label = list(source_node.labels)[0] if source_node.labels else None
# target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
# if source_label and target_label:
# connections.append((source_label, target_label))
with driver.session() as session:
return session.read_transaction(run_query)
# logger.info(
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
# )
# return connections
# with driver.session() as session:
# return session.read_transaction(get_connections_for_node)
#upsert_node
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
label = node_id.strip('\"')
properties = node_data
"""
Upsert a node with the given label and properties within a transaction.
If a node with the same label exists, it will:
- Update existing properties with new values
- Add new properties that don't exist
If no node exists, creates a new node with all properties.
Args:
label: The node label to search for and apply
properties: Dictionary of node properties
@@ -355,7 +365,7 @@ class GraphStorage(BaseGraphStorage):
result = tx.run(query, properties=edge_properties)
logger.info(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{None}'
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
)
return result.single()
@@ -369,6 +379,8 @@ class GraphStorage(BaseGraphStorage):
# return result
async def _node2vec_embed(self):
print ("this is never called. checking to be sure.")
# async def _node2vec_embed(self):
with self._driver.session() as session:
#Define the Cypher query

View File

@@ -102,8 +102,8 @@ class LightRAG:
# module = importlib.import_module('kg.neo4j')
# Neo4JStorage = getattr(module, 'GraphStorage')
if True==True:
print ("using KG")
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
else:
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage

View File

@@ -235,7 +235,7 @@ class NetworkXStorage(BaseGraphStorage):
async def _node2vec_embed(self):
from graspologic import embed
print ("is this ever called?")
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

23
test.py
View File

@@ -1,16 +1,35 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from pprint import pprint
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
WORKING_DIR = "./dickensTestEmbedcall"
# G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
# nx.write_gexf(G, "graph_chunk_entity_relation.gefx")
import networkx as nx
from networkx_query import search_nodes, search_edges
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
query = {} # Empty query matches all nodes
result = search_nodes(G, query)
# Extract node IDs from the result
node_ids = sorted([node for node in result])
print("All node IDs in the graph:")
pprint(node_ids)
raise Exception
# raise Exception
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)

View File

@@ -8,8 +8,7 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
# nest_asyncio.apply()
#########
WORKING_DIR = "./neo4jWorkDir"
WORKING_DIR = "./local_neo4jWorkDir"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)