Merge pull request #221 from wiltshirek/main

Fix Event Loop conflict
This commit is contained in:
zrguo
2024-11-07 14:52:10 +08:00
committed by GitHub
10 changed files with 187 additions and 4777 deletions

View File

@@ -1,22 +1,28 @@
import networkx as nx import networkx as nx
G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml') G = nx.read_graphml("./dickensTestEmbedcall/graph_chunk_entity_relation.graphml")
def get_all_edges_and_nodes(G): def get_all_edges_and_nodes(G):
# Get all edges and their properties # Get all edges and their properties
edges_with_properties = [] edges_with_properties = []
for u, v, data in G.edges(data=True): for u, v, data in G.edges(data=True):
edges_with_properties.append({ edges_with_properties.append(
'start': u, {
'end': v, "start": u,
'label': data.get('label', ''), # Assuming 'label' is used for edge type "end": v,
'properties': data, "label": data.get(
'start_node_properties': G.nodes[u], "label", ""
'end_node_properties': G.nodes[v] ), # Assuming 'label' is used for edge type
}) "properties": data,
"start_node_properties": G.nodes[u],
"end_node_properties": G.nodes[v],
}
)
return edges_with_properties return edges_with_properties
# Example usage # Example usage
if __name__ == "__main__": if __name__ == "__main__":
# Assume G is your NetworkX graph loaded from Neo4j # Assume G is your NetworkX graph loaded from Neo4j

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,16 @@
import asyncio import asyncio
import html
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast, Tuple, List, Dict from typing import Any, Union, Tuple, List, Dict
import numpy as np
import inspect import inspect
from lightrag.utils import load_json, logger, write_json from lightrag.utils import logger
from ..base import ( from ..base import BaseGraphStorage
BaseGraphStorage from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
) )
from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
from contextlib import asynccontextmanager
from tenacity import ( from tenacity import (
@@ -26,7 +25,7 @@ from tenacity import (
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):
print ("no preloading of graph with neo4j in production") print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config): def __init__(self, namespace, global_config):
super().__init__(namespace=namespace, global_config=global_config) super().__init__(namespace=namespace, global_config=global_config)
@@ -35,7 +34,9 @@ class Neo4JStorage(BaseGraphStorage):
URI = os.environ["NEO4J_URI"] URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"] USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"] PASSWORD = os.environ["NEO4J_PASSWORD"]
self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
return None return None
def __post_init__(self): def __post_init__(self):
@@ -43,37 +44,35 @@ class Neo4JStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def close(self): async def close(self):
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
self._driver = None self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
async def index_done_callback(self): async def index_done_callback(self):
print ("KG successfully indexed.") print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('\"') entity_name_label = node_id.strip('"')
async with self._driver.session() as session: async with self._driver.session() as session:
query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
) )
return single_result["node_exists"] return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('\"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('\"') entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session: async with self._driver.session() as session:
query = ( query = (
@@ -83,19 +82,16 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
) )
return single_result["edgeExists"] return single_result["edgeExists"]
def close(self): def close(self):
self._driver.close() self._driver.close()
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session: async with self._driver.session() as session:
entity_name_label = node_id.strip('\"') entity_name_label = node_id.strip('"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n" query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query) result = await session.run(query)
record = await result.single() record = await result.single()
@@ -103,17 +99,15 @@ class Neo4JStorage(BaseGraphStorage):
node = record["n"] node = record["n"]
node_dict = dict(node) node_dict = dict(node)
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}' f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
) )
return node_dict return node_dict
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('\"') entity_name_label = node_id.strip('"')
async with self._driver.session() as session: async with self._driver.session() as session:
query = f""" query = f"""
MATCH (n:`{entity_name_label}`) MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount RETURN COUNT{{ (n)--() }} AS totalEdgeCount
@@ -123,16 +117,15 @@ class Neo4JStorage(BaseGraphStorage):
if record: if record:
edge_count = record["totalEdgeCount"] edge_count = record["totalEdgeCount"]
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}' f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
) )
return edge_count return edge_count
else: else:
return None return None
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('\"') entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('\"') entity_name_label_target = tgt_id.strip('"')
src_degree = await self.node_degree(entity_name_label_source) src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target) trg_degree = await self.node_degree(entity_name_label_target)
@@ -142,15 +135,15 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}' f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
) )
return degrees return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: ) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('\"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('\"') entity_name_label_target = target_node_id.strip('"')
""" """
Find all edges between nodes of two given labels Find all edges between nodes of two given labels
@@ -161,27 +154,29 @@ class Neo4JStorage(BaseGraphStorage):
Returns: Returns:
list: List of all relationships/edges found list: List of all relationships/edges found
""" """
async with self._driver.session() as session: async with self._driver.session() as session:
query = f""" query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
LIMIT 1 LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target) """.format(
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
)
result = await session.run(query) result = await session.run(query)
record = await result.single() record = await result.single()
if record: if record:
result = dict(record["edge_properties"]) result = dict(record["edge_properties"])
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}' f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
) )
return result return result
else: else:
return None return None
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]: node_label = source_node_id.strip('"')
node_label = source_node_id.strip('\"')
""" """
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
@@ -190,26 +185,37 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`) query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected""" RETURN n, r, connected"""
async with self._driver.session() as session: async with self._driver.session() as session:
results = await session.run(query) results = await session.run(query)
edges = [] edges = []
async for record in results: async for record in results:
source_node = record['n'] source_node = record["n"]
connected_node = record['connected'] connected_node = record["connected"]
source_label = list(source_node.labels)[0] if source_node.labels else None source_label = (
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None list(source_node.labels)[0] if source_node.labels else None
)
target_label = (
list(connected_node.labels)[0]
if connected_node and connected_node.labels
else None
)
if source_label and target_label: if source_label and target_label:
edges.append((source_label, target_label)) edges.append((source_label, target_label))
return edges return edges
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)), retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
""" """
@@ -219,7 +225,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties node_data: Dictionary of node properties
""" """
label = node_id.strip('\"') label = node_id.strip('"')
properties = node_data properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction): async def _do_upsert(tx: AsyncManagedTransaction):
@@ -228,7 +234,9 @@ class Neo4JStorage(BaseGraphStorage):
SET n += $properties SET n += $properties
""" """
await tx.run(query, properties=properties) await tx.run(query, properties=properties)
logger.debug(f"Upserted node with label '{label}' and properties: {properties}") logger.debug(
f"Upserted node with label '{label}' and properties: {properties}"
)
try: try:
async with self._driver.session() as session: async with self._driver.session() as session:
@@ -240,9 +248,17 @@ class Neo4JStorage(BaseGraphStorage):
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)), retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
) )
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]): async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
@@ -251,8 +267,8 @@ class Neo4JStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge edge_data (dict): Dictionary of properties to set on the edge
""" """
source_node_label = source_node_id.strip('\"') source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('\"') target_node_label = target_node_id.strip('"')
edge_properties = edge_data edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction): async def _do_upsert_edge(tx: AsyncManagedTransaction):
@@ -265,7 +281,9 @@ class Neo4JStorage(BaseGraphStorage):
RETURN r RETURN r
""" """
await tx.run(query, properties=edge_properties) await tx.run(query, properties=edge_properties)
logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}") logger.debug(
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
)
try: try:
async with self._driver.session() as session: async with self._driver.session() as session:
@@ -273,6 +291,6 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e: except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}") logger.error(f"Error during edge upsert: {str(e)}")
raise raise
async def _node2vec_embed(self):
print ("Implemented but never called.")
async def _node2vec_embed(self):
print("Implemented but never called.")

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import os import os
import importlib
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
@@ -25,17 +24,14 @@ from .storage import (
NetworkXStorage, NetworkXStorage,
) )
from .kg.neo4j_impl import ( from .kg.neo4j_impl import Neo4JStorage
Neo4JStorage # future KG integrations
)
#future KG integrations
# from .kg.ArangoDB_impl import ( # from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage # GraphStorage as ArangoDBStorage
# ) # )
from .utils import ( from .utils import (
EmbeddingFunc, EmbeddingFunc,
compute_mdhash_id, compute_mdhash_id,
@@ -55,18 +51,16 @@ from .base import (
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
logger.info("Creating a new event loop in main thread.") logger.info("Creating a new event loop in main thread.")
# loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop = asyncio.get_event_loop()
return loop return loop
@dataclass @dataclass
class LightRAG: class LightRAG:
working_dir: str = field( working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
) )
@@ -76,8 +70,6 @@ class LightRAG:
current_log_level = logger.level current_log_level = logger.level
log_level: str = field(default=current_log_level) log_level: str = field(default=current_log_level)
# text chunking # text chunking
chunk_token_size: int = 1200 chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100 chunk_overlap_token_size: int = 100
@@ -132,8 +124,10 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
#@TODO: should move all storage setup here to leverage initial start params attached to self. # @TODO: should move all storage setup here to leverage initial start params attached to self.
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg] self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.kg
]
if not os.path.exists(self.working_dir): if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
@@ -187,6 +181,7 @@ class LightRAG:
**self.llm_model_kwargs, **self.llm_model_kwargs,
) )
) )
def _get_storage_class(self) -> Type[BaseGraphStorage]: def _get_storage_class(self) -> Type[BaseGraphStorage]:
return { return {
"Neo4JStorage": Neo4JStorage, "Neo4JStorage": Neo4JStorage,

View File

@@ -466,7 +466,6 @@ async def _build_local_query_context(
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam, query_param: QueryParam,
): ):
results = await entities_vdb.query(query, top_k=query_param.top_k) results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results): if not len(results):
@@ -483,7 +482,7 @@ async def _build_local_query_context(
{**n, "entity_name": k["entity_name"], "rank": d} {**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees) for k, n, d in zip(results, node_datas, node_degrees)
if n is not None if n is not None
]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram. ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
use_text_units = await _find_most_related_text_unit_from_entities( use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst node_datas, query_param, text_chunks_db, knowledge_graph_inst
) )
@@ -928,7 +927,6 @@ async def hybrid_query(
query_param, query_param,
) )
if hl_keywords: if hl_keywords:
high_level_context = await _build_global_query_context( high_level_context = await _build_global_query_context(
hl_keywords, hl_keywords,
@@ -939,7 +937,6 @@ async def hybrid_query(
query_param, query_param,
) )
context = combine_contexts(high_level_context, low_level_context) context = combine_contexts(high_level_context, low_level_context)
if query_param.only_need_context: if query_param.only_need_context:
@@ -1010,7 +1007,9 @@ def combine_contexts(high_level_context, low_level_context):
combined_entities = process_combine_contexts(hl_entities, ll_entities) combined_entities = process_combine_contexts(hl_entities, ll_entities)
# Combine and deduplicate the relationships # Combine and deduplicate the relationships
combined_relationships = process_combine_contexts(hl_relationships, ll_relationships) combined_relationships = process_combine_contexts(
hl_relationships, ll_relationships
)
# Combine and deduplicate the sources # Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources) combined_sources = process_combine_contexts(hl_sources, ll_sources)
@@ -1046,7 +1045,6 @@ async def naive_query(
chunks_ids = [r["id"] for r in results] chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids) chunks = await text_chunks_db.get_by_ids(chunks_ids)
maybe_trun_chunks = truncate_list_by_token_size( maybe_trun_chunks = truncate_list_by_token_size(
chunks, chunks,
key=lambda x: x["content"], key=lambda x: x["content"],

View File

@@ -233,8 +233,7 @@ class NetworkXStorage(BaseGraphStorage):
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
#@TODO: NOT USED
async def _node2vec_embed(self): async def _node2vec_embed(self):
from graspologic import embed from graspologic import embed

View File

@@ -9,7 +9,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Union,List from typing import Any, Union, List
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
@@ -176,19 +176,20 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i] return list_data[:i]
return list_data return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str: def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO() output = io.StringIO()
writer = csv.writer(output) writer = csv.writer(output)
writer.writerows(data) writer.writerows(data)
return output.getvalue() return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]: def csv_string_to_list(csv_string: str) -> List[List[str]]:
output = io.StringIO(csv_string) output = io.StringIO(csv_string)
reader = csv.reader(output) reader = csv.reader(output)
return [row for row in reader] return [row for row in reader]
def save_data_to_file(data, file_name): def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f: with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4) json.dump(data, f, ensure_ascii=False, indent=4)
@@ -253,13 +254,14 @@ def xml_to_json(xml_file):
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
def process_combine_contexts(hl, ll): def process_combine_contexts(hl, ll):
header = None header = None
list_hl = csv_string_to_list(hl.strip()) list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip()) list_ll = csv_string_to_list(ll.strip())
if list_hl: if list_hl:
header=list_hl[0] header = list_hl[0]
list_hl = list_hl[1:] list_hl = list_hl[1:]
if list_ll: if list_ll:
header = list_ll[0] header = list_ll[0]
@@ -268,13 +270,11 @@ def process_combine_contexts(hl, ll):
return "" return ""
if list_hl: if list_hl:
list_hl = [','.join(item[1:]) for item in list_hl if item] list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll: if list_ll:
list_ll = [','.join(item[1:]) for item in list_ll if item] list_ll = [",".join(item[1:]) for item in list_ll if item]
combined_sources_set = set( combined_sources_set = set(filter(None, list_hl + list_ll))
filter(None, list_hl + list_ll)
)
combined_sources = [",\t".join(header)] combined_sources = [",\t".join(header)]

19
test.py
View File

@@ -1,7 +1,6 @@
import os import os
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from pprint import pprint
######### #########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio # import nest_asyncio
@@ -15,7 +14,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
) )
@@ -23,13 +22,21 @@ with open("./book.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
# Perform naive search # Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search # Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -18,7 +18,7 @@ rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
kg="Neo4JStorage", kg="Neo4JStorage",
log_level="INFO" log_level="INFO",
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
) )
@@ -26,13 +26,21 @@ with open("./book.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
# Perform naive search # Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search # Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)