fix event loop conflict

This commit is contained in:
Ken Wiltshire
2024-11-06 11:18:14 -05:00
parent 8420cd1c77
commit 3d5d083f42
9 changed files with 185 additions and 152 deletions

View File

@@ -1,28 +1,34 @@
import networkx as nx
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
G = nx.read_graphml("./dickensTestEmbedcall/graph_chunk_entity_relation.graphml")
def get_all_edges_and_nodes(G):
# Get all edges and their properties
edges_with_properties = []
for u, v, data in G.edges(data=True):
edges_with_properties.append({
'start': u,
'end': v,
'label': data.get('label', ''), # Assuming 'label' is used for edge type
'properties': data,
'start_node_properties': G.nodes[u],
'end_node_properties': G.nodes[v]
})
edges_with_properties.append(
{
"start": u,
"end": v,
"label": data.get(
"label", ""
), # Assuming 'label' is used for edge type
"properties": data,
"start_node_properties": G.nodes[u],
"end_node_properties": G.nodes[v],
}
)
return edges_with_properties
# Example usage
if __name__ == "__main__":
# Assume G is your NetworkX graph loaded from Neo4j
all_edges = get_all_edges_and_nodes(G)
# Print all edges and node properties
for edge in all_edges:
print(f"Edge Label: {edge['label']}")
@@ -31,4 +37,4 @@ if __name__ == "__main__":
print(f"Start Node Properties: {edge['start_node_properties']}")
print(f"End Node: {edge['end']}")
print(f"End Node Properties: {edge['end_node_properties']}")
print("---")
print("---")

View File

