index to neo4j working and tested. check queires next.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -6,3 +6,6 @@ lightrag-dev/
|
||||
.idea/
|
||||
dist/
|
||||
env/
|
||||
local_neo4jWorkDir/
|
||||
local_neo4jWorkDir.bak/
|
||||
neo4jWorkDir/
|
34
get_all_edges_nx.py
Normal file
34
get_all_edges_nx.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import networkx as nx
|
||||
|
||||
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
|
||||
|
||||
def get_all_edges_and_nodes(G):
|
||||
# Get all edges and their properties
|
||||
edges_with_properties = []
|
||||
for u, v, data in G.edges(data=True):
|
||||
edges_with_properties.append({
|
||||
'start': u,
|
||||
'end': v,
|
||||
'label': data.get('label', ''), # Assuming 'label' is used for edge type
|
||||
'properties': data,
|
||||
'start_node_properties': G.nodes[u],
|
||||
'end_node_properties': G.nodes[v]
|
||||
})
|
||||
|
||||
return edges_with_properties
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Assume G is your NetworkX graph loaded from Neo4j
|
||||
|
||||
all_edges = get_all_edges_and_nodes(G)
|
||||
|
||||
# Print all edges and node properties
|
||||
for edge in all_edges:
|
||||
print(f"Edge Label: {edge['label']}")
|
||||
print(f"Edge Properties: {edge['properties']}")
|
||||
print(f"Start Node: {edge['start']}")
|
||||
print(f"Start Node Properties: {edge['start_node_properties']}")
|
||||
print(f"End Node: {edge['end']}")
|
||||
print(f"End Node Properties: {edge['end_node_properties']}")
|
||||
print("---")
|
4621
graph_chunk_entity_relation.gefx
Normal file
4621
graph_chunk_entity_relation.gefx
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,24 +4,42 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
import inspect
|
||||
|
||||
|
||||
|
||||
|
||||
# import package.common.utils as utils
|
||||
|
||||
|
||||
from lightrag.utils import load_json, logger, write_json
|
||||
from ..base import (
|
||||
BaseGraphStorage
|
||||
)
|
||||
from neo4j import GraphDatabase
|
||||
from neo4j import GraphDatabase, exceptions as neo4jExceptions
|
||||
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
|
||||
# during indexing.
|
||||
|
||||
|
||||
|
||||
|
||||
# Replace with your actual URI, username, and password
|
||||
#local
|
||||
URI = "neo4j://localhost:7687"
|
||||
USERNAME = "neo4j"
|
||||
PASSWORD = "your_password"
|
||||
PASSWORD = "password"
|
||||
|
||||
#aura
|
||||
# URI = "neo4j+s://91fbae6c.databases.neo4j.io"
|
||||
# USERNAME = "neo4j"
|
||||
# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
|
||||
# Create a driver object
|
||||
|
||||
|
||||
@@ -33,7 +51,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
|
||||
def __post_init__(self):
|
||||
# self._graph = preloaded_graph or nx.Graph()
|
||||
self._driver = GraphDatabase.driver("neo4j+s://91fbae6c.databases.neo4j.io", auth=("neo4j", "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"))
|
||||
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
@@ -129,13 +147,11 @@ class GraphStorage(BaseGraphStorage):
|
||||
|
||||
# degree = session.read_transaction(get_edge_degree, 1, 2)
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
entity_name__label_source = src_id.strip('\"')
|
||||
entity_name_label_source = src_id.strip('\"')
|
||||
entity_name_label_target = tgt_id.strip('\"')
|
||||
with self._driver.session() as session:
|
||||
query = """MATCH (n1:`{node_label1}`)-[r]-(n2:`{node_label2}`)
|
||||
RETURN count(r) AS degree""".format(entity_name__label_source=entity_name__label_source,
|
||||
entity_name_label_target=entity_name_label_target)
|
||||
|
||||
query = f"""MATCH (n1:`{entity_name_label_source}`)-[r]-(n2:`{entity_name_label_target}`)
|
||||
RETURN count(r) AS degree"""
|
||||
result = session.run(query)
|
||||
record = result.single()
|
||||
logger.info(
|
||||
@@ -144,7 +160,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
return record["degree"]
|
||||
|
||||
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
|
||||
entity_name__label_source = source_node_id.strip('\"')
|
||||
entity_name_label_source = source_node_id.strip('\"')
|
||||
entity_name_label_target = target_node_id.strip('\"')
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
@@ -158,26 +174,23 @@ class GraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
with self._driver.session() as session:
|
||||
query = f"""
|
||||
MATCH (source:`{entity_name__label_source}`)-[r]-(target:`{entity_name_label_target}`)
|
||||
RETURN r
|
||||
"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
|
||||
|
||||
result = session.run(query)
|
||||
for logrecord in result:
|
||||
record = result.single()
|
||||
if record:
|
||||
result = dict(record["edge_properties"])
|
||||
logger.info(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{logrecord["r"]}'
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
||||
)
|
||||
|
||||
|
||||
return [record["r"] for record in result]
|
||||
|
||||
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
node_label = source_node_id.strip('\"')
|
||||
|
||||
@@ -208,8 +221,8 @@ class GraphStorage(BaseGraphStorage):
|
||||
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
|
||||
|
||||
if source_label and target_label:
|
||||
print (f"appending: {[source_label, target_label]}")
|
||||
edges.append([source_label, target_label])
|
||||
print (f"appending: {(source_label, target_label)}")
|
||||
edges.append((source_label, target_label))
|
||||
|
||||
return edges
|
||||
|
||||
@@ -218,57 +231,54 @@ class GraphStorage(BaseGraphStorage):
|
||||
return edges
|
||||
|
||||
|
||||
# try:
|
||||
# with self._driver.session() as session:
|
||||
# if self.has_node(node_label):
|
||||
# edges = session.read_transaction(fetch_edges,node_label)
|
||||
# return list(edges)
|
||||
# return edges
|
||||
# finally:
|
||||
# print ("consider closign driver here")
|
||||
# # driver.close()
|
||||
|
||||
from typing import List, Tuple
|
||||
async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]:
|
||||
def run_query(tx):
|
||||
query = f"""
|
||||
MATCH (n:`{label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected
|
||||
"""
|
||||
results = tx.run(query)
|
||||
# from typing import List, Tuple
|
||||
# async def get_node_connections(driver: GraphDatabase.driver, label: str) -> List[Tuple[str, str]]:
|
||||
# def get_connections_for_node(tx):
|
||||
# query = f"""
|
||||
# MATCH (n:`{label}`)
|
||||
# OPTIONAL MATCH (n)-[r]-(connected)
|
||||
# RETURN n, r, connected
|
||||
# """
|
||||
# results = tx.run(query)
|
||||
|
||||
connections = []
|
||||
for record in results:
|
||||
source_node = record['n']
|
||||
connected_node = record['connected']
|
||||
|
||||
source_label = list(source_node.labels)[0] if source_node.labels else None
|
||||
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
|
||||
# connections = []
|
||||
# for record in results:
|
||||
# source_node = record['n']
|
||||
# connected_node = record['connected']
|
||||
|
||||
if source_label and target_label:
|
||||
connections.append((source_label, target_label))
|
||||
# source_label = list(source_node.labels)[0] if source_node.labels else None
|
||||
# target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
|
||||
|
||||
return connections
|
||||
# if source_label and target_label:
|
||||
# connections.append((source_label, target_label))
|
||||
|
||||
with driver.session() as session:
|
||||
return session.read_transaction(run_query)
|
||||
# logger.info(
|
||||
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
|
||||
# )
|
||||
# return connections
|
||||
|
||||
# with driver.session() as session:
|
||||
|
||||
# return session.read_transaction(get_connections_for_node)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#upsert_node
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
label = node_id.strip('\"')
|
||||
properties = node_data
|
||||
"""
|
||||
Upsert a node with the given label and properties within a transaction.
|
||||
If a node with the same label exists, it will:
|
||||
- Update existing properties with new values
|
||||
- Add new properties that don't exist
|
||||
If no node exists, creates a new node with all properties.
|
||||
|
||||
Args:
|
||||
label: The node label to search for and apply
|
||||
properties: Dictionary of node properties
|
||||
@@ -355,7 +365,7 @@ class GraphStorage(BaseGraphStorage):
|
||||
|
||||
result = tx.run(query, properties=edge_properties)
|
||||
logger.info(
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{None}'
|
||||
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
|
||||
)
|
||||
return result.single()
|
||||
|
||||
@@ -369,6 +379,8 @@ class GraphStorage(BaseGraphStorage):
|
||||
# return result
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
print ("this is never called. checking to be sure.")
|
||||
|
||||
# async def _node2vec_embed(self):
|
||||
with self._driver.session() as session:
|
||||
#Define the Cypher query
|
||||
|
@@ -102,8 +102,8 @@ class LightRAG:
|
||||
|
||||
# module = importlib.import_module('kg.neo4j')
|
||||
# Neo4JStorage = getattr(module, 'GraphStorage')
|
||||
|
||||
if True==True:
|
||||
print ("using KG")
|
||||
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
|
||||
else:
|
||||
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
|
||||
|
@@ -235,7 +235,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
from graspologic import embed
|
||||
|
||||
print ("is this ever called?")
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
23
test.py
23
test.py
@@ -1,16 +1,35 @@
|
||||
import os
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
||||
|
||||
from pprint import pprint
|
||||
#########
|
||||
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
#########
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
WORKING_DIR = "./dickensTestEmbedcall"
|
||||
|
||||
|
||||
# G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
|
||||
# nx.write_gexf(G, "graph_chunk_entity_relation.gefx")
|
||||
|
||||
import networkx as nx
|
||||
from networkx_query import search_nodes, search_edges
|
||||
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
|
||||
query = {} # Empty query matches all nodes
|
||||
result = search_nodes(G, query)
|
||||
|
||||
# Extract node IDs from the result
|
||||
node_ids = sorted([node for node in result])
|
||||
|
||||
print("All node IDs in the graph:")
|
||||
pprint(node_ids)
|
||||
raise Exception
|
||||
|
||||
|
||||
# raise Exception
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
Reference in New Issue
Block a user