index to neo4j working and tested. check queires next.

This commit is contained in:
Ken Wiltshire
2024-10-30 17:48:14 -04:00
parent 9ab7312ecc
commit e4509327dd
13 changed files with 32177 additions and 9668 deletions

5
.gitignore vendored
View File

@@ -5,4 +5,7 @@ book.txt
lightrag-dev/ lightrag-dev/
.idea/ .idea/
dist/ dist/
env/ env/
local_neo4jWorkDir/
local_neo4jWorkDir.bak/
neo4jWorkDir/

34
get_all_edges_nx.py Normal file
View File

@@ -0,0 +1,34 @@
import networkx as nx
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
def get_all_edges_and_nodes(G):
# Get all edges and their properties
edges_with_properties = []
for u, v, data in G.edges(data=True):
edges_with_properties.append({
'start': u,
'end': v,
'label': data.get('label', ''), # Assuming 'label' is used for edge type
'properties': data,
'start_node_properties': G.nodes[u],
'end_node_properties': G.nodes[v]
})
return edges_with_properties
# Example usage
if __name__ == "__main__":
# Assume G is your NetworkX graph loaded from Neo4j
all_edges = get_all_edges_and_nodes(G)
# Print all edges and node properties
for edge in all_edges:
print(f"Edge Label: {edge['label']}")
print(f"Edge Properties: {edge['properties']}")
print(f"Start Node: {edge['start']}")
print(f"Start Node Properties: {edge['start_node_properties']}")
print(f"End Node: {edge['end']}")
print(f"End Node Properties: {edge['end_node_properties']}")
print("---")

File diff suppressed because it is too large Load Diff

View File