@@ -1,17 +1,16 @@
import asyncio
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast, Tuple, List, Dict
import numpy as np
from typing import Any, Union, Tuple, List, Dict
import inspect
from lightrag.utils import load_json, logger, write_json
from ..base import (
BaseGraphStorage
from lightrag.utils import logger
from ..base import BaseGraphStorage
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
)
from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
from contextlib import asynccontextmanager
from tenacity import (
@@ -26,7 +25,7 @@ from tenacity import (
class Neo4JStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print ("no preloading of graph with neo4j in production")
print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config):
super().__init__(namespace=namespace, global_config=global_config)
@@ -35,7 +34,9 @@ class Neo4JStorage(BaseGraphStorage):
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
return None
def __post_init__(self):
@@ -43,59 +44,54 @@ class Neo4JStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def close(self):
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
await self._driver.close()
async def index_done_callback(self):
print ("KG successfully indexed.")
print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('\"')
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
result = await session.run(query)
async with self._driver.session() as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
async with self._driver.session() as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(query)
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
return single_result["edgeExists"]
def close(self):
self._driver.close()
def close(self):
self._driver.close()
async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session:
entity_name_label = node_id.strip('\"')
entity_name_label = node_id.strip('"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
record = await result.single()
@@ -103,54 +99,51 @@ class Neo4JStorage(BaseGraphStorage):
node = record["n"]
node_dict = dict(node)
logger.debug(
f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}'
f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('\"')
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
async with self._driver.session() as session:
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
"""
result = await session.run(query)
record = await result.single()
result = await session.run(query)
record = await result.single()
if record:
edge_count = record["totalEdgeCount"]
edge_count = record["totalEdgeCount"]
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
)
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
)
return edge_count
else:
else:
return None
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('\"')
entity_name_label_target = tgt_id.strip('\"')
entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('"')
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}'
)
f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
)
return degrees
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
"""
Find all edges between nodes of two given labels
@@ -161,28 +154,30 @@ class Neo4JStorage(BaseGraphStorage):
Returns:
list: List of all relationships/edges found
"""
async with self._driver.session() as session:
async with self._driver.session() as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
result = await session.run(query)
""".format(
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
)
result = await session.run(query)
record = await result.single()
if record:
result = dict(record["edge_properties"])
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
)
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
else:
return None
async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]:
node_label = source_node_id.strip('\"')
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
node_label = source_node_id.strip('"')
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
@@ -190,26 +185,37 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session() as session:
async with self._driver.session() as session:
results = await session.run(query)
edges = []
async for record in results:
source_node = record['n']
connected_node = record['connected']
source_label = list(source_node.labels)[0] if source_node.labels else None
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
source_node = record["n"]
connected_node = record["connected"]
source_label = (
list(source_node.labels)[0] if source_node.labels else None
)
target_label = (
list(connected_node.labels)[0]
if connected_node and connected_node.labels
else None
)
if source_label and target_label:
edges.append((source_label, target_label))
return edges
return edges
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
@@ -219,7 +225,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('\"')
label = node_id.strip('"')
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
@@ -228,7 +234,9 @@ class Neo4JStorage(BaseGraphStorage):
SET n += $properties
"""
await tx.run(query, properties=properties)
logger.debug(f"Upserted node with label '{label}' and properties: {properties}")
logger.debug(
f"Upserted node with label '{label}' and properties: {properties}"
)
try:
async with self._driver.session() as session:
@@ -236,13 +244,21 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
)
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]):
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
"""
Upsert an edge and its properties between two nodes identified by their labels.
@@ -251,8 +267,8 @@ class Neo4JStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('\"')
target_node_label = target_node_id.strip('\"')
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
@@ -265,7 +281,9 @@ class Neo4JStorage(BaseGraphStorage):
RETURN r
"""
await tx.run(query, properties=edge_properties)
logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}")
logger.debug(
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
)
try:
async with self._driver.session() as session:
@@ -273,6 +291,6 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
async def _node2vec_embed(self):
print ("Implemented but never called.")
print("Implemented but never called.")

View File

@@ -1,6 +1,5 @@
import asyncio
import os
import importlib
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
@@ -24,18 +23,15 @@ from .storage import (
NanoVectorDBStorage,
NetworkXStorage,
)
from .kg.neo4j_impl import (
Neo4JStorage
)
#future KG integrations
from .kg.neo4j_impl import Neo4JStorage
# future KG integrations
# from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage
# )
from .utils import (
EmbeddingFunc,
compute_mdhash_id,
@@ -52,6 +48,7 @@ from .base import (
QueryParam,
)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
@@ -64,7 +61,6 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@dataclass
class LightRAG:
working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
@@ -74,8 +70,6 @@ class LightRAG:
current_log_level = logger.level
log_level: str = field(default=current_log_level)
# text chunking
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
@@ -130,8 +124,10 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
#@TODO: should move all storage setup here to leverage initial start params attached to self.
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
# @TODO: should move all storage setup here to leverage initial start params attached to self.
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.kg
]
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
@@ -185,6 +181,7 @@ class LightRAG:
**self.llm_model_kwargs,
)
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
return {
"Neo4JStorage": Neo4JStorage,
@@ -328,4 +325,4 @@ class LightRAG:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
await asyncio.gather(*tasks)

View File

@@ -798,4 +798,4 @@ if __name__ == "__main__":
result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())
asyncio.run(main())

View File

@@ -466,7 +466,6 @@ async def _build_local_query_context(
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
):
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
@@ -483,7 +482,7 @@ async def _build_local_query_context(
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
)
@@ -928,7 +927,6 @@ async def hybrid_query(
query_param,
)
if hl_keywords:
high_level_context = await _build_global_query_context(
hl_keywords,
@@ -939,7 +937,6 @@ async def hybrid_query(
query_param,
)
context = combine_contexts(high_level_context, low_level_context)
if query_param.only_need_context:
@@ -1008,9 +1005,11 @@ def combine_contexts(high_level_context, low_level_context):
# Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities)
# Combine and deduplicate the relationships
combined_relationships = process_combine_contexts(hl_relationships, ll_relationships)
combined_relationships = process_combine_contexts(
hl_relationships, ll_relationships
)
# Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources)
@@ -1046,7 +1045,6 @@ async def naive_query(
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
@@ -1077,4 +1075,4 @@ async def naive_query(
.strip()
)
return response
return response

View File

@@ -233,8 +233,7 @@ class NetworkXStorage(BaseGraphStorage):
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
#@TODO: NOT USED
# @TODO: NOT USED
async def _node2vec_embed(self):
from graspologic import embed

View File

@@ -9,7 +9,7 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union,List
from typing import Any, Union, List
import xml.etree.ElementTree as ET
import numpy as np
@@ -176,19 +176,20 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i]
return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(data)
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
output = io.StringIO(csv_string)
reader = csv.reader(output)
return [row for row in reader]
def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
@@ -253,13 +254,14 @@ def xml_to_json(xml_file):
print(f"An error occurred: {e}")
return None
def process_combine_contexts(hl, ll):
header = None
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
if list_hl:
header=list_hl[0]
header = list_hl[0]
list_hl = list_hl[1:]
if list_ll:
header = list_ll[0]
@@ -268,19 +270,17 @@ def process_combine_contexts(hl, ll):
return ""
if list_hl:
list_hl = [','.join(item[1:]) for item in list_hl if item]
list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll:
list_ll = [','.join(item[1:]) for item in list_ll if item]
list_ll = [",".join(item[1:]) for item in list_ll if item]
combined_sources_set = set(
filter(None, list_hl + list_ll)
)
combined_sources_set = set(filter(None, list_hl + list_ll))
combined_sources = [",\t".join(header)]
for i, item in enumerate(combined_sources_set, start=1):
combined_sources.append(f"{i},\t{item}")
combined_sources = "\n".join(combined_sources)
return combined_sources

23
test.py
View File

@@ -1,11 +1,10 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from pprint import pprint
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
@@ -15,7 +14,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
@@ -23,13 +22,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -5,8 +5,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./local_neo4jWorkDir"
@@ -18,7 +18,7 @@ rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
kg="Neo4JStorage",
log_level="INFO"
log_level="INFO",
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
@@ -26,13 +26,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)