@@ -4,24 +4,42 @@ import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast from typing import Any, Union, cast
import numpy as np import numpy as np
from nano_vectordb import NanoVectorDB
import inspect import inspect
# import package.common.utils as utils
# import package.common.utils as utils
from lightrag.utils import load_json, logger, write_json from lightrag.utils import load_json, logger, write_json
from ..base import ( from ..base import (
BaseGraphStorage BaseGraphStorage
) )
from neo4j import GraphDatabase from neo4j import GraphDatabase, exceptions as neo4jExceptions
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
# during indexing.
# Replace with your actual URI, username, and password # Replace with your actual URI, username, and password
#local
URI = "neo4j://localhost:7687" URI = "neo4j://localhost:7687"
USERNAME = "neo4j" USERNAME = "neo4j"
PASSWORD = "your_password" PASSWORD = "password"
#aura
# URI = "neo4j+s://91fbae6c.databases.neo4j.io"
# USERNAME = "neo4j"
# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
# Create a driver object # Create a driver object
@@ -33,7 +51,7 @@ class GraphStorage(BaseGraphStorage):
def __post_init__(self): def __post_init__(self):
# self._graph = preloaded_graph or nx.Graph() # self._graph = preloaded_graph or nx.Graph()
self._driver = GraphDatabase.driver("neo4j+s://91fbae6c.databases.neo4j.io", auth=("neo4j", "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw")) self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
@@ -129,13 +147,11 @@ class GraphStorage(BaseGraphStorage):
# degree = session.read_transaction(get_edge_degree, 1, 2) # degree = session.read_transaction(get_edge_degree, 1, 2)
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name__label_source = src_id.strip('\"') entity_name_label_source = src_id.strip('\"')
entity_name_label_target = tgt_id.strip('\"') entity_name_label_target = tgt_id.strip('\"')
with self._driver.session() as session: with self._driver.session() as session:
query = """MATCH (n1:`{node_label1}`)-[r]-(n2:`{node_label2}`) query = f"""MATCH (n1:`{entity_name_label_source}`)-[r]-(n2:`{entity_name_label_target}`)
RETURN count(r) AS degree""".format(entity_name__label_source=entity_name__label_source, RETURN count(r) AS degree"""
entity_name_label_target=entity_name_label_target)
result = session.run(query) result = session.run(query)
record = result.single() record = result.single()
logger.info( logger.info(
@@ -144,7 +160,7 @@ class GraphStorage(BaseGraphStorage):
return record["degree"] return record["degree"]
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
entity_name__label_source = source_node_id.strip('\"') entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"') entity_name_label_target = target_node_id.strip('\"')
""" """
Find all edges between nodes of two given labels Find all edges between nodes of two given labels
@@ -156,28 +172,25 @@ class GraphStorage(BaseGraphStorage):
Returns: Returns:
list: List of all relationships/edges found list: List of all relationships/edges found
""" """
with self._driver.session() as session: with self._driver.session() as session:
query = f""" query = f"""
MATCH (source:`{entity_name__label_source}`)-[r]-(target:`{entity_name_label_target}`) MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN r RETURN properties(r) as edge_properties
""" LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
result = session.run(query) result = session.run(query)
for logrecord in result: record = result.single()
if record:
result = dict(record["edge_properties"])
logger.info( logger.info(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{logrecord["r"]}' f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
) )
return result
else:
return [record["r"] for record in result] return None
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str):
node_label = source_node_id.strip('\"') node_label = source_node_id.strip('\"')
@@ -208,8 +221,8 @@ class GraphStorage(BaseGraphStorage):
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
if source_label and target_label: if source_label and target_label:
print (f"appending: {[source_label, target_label]}") print (f"appending: {(source_label, target_label)}")
edges.append([source_label, target_label]) edges.append((source_label, target_label))
return edges return edges
@@ -218,57 +231,54 @@ class GraphStorage(BaseGraphStorage):
return edges return edges
# try:
# with self._driver.session() as session:
# if self.has_node(node_label):
# edges = session.read_transaction(fetch_edges,node_label)
# return list(edges)
# return edges
# finally:
# print ("consider closign driver here")
# # driver.close()
from typing import List, Tuple # from typing import List, Tuple
async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]: # async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]:
def run_query(tx): # def get_connections_for_node(tx):
query = f""" # query = f"""
MATCH (n:`{label}`) # MATCH (n:`{label}`)
OPTIONAL MATCH (n)-[r]-(connected) # OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected # RETURN n, r, connected
""" # """
results = tx.run(query) # results = tx.run(query)
connections = []
for record in results:
source_node = record['n']
connected_node = record['connected']
source_label = list(source_node.labels)[0] if source_node.labels else None
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
if source_label and target_label:
connections.append((source_label, target_label))
return connections # connections = []
# for record in results:
# source_node = record['n']
# connected_node = record['connected']
# source_label = list(source_node.labels)[0] if source_node.labels else None
# target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
# if source_label and target_label:
# connections.append((source_label, target_label))
with driver.session() as session: # logger.info(
return session.read_transaction(run_query) # f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
# )
# return connections
# with driver.session() as session:
# return session.read_transaction(get_connections_for_node)
#upsert_node #upsert_node
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: dict[str, str]):
label = node_id.strip('\"') label = node_id.strip('\"')
properties = node_data properties = node_data
""" """
Upsert a node with the given label and properties within a transaction. Upsert a node with the given label and properties within a transaction.
If a node with the same label exists, it will:
- Update existing properties with new values
- Add new properties that don't exist
If no node exists, creates a new node with all properties.
Args: Args:
label: The node label to search for and apply label: The node label to search for and apply
properties: Dictionary of node properties properties: Dictionary of node properties
@@ -355,7 +365,7 @@ class GraphStorage(BaseGraphStorage):
result = tx.run(query, properties=edge_properties) result = tx.run(query, properties=edge_properties)
logger.info( logger.info(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{None}' f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
) )
return result.single() return result.single()
@@ -369,6 +379,8 @@ class GraphStorage(BaseGraphStorage):
# return result # return result
async def _node2vec_embed(self): async def _node2vec_embed(self):
print ("this is never called. checking to be sure.")
# async def _node2vec_embed(self): # async def _node2vec_embed(self):
with self._driver.session() as session: with self._driver.session() as session:
#Define the Cypher query #Define the Cypher query

View File

@@ -102,8 +102,8 @@ class LightRAG:
# module = importlib.import_module('kg.neo4j') # module = importlib.import_module('kg.neo4j')
# Neo4JStorage = getattr(module, 'GraphStorage') # Neo4JStorage = getattr(module, 'GraphStorage')
if True==True: if True==True:
print ("using KG")
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
else: else:
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage

View File

@@ -235,7 +235,7 @@ class NetworkXStorage(BaseGraphStorage):
async def _node2vec_embed(self): async def _node2vec_embed(self):
from graspologic import embed from graspologic import embed
print ("is this ever called?")
embeddings, nodes = embed.node2vec_embed( embeddings, nodes = embed.node2vec_embed(
self._graph, self._graph,
**self.global_config["node2vec_params"], **self.global_config["node2vec_params"],

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

23
test.py
View File

@@ -1,16 +1,35 @@
import os import os
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from pprint import pprint
######### #########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio # import nest_asyncio
# nest_asyncio.apply() # nest_asyncio.apply()
######### #########
WORKING_DIR = "./dickens" WORKING_DIR = "./dickensTestEmbedcall"
# G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
# nx.write_gexf(G, "graph_chunk_entity_relation.gefx")
import networkx as nx
from networkx_query import search_nodes, search_edges
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
query = {} # Empty query matches all nodes
result = search_nodes(G, query)
# Extract node IDs from the result
node_ids = sorted([node for node in result])
print("All node IDs in the graph:")
pprint(node_ids)
raise Exception
# raise Exception
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)

View File

@@ -8,8 +8,7 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
# nest_asyncio.apply() # nest_asyncio.apply()
######### #########
WORKING_DIR = "./neo4jWorkDir" WORKING_DIR = "./local_neo4jWorkDir"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)