From aaa8194423e18db4503b4b04fa5543bd63980b41 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Wed, 5 Mar 2025 15:32:39 +0100 Subject: [PATCH 01/54] Update document_routes.py --- lightrag/api/routers/document_routes.py | 114 +++++++++++++++--------- 1 file changed, 73 insertions(+), 41 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index d9dfe913..9d161f6c 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -16,7 +16,7 @@ from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus -from ..utils_api import get_api_key_dependency +from lightrag.api.utils_api import get_api_key_dependency, global_args router = APIRouter(prefix="/documents", tags=["documents"]) @@ -237,54 +237,86 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: ) return False case ".pdf": - if not pm.is_installed("pypdf2"): # type: ignore - pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore - from io import BytesIO + if global_args["main_args"].document_loading_tool=="DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("pypdf2"): # type: ignore + pm.install("pypdf2") + from PyPDF2 import PdfReader # type: ignore + from io import BytesIO - pdf_file = BytesIO(file) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" case ".docx": - if not pm.is_installed("python-docx"): # type: ignore - pm.install("docx") - from docx import Document # type: ignore - from io import BytesIO + if global_args["main_args"].document_loading_tool=="DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("python-docx"): # type: ignore + pm.install("docx") + from docx import Document # type: ignore + from io import BytesIO - docx_file = BytesIO(file) - doc = Document(docx_file) - content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + docx_file = BytesIO(file) + doc = Document(docx_file) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) case ".pptx": - if not pm.is_installed("python-pptx"): # type: ignore - pm.install("pptx") - from pptx import Presentation # type: ignore - from io import BytesIO + if global_args["main_args"].document_loading_tool=="DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("python-pptx"): # type: ignore + pm.install("pptx") + from pptx import Presentation # type: ignore + from io import BytesIO - pptx_file = BytesIO(file) - prs = Presentation(pptx_file) - for slide in prs.slides: - for shape in slide.shapes: - if hasattr(shape, "text"): - content += shape.text + "\n" + pptx_file = BytesIO(file) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" case ".xlsx": - if not pm.is_installed("openpyxl"): # type: ignore - pm.install("openpyxl") - from openpyxl import load_workbook # type: ignore - from io import BytesIO + if global_args["main_args"].document_loading_tool=="DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("openpyxl"): # type: ignore + pm.install("openpyxl") + from openpyxl import load_workbook # type: ignore + from io import BytesIO - xlsx_file = BytesIO(file) - wb = load_workbook(xlsx_file) - for sheet in wb: - content += f"Sheet: {sheet.title}\n" - for row in sheet.iter_rows(values_only=True): - content += ( - "\t".join( - str(cell) if cell is not None else "" for cell in row + xlsx_file = BytesIO(file) + wb = load_workbook(xlsx_file) + for sheet in wb: + content += f"Sheet: {sheet.title}\n" + for row in sheet.iter_rows(values_only=True): + content += ( + "\t".join( + str(cell) if cell is not None else "" for cell in row + ) + + "\n" ) - + "\n" - ) - content += "\n" + content += "\n" case _: logger.error( f"Unsupported file type: {file_path.name} (extension {ext})" From 95a6a274ca7d0588e72e76f8eb445870a128f868 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Wed, 5 Mar 2025 15:33:06 +0100 Subject: [PATCH 02/54] Update ollama_api.py --- lightrag/api/routers/ollama_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 9688d073..37d7354e 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -11,7 +11,7 @@ import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam from lightrag.utils import encode_string_by_tiktoken -from ..utils_api import ollama_server_infos +from lightrag.api.utils_api import ollama_server_infos # query mode according to query prefix (bypass is not LightRAG quer mode) From c62422eadee4ac9f666c55fa04706ce52812fd32 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Wed, 5 Mar 2025 15:33:54 +0100 Subject: [PATCH 03/54] Update utils_api.py --- lightrag/api/utils_api.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index ed1250d4..39b2950f 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -17,6 +17,10 @@ from starlette.status import HTTP_403_FORBIDDEN # Load environment variables load_dotenv(override=True) +global_args={ + "main_args":None +} + class OllamaServerInfos: # Constants for emulated Ollama model information @@ -340,9 +344,13 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: # Inject chunk configuration args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) + + # Select Document loading tool + args.document_loading_tool = get_env_value("DOCUMENT_LOADING_TOOL", "DOCLING") ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name + global_args["main_args"]= args return args From 39c24f4a597c9e82e45975e322ee28156a6fb202 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Wed, 5 Mar 2025 15:36:17 +0100 Subject: [PATCH 04/54] Update utils_api.py --- lightrag/api/utils_api.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 39b2950f..8ba4565f 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -17,9 +17,7 @@ from starlette.status import HTTP_403_FORBIDDEN # Load environment variables load_dotenv(override=True) -global_args={ - "main_args":None -} +global_args = {"main_args": None} class OllamaServerInfos: @@ -344,13 +342,13 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: # Inject chunk configuration args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) - + # Select Document loading tool args.document_loading_tool = get_env_value("DOCUMENT_LOADING_TOOL", "DOCLING") ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name - global_args["main_args"]= args + global_args["main_args"] = args return args From 6e4daea056940b17f6773c59e492bd8a5eb5d308 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Wed, 5 Mar 2025 15:36:47 +0100 Subject: [PATCH 05/54] Linting --- lightrag/api/routers/document_routes.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 9d161f6c..a6830389 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -237,10 +237,11 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: ) return False case ".pdf": - if global_args["main_args"].document_loading_tool=="DOCLING": + if global_args["main_args"].document_loading_tool == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter + converter = DocumentConverter() result = converter.convert(file_path) content = result.document.export_to_markdown() @@ -255,10 +256,11 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: for page in reader.pages: content += page.extract_text() + "\n" case ".docx": - if global_args["main_args"].document_loading_tool=="DOCLING": + if global_args["main_args"].document_loading_tool == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter + converter = DocumentConverter() result = converter.convert(file_path) content = result.document.export_to_markdown() @@ -270,12 +272,15 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: docx_file = BytesIO(file) doc = Document(docx_file) - content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + content = "\n".join( + [paragraph.text for paragraph in doc.paragraphs] + ) case ".pptx": - if global_args["main_args"].document_loading_tool=="DOCLING": + if global_args["main_args"].document_loading_tool == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter + converter = DocumentConverter() result = converter.convert(file_path) content = result.document.export_to_markdown() @@ -292,10 +297,11 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: if hasattr(shape, "text"): content += shape.text + "\n" case ".xlsx": - if global_args["main_args"].document_loading_tool=="DOCLING": + if global_args["main_args"].document_loading_tool == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter + converter = DocumentConverter() result = converter.convert(file_path) content = result.document.export_to_markdown() @@ -312,7 +318,8 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: for row in sheet.iter_rows(values_only=True): content += ( "\t".join( - str(cell) if cell is not None else "" for cell in row + str(cell) if cell is not None else "" + for cell in row ) + "\n" ) From 00f3c6c6ddce60687d25c4c7022efb5bba1e4b5d Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Thu, 6 Mar 2025 01:11:48 +0100 Subject: [PATCH 06/54] Upgraded document loading engine --- lightrag/api/routers/document_routes.py | 8 ++++---- lightrag/api/utils_api.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index a6830389..dcb8f961 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -237,7 +237,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: ) return False case ".pdf": - if global_args["main_args"].document_loading_tool == "DOCLING": + if global_args["main_args"].document_loading_engine == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter @@ -256,7 +256,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: for page in reader.pages: content += page.extract_text() + "\n" case ".docx": - if global_args["main_args"].document_loading_tool == "DOCLING": + if global_args["main_args"].document_loading_engine == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter @@ -276,7 +276,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: [paragraph.text for paragraph in doc.paragraphs] ) case ".pptx": - if global_args["main_args"].document_loading_tool == "DOCLING": + if global_args["main_args"].document_loading_engine == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter @@ -297,7 +297,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: if hasattr(shape, "text"): content += shape.text + "\n" case ".xlsx": - if global_args["main_args"].document_loading_tool == "DOCLING": + if global_args["main_args"].document_loading_engine == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 8ba4565f..ae674968 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -344,7 +344,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) # Select Document loading tool - args.document_loading_tool = get_env_value("DOCUMENT_LOADING_TOOL", "DOCLING") + args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DOCLING") ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name From 6e3b23069c0a76a5bfa7e27189faac57ff7d0691 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 7 Mar 2025 16:43:18 +0800 Subject: [PATCH 07/54] - Remove useless `_label_exists` method --- lightrag/kg/neo4j_impl.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index fec39138..2498341d 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -164,23 +164,13 @@ class Neo4JStorage(BaseGraphStorage): # Noe4J handles persistence automatically pass - async def _label_exists(self, label: str) -> bool: - """Check if a label exists in the Neo4j database.""" - query = "CALL db.labels() YIELD label RETURN label" - try: - async with self._driver.session(database=self._DATABASE) as session: - result = await session.run(query) - labels = [record["label"] for record in await result.data()] - return label in labels - except Exception as e: - logger.error(f"Error checking label existence: {e}") - return False - async def _ensure_label(self, label: str) -> str: - """Ensure a label exists by validating it.""" + """Ensure a label is valid + + Args: + label: The label to validate + """ clean_label = label.strip('"') - if not await self._label_exists(clean_label): - logger.warning(f"Label '{clean_label}' does not exist in Neo4j") return clean_label async def has_node(self, node_id: str) -> bool: @@ -290,7 +280,7 @@ class Neo4JStorage(BaseGraphStorage): if record: try: result = dict(record["edge_properties"]) - logger.info(f"Result: {result}") + logger.debug(f"Result: {result}") # Ensure required keys exist with defaults required_keys = { "weight": 0.0, From 0ee2e7fd4800050ef2d1819c157a196ed66cf4fa Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 7 Mar 2025 16:56:48 +0800 Subject: [PATCH 08/54] Suppress Neo4j warning logs by setting logger level. --- lightrag/kg/neo4j_impl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 2498341d..265c0347 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -15,6 +15,7 @@ from tenacity import ( retry_if_exception_type, ) +import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge @@ -37,6 +38,8 @@ config.read("config.ini", "utf-8") # Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) +# Set neo4j logger level to ERROR to suppress warning logs +logging.getLogger("neo4j").setLevel(logging.ERROR) @final @dataclass From af803f4e7ad3267fcd184fd6c3914b4c6b2c6bef Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 01:20:36 +0800 Subject: [PATCH 09/54] Refactor Neo4J graph query with min_degree an inclusive match support --- lightrag/kg/neo4j_impl.py | 434 ++++++++++++++++++++++++-------------- 1 file changed, 275 insertions(+), 159 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 265c0347..f6567249 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -41,6 +41,7 @@ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) # Set neo4j logger level to ERROR to suppress warning logs logging.getLogger("neo4j").setLevel(logging.ERROR) + @final @dataclass class Neo4JStorage(BaseGraphStorage): @@ -63,19 +64,25 @@ class Neo4JStorage(BaseGraphStorage): MAX_CONNECTION_POOL_SIZE = int( os.environ.get( "NEO4J_MAX_CONNECTION_POOL_SIZE", - config.get("neo4j", "connection_pool_size", fallback=800), + config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800 ) ) CONNECTION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_TIMEOUT", - config.get("neo4j", "connection_timeout", fallback=60.0), + config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0 ), ) CONNECTION_ACQUISITION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", - config.get("neo4j", "connection_acquisition_timeout", fallback=60.0), + config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0 + ), + ) + MAX_TRANSACTION_RETRY_TIME = float( + os.environ.get( + "NEO4J_MAX_TRANSACTION_RETRY_TIME", + config.get("neo4j", "max_transaction_retry_time", fallback=30.0), ), ) DATABASE = os.environ.get( @@ -88,6 +95,7 @@ class Neo4JStorage(BaseGraphStorage): max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, connection_timeout=CONNECTION_TIMEOUT, connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, + max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, ) # Try to connect to the database @@ -169,21 +177,24 @@ class Neo4JStorage(BaseGraphStorage): async def _ensure_label(self, label: str) -> str: """Ensure a label is valid - + Args: label: The label to validate """ clean_label = label.strip('"') + if not clean_label: + raise ValueError("Neo4j: Label cannot be empty") return clean_label async def has_node(self, node_id: str) -> bool: entity_name_label = await self._ensure_label(node_id) - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) result = await session.run(query) single_result = await result.single() + await result.consume() # Ensure result is fully consumed logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" ) @@ -193,13 +204,14 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" ) result = await session.run(query) single_result = await result.single() + await result.consume() # Ensure result is fully consumed logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" ) @@ -215,13 +227,16 @@ class Neo4JStorage(BaseGraphStorage): dict: Node properties if found None: If node not found """ - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: entity_name_label = await self._ensure_label(node_id) query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) - record = await result.single() - if record: - node = record["n"] + records = await result.fetch(2) # Get up to 2 records to check for duplicates + await result.consume() # Ensure result is fully consumed + if len(records) > 1: + logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.") + if records: + node = records[0]["n"] node_dict = dict(node) logger.debug( f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" @@ -230,23 +245,40 @@ class Neo4JStorage(BaseGraphStorage): return None async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. + + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + """ entity_name_label = node_id.strip('"') - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = f""" MATCH (n:`{entity_name_label}`) - RETURN COUNT{{ (n)--() }} AS totalEdgeCount + OPTIONAL MATCH (n)-[r]-() + RETURN n, COUNT(r) AS degree """ result = await session.run(query) - record = await result.single() - if record: - edge_count = record["totalEdgeCount"] - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}" - ) - return edge_count - else: - return None + records = await result.fetch(100) + await result.consume() # Ensure result is fully consumed + + if not records: + logger.warning(f"No node found with label '{entity_name_label}'") + return 0 + + if len(records) > 1: + logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree") + + degree = records[0]["degree"] + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" + ) + return degree async def edge_degree(self, src_id: str, tgt_id: str) -> int: entity_name_label_source = src_id.strip('"') @@ -264,6 +296,31 @@ class Neo4JStorage(BaseGraphStorage): ) return degrees + async def check_duplicate_nodes(self) -> list[tuple[str, int]]: + """Find all labels that have multiple nodes + + Returns: + list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes + """ + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n) + WITH labels(n) as nodeLabels + UNWIND nodeLabels as label + WITH label, count(*) as node_count + WHERE node_count > 1 + RETURN label, node_count + ORDER BY node_count DESC + """ + result = await session.run(query) + duplicates = [] + async for record in result: + label = record["label"] + count = record["node_count"] + logger.info(f"Found {count} nodes with label: {label}") + duplicates.append((label, count)) + return duplicates + async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -271,18 +328,21 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = f""" - MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) + MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties - LIMIT 1 """ result = await session.run(query) - record = await result.single() - if record: + records = await result.fetch(2) # Get up to 2 records to check for duplicates + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + ) + if records: try: - result = dict(record["edge_properties"]) + result = dict(records[0]["edge_properties"]) logger.debug(f"Result: {result}") # Ensure required keys exist with defaults required_keys = { @@ -349,24 +409,27 @@ class Neo4JStorage(BaseGraphStorage): query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: results = await session.run(query) edges = [] - async for record in results: - source_node = record["n"] - connected_node = record["connected"] + try: + async for record in results: + source_node = record["n"] + connected_node = record["connected"] - source_label = ( - list(source_node.labels)[0] if source_node.labels else None - ) - target_label = ( - list(connected_node.labels)[0] - if connected_node and connected_node.labels - else None - ) + source_label = ( + list(source_node.labels)[0] if source_node.labels else None + ) + target_label = ( + list(connected_node.labels)[0] + if connected_node and connected_node.labels + else None + ) - if source_label and target_label: - edges.append((source_label, target_label)) + if source_label and target_label: + edges.append((source_label, target_label)) + finally: + await results.consume() # Ensure results are consumed even if processing fails return edges @@ -427,30 +490,46 @@ class Neo4JStorage(BaseGraphStorage): ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. + Checks if both source and target nodes exist before creating the edge. Args: source_node_id (str): Label of the source node (used as identifier) target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge + + Raises: + ValueError: If either source or target node does not exist """ source_label = await self._ensure_label(source_node_id) target_label = await self._ensure_label(target_node_id) edge_properties = edge_data + # Check if both nodes exist + source_exists = await self.has_node(source_label) + target_exists = await self.has_node(target_label) + + if not source_exists: + raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist") + if not target_exists: + raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist") + async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" MATCH (source:`{source_label}`) WITH source MATCH (target:`{target_label}`) - MERGE (source)-[r:DIRECTED]->(target) + MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r """ result = await tx.run(query, properties=edge_properties) - record = await result.single() - logger.debug( - f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" - ) + try: + record = await result.single() + logger.debug( + f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" + ) + finally: + await result.consume() # Ensure result is consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -463,145 +542,179 @@ class Neo4JStorage(BaseGraphStorage): print("Implemented but never called.") async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 + self, + node_label: str, + max_depth: int = 3, + min_degree: int = 0, + inclusive: bool = False, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence (nodes containing the specified label string) - 2. Followed by nodes directly connected to the matching nodes - 3. Finally, the degree of the nodes + 1. min_degree does not affect nodes directly connected to the matching nodes + 2. Label matching nodes take precedence + 3. Followed by nodes directly connected to the matching nodes + 4. Finally, the degree of the nodes Args: - node_label (str): String to match in node labels (will match any node containing this string in its label) - max_depth (int, optional): Maximum depth of the graph. Defaults to 5. + node_label: Label of the starting node + max_depth: Maximum depth of the subgraph + min_degree: Minimum degree of nodes to include. Defaults to 0 + inclusive: Do an inclusive search if true Returns: KnowledgeGraph: Complete connected subgraph for specified node """ label = node_label.strip('"') - # Escape single quotes to prevent injection attacks - escaped_label = label.replace("'", "\\'") result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: try: if label == "*": main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() WITH n, count(r) AS degree + WHERE degree >= $min_degree ORDER BY degree DESC LIMIT $max_nodes - WITH collect(n) AS nodes - MATCH (a)-[r]->(b) - WHERE a IN nodes AND b IN nodes - RETURN nodes, collect(DISTINCT r) AS relationships + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ result_set = await session.run( - main_query, {"max_nodes": MAX_GRAPH_NODES} + main_query, + {"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree}, ) else: - validate_query = f""" - MATCH (n) - WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}') - RETURN n LIMIT 1 - """ - validate_result = await session.run(validate_query) - if not await validate_result.single(): - logger.warning( - f"No nodes containing '{label}' in their labels found!" - ) - return result - # Main query uses partial matching - main_query = f""" + main_query = """ MATCH (start) - WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}') + WHERE any(label IN labels(start) WHERE + CASE + WHEN $inclusive THEN label CONTAINS $label + ELSE label = $label + END + ) WITH start - CALL apoc.path.subgraphAll(start, {{ - relationshipFilter: '>', + CALL apoc.path.subgraphAll(start, { + relationshipFilter: '', minLevel: 0, - maxLevel: {max_depth}, + maxLevel: $max_depth, bfs: true - }}) + }) YIELD nodes, relationships WITH start, nodes, relationships UNWIND nodes AS node OPTIONAL MATCH (node)-[r]-() - WITH node, count(r) AS degree, start, nodes, relationships, - CASE - WHEN id(node) = id(start) THEN 2 - WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1 - ELSE 0 - END AS priority - ORDER BY priority DESC, degree DESC + WITH node, count(r) AS degree, start, nodes, relationships + WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree + ORDER BY + CASE + WHEN node = start THEN 3 + WHEN EXISTS((start)--(node)) THEN 2 + ELSE 1 + END DESC, + degree DESC LIMIT $max_nodes - WITH collect(node) AS filtered_nodes, nodes, relationships - RETURN filtered_nodes AS nodes, - [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships + WITH collect({node: node}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ result_set = await session.run( - main_query, {"max_nodes": MAX_GRAPH_NODES} + main_query, + { + "max_nodes": MAX_GRAPH_NODES, + "label": label, + "inclusive": inclusive, + "max_depth": max_depth, + "min_degree": min_degree, + }, ) - record = await result_set.single() + try: + record = await result_set.single() - if record: - # Handle nodes (compatible with multi-label cases) - for node in record["nodes"]: - # Use node ID + label combination as unique identifier - node_id = node.id - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=f"{node_id}", - labels=list(node.labels), - properties=dict(node), + if record: + # Handle nodes (compatible with multi-label cases) + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=list(node.labels), + properties=dict(node), + ) ) - ) - seen_nodes.add(node_id) + seen_nodes.add(node_id) - # Handle relationships (including direction information) - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), + # Handle relationships (including direction information) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) ) - ) - seen_edges.add(edge_id) + seen_edges.add(edge_id) - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + finally: + await result_set.consume() # Ensure result set is consumed except neo4jExceptions.ClientError as e: - logger.error(f"APOC query failed: {str(e)}") - return await self._robust_fallback(label, max_depth) + logger.warning( + f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation" + ) + if inclusive: + logger.warning( + "Inclusive search mode is not supported in recursive query, using exact matching" + ) + return await self._robust_fallback(label, max_depth, min_degree) return result async def _robust_fallback( - self, label: str, max_depth: int + self, label: str, max_depth: int, min_degree: int = 0 ) -> Dict[str, List[Dict]]: - """Enhanced fallback query solution""" + """ + Fallback implementation when APOC plugin is not available or incompatible. + This method implements the same functionality as get_knowledge_graph but uses + only basic Cypher queries and recursive traversal instead of APOC procedures. + """ result = {"nodes": [], "edges": []} visited_nodes = set() visited_edges = set() async def traverse(current_label: str, current_depth: int): + # Check traversal limits if current_depth > max_depth: + logger.debug(f"Reached max depth: {max_depth}") + return + if len(visited_nodes) >= MAX_GRAPH_NODES: + logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") return # Get current node details @@ -614,46 +727,46 @@ class Neo4JStorage(BaseGraphStorage): return visited_nodes.add(node_id) - # Add node data (with complete labels) - node_data = {k: v for k, v in node.items()} - node_data["labels"] = [ - current_label - ] # Assume get_node method returns label information - result["nodes"].append(node_data) + # Add node data with label as ID + result["nodes"].append({ + "id": current_label, + "labels": current_label, + "properties": node + }) - # Get all outgoing and incoming edges + # Get connected nodes that meet the degree requirement + # Note: We don't need to check a's degree since it's the current node + # and was already validated in the previous iteration query = f""" - MATCH (a)-[r]-(b) - WHERE a:`{current_label}` OR b:`{current_label}` - RETURN a, r, b, - CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction + MATCH (a:`{current_label}`)-[r]-(b) + WITH r, b, + COUNT((b)--()) AS b_degree + WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) + RETURN r, b """ - async with self._driver.session(database=self._DATABASE) as session: - results = await session.run(query) + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + results = await session.run(query, {"min_degree": min_degree}) async for record in results: # Handle edges rel = record["r"] edge_id = f"{rel.id}_{rel.type}" if edge_id not in visited_edges: - edge_data = dict(rel) - edge_data.update( - { - "source": list(record["a"].labels)[0], - "target": list(record["b"].labels)[0], + b_node = record["b"] + if b_node.labels: # Only process if target node has labels + target_label = list(b_node.labels)[0] + result["edges"].append({ + "id": f"{current_label}_{target_label}", "type": rel.type, - "direction": record["direction"], - } - ) - result["edges"].append(edge_data) - visited_edges.add(edge_id) + "source": current_label, + "target": target_label, + "properties": dict(rel) + }) + visited_edges.add(edge_id) - # Recursively traverse adjacent nodes - next_label = ( - list(record["b"].labels)[0] - if record["direction"] == "OUTGOING" - else list(record["a"].labels)[0] - ) - await traverse(next_label, current_depth + 1) + # Continue traversal + await traverse(target_label, current_depth + 1) + else: + logger.warning(f"Skipping edge {edge_id} due to missing labels on target node") await traverse(label, 0) return result @@ -664,7 +777,7 @@ class Neo4JStorage(BaseGraphStorage): Returns: ["Person", "Company", ...] # Alphabetically sorted label list """ - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: # Method 1: Direct metadata query (Available for Neo4j 4.3+) # query = "CALL db.labels() YIELD label RETURN label" @@ -679,8 +792,11 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) labels = [] - async for record in result: - labels.append(record["label"]) + try: + async for record in result: + labels.append(record["label"]) + finally: + await result.consume() # Ensure results are consumed even if processing fails return labels @retry( @@ -763,7 +879,7 @@ class Neo4JStorage(BaseGraphStorage): async def _do_delete_edge(tx: AsyncManagedTransaction): query = f""" - MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`) + MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`) DELETE r """ await tx.run(query) From c07b592e1bfe73cde40c46f46e06f1dc9c3ae292 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 02:39:51 +0800 Subject: [PATCH 10/54] Add missing await consume --- lightrag/kg/neo4j_impl.py | 250 ++++++++++++++++++++------------------ 1 file changed, 130 insertions(+), 120 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f6567249..ea316d0f 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -64,19 +64,19 @@ class Neo4JStorage(BaseGraphStorage): MAX_CONNECTION_POOL_SIZE = int( os.environ.get( "NEO4J_MAX_CONNECTION_POOL_SIZE", - config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800 + config.get("neo4j", "connection_pool_size", fallback=50), ) ) CONNECTION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_TIMEOUT", - config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0 + config.get("neo4j", "connection_timeout", fallback=30.0), ), ) CONNECTION_ACQUISITION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", - config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0 + config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), ), ) MAX_TRANSACTION_RETRY_TIME = float( @@ -188,23 +188,24 @@ class Neo4JStorage(BaseGraphStorage): async def has_node(self, node_id: str) -> bool: entity_name_label = await self._ensure_label(node_id) - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) result = await session.run(query) single_result = await result.single() await result.consume() # Ensure result is fully consumed - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" - ) return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" @@ -212,9 +213,6 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) single_result = await result.single() await result.consume() # Ensure result is fully consumed - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" - ) return single_result["edgeExists"] async def get_node(self, node_id: str) -> dict[str, str] | None: @@ -227,14 +225,20 @@ class Neo4JStorage(BaseGraphStorage): dict: Node properties if found None: If node not found """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: entity_name_label = await self._ensure_label(node_id) query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) - records = await result.fetch(2) # Get up to 2 records to check for duplicates + records = await result.fetch( + 2 + ) # Get up to 2 records to check for duplicates await result.consume() # Ensure result is fully consumed if len(records) > 1: - logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.") + logger.warning( + f"Multiple nodes found with label '{entity_name_label}'. Using first node." + ) if records: node = records[0]["n"] node_dict = dict(node) @@ -248,16 +252,18 @@ class Neo4JStorage(BaseGraphStorage): """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. If no node is found, returns 0. - + Args: node_id: The label of the node - + Returns: int: The number of relationships the node has, or 0 if no node found """ entity_name_label = node_id.strip('"') - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = f""" MATCH (n:`{entity_name_label}`) OPTIONAL MATCH (n)-[r]-() @@ -266,14 +272,16 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) records = await result.fetch(100) await result.consume() # Ensure result is fully consumed - + if not records: logger.warning(f"No node found with label '{entity_name_label}'") return 0 - + if len(records) > 1: - logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree") - + logger.warning( + f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" + ) + degree = records[0]["degree"] logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" @@ -296,30 +304,6 @@ class Neo4JStorage(BaseGraphStorage): ) return degrees - async def check_duplicate_nodes(self) -> list[tuple[str, int]]: - """Find all labels that have multiple nodes - - Returns: - list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes - """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: - query = """ - MATCH (n) - WITH labels(n) as nodeLabels - UNWIND nodeLabels as label - WITH label, count(*) as node_count - WHERE node_count > 1 - RETURN label, node_count - ORDER BY node_count DESC - """ - result = await session.run(query) - duplicates = [] - async for record in result: - label = record["label"] - count = record["node_count"] - logger.info(f"Found {count} nodes with label: {label}") - duplicates.append((label, count)) - return duplicates async def get_edge( self, source_node_id: str, target_node_id: str @@ -328,64 +312,69 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = f""" MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties """ result = await session.run(query) - records = await result.fetch(2) # Get up to 2 records to check for duplicates - if len(records) > 1: - logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + try: + records = await result.fetch(2) # Get up to 2 records to check for duplicates + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + ) + if records: + try: + result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in result: + result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" + ) + return result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" ) - if records: - try: - result = dict(records[0]["edge_properties"]) - logger.debug(f"Result: {result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in result: - result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" - ) - return result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" - ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + # Return default edge properties when no edge found + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } + finally: + await result.consume() # Ensure result is fully consumed except Exception as e: logger.error( @@ -409,7 +398,9 @@ class Neo4JStorage(BaseGraphStorage): query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: results = await session.run(query) edges = [] try: @@ -429,7 +420,9 @@ class Neo4JStorage(BaseGraphStorage): if source_label and target_label: edges.append((source_label, target_label)) finally: - await results.consume() # Ensure results are consumed even if processing fails + await ( + results.consume() + ) # Ensure results are consumed even if processing fails return edges @@ -461,10 +454,11 @@ class Neo4JStorage(BaseGraphStorage): MERGE (n:`{label}`) SET n += $properties """ - await tx.run(query, properties=properties) + result = await tx.run(query, properties=properties) logger.debug( f"Upserted node with label '{label}' and properties: {properties}" ) + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -509,9 +503,13 @@ class Neo4JStorage(BaseGraphStorage): target_exists = await self.has_node(target_label) if not source_exists: - raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist") + raise ValueError( + f"Neo4j: source node with label '{source_label}' does not exist" + ) if not target_exists: - raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist") + raise ValueError( + f"Neo4j: target node with label '{target_label}' does not exist" + ) async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" @@ -570,7 +568,9 @@ class Neo4JStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: try: if label == "*": main_query = """ @@ -728,11 +728,9 @@ class Neo4JStorage(BaseGraphStorage): visited_nodes.add(node_id) # Add node data with label as ID - result["nodes"].append({ - "id": current_label, - "labels": current_label, - "properties": node - }) + result["nodes"].append( + {"id": current_label, "labels": current_label, "properties": node} + ) # Get connected nodes that meet the degree requirement # Note: We don't need to check a's degree since it's the current node @@ -744,7 +742,9 @@ class Neo4JStorage(BaseGraphStorage): WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) RETURN r, b """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: results = await session.run(query, {"min_degree": min_degree}) async for record in results: # Handle edges @@ -754,19 +754,23 @@ class Neo4JStorage(BaseGraphStorage): b_node = record["b"] if b_node.labels: # Only process if target node has labels target_label = list(b_node.labels)[0] - result["edges"].append({ - "id": f"{current_label}_{target_label}", - "type": rel.type, - "source": current_label, - "target": target_label, - "properties": dict(rel) - }) + result["edges"].append( + { + "id": f"{current_label}_{target_label}", + "type": rel.type, + "source": current_label, + "target": target_label, + "properties": dict(rel), + } + ) visited_edges.add(edge_id) # Continue traversal await traverse(target_label, current_depth + 1) else: - logger.warning(f"Skipping edge {edge_id} due to missing labels on target node") + logger.warning( + f"Skipping edge {edge_id} due to missing labels on target node" + ) await traverse(label, 0) return result @@ -777,7 +781,9 @@ class Neo4JStorage(BaseGraphStorage): Returns: ["Person", "Company", ...] # Alphabetically sorted label list """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: # Method 1: Direct metadata query (Available for Neo4j 4.3+) # query = "CALL db.labels() YIELD label RETURN label" @@ -796,7 +802,9 @@ class Neo4JStorage(BaseGraphStorage): async for record in result: labels.append(record["label"]) finally: - await result.consume() # Ensure results are consumed even if processing fails + await ( + result.consume() + ) # Ensure results are consumed even if processing fails return labels @retry( @@ -824,8 +832,9 @@ class Neo4JStorage(BaseGraphStorage): MATCH (n:`{label}`) DETACH DELETE n """ - await tx.run(query) + result = await tx.run(query) logger.debug(f"Deleted node with label '{label}'") + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -882,8 +891,9 @@ class Neo4JStorage(BaseGraphStorage): MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`) DELETE r """ - await tx.run(query) + result = await tx.run(query) logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: From fcb04e47e5f1beda21c9304ba3c07d90e2e07fc1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 04:28:54 +0800 Subject: [PATCH 11/54] Refactor Neo4J APOC fall back retrival implementaion --- lightrag/kg/neo4j_impl.py | 255 ++++++++++++++++++++++---------------- 1 file changed, 149 insertions(+), 106 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index ea316d0f..60e8982e 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,7 +3,7 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, List, Dict, final +from typing import Any, final, Optional import numpy as np import configparser @@ -304,7 +304,6 @@ class Neo4JStorage(BaseGraphStorage): ) return degrees - async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -321,60 +320,59 @@ class Neo4JStorage(BaseGraphStorage): """ result = await session.run(query) - try: - records = await result.fetch(2) # Get up to 2 records to check for duplicates - if len(records) > 1: - logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." - ) - if records: - try: - result = dict(records[0]["edge_properties"]) - logger.debug(f"Result: {result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in result: - result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" - ) - return result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + records = await result.fetch(2) # Get up to 2 records to check for duplicates + await result.consume() # Ensure result is fully consumed before processing records + + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - finally: - await result.consume() # Ensure result is fully consumed + if records: + try: + edge_result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {edge_result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" + ) + return edge_result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + ) + # Return default edge properties when no edge found + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } except Exception as e: logger.error( @@ -685,30 +683,36 @@ class Neo4JStorage(BaseGraphStorage): await result_set.consume() # Ensure result set is consumed except neo4jExceptions.ClientError as e: - logger.warning( - f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation" - ) - if inclusive: + logger.warning(f"APOC plugin error: {str(e)}") + if label != "*": logger.warning( - "Inclusive search mode is not supported in recursive query, using exact matching" + "Neo4j: falling back to basic Cypher recursive search..." ) - return await self._robust_fallback(label, max_depth, min_degree) + if inclusive: + logger.warning( + "Neo4j: inclusive search mode is not supported in recursive query, using exact matching" + ) + return await self._robust_fallback(label, max_depth, min_degree) return result async def _robust_fallback( self, label: str, max_depth: int, min_degree: int = 0 - ) -> Dict[str, List[Dict]]: + ) -> KnowledgeGraph: """ Fallback implementation when APOC plugin is not available or incompatible. This method implements the same functionality as get_knowledge_graph but uses only basic Cypher queries and recursive traversal instead of APOC procedures. """ - result = {"nodes": [], "edges": []} + result = KnowledgeGraph() visited_nodes = set() visited_edges = set() - async def traverse(current_label: str, current_depth: int): + async def traverse( + node: KnowledgeGraphNode, + edge: Optional[KnowledgeGraphEdge], + current_depth: int, + ): # Check traversal limits if current_depth > max_depth: logger.debug(f"Reached max depth: {max_depth}") @@ -717,62 +721,101 @@ class Neo4JStorage(BaseGraphStorage): logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") return - # Get current node details - node = await self.get_node(current_label) - if not node: + # Check if node already visited + if node.id in visited_nodes: return - node_id = f"{current_label}" - if node_id in visited_nodes: - return - visited_nodes.add(node_id) - - # Add node data with label as ID - result["nodes"].append( - {"id": current_label, "labels": current_label, "properties": node} - ) - - # Get connected nodes that meet the degree requirement - # Note: We don't need to check a's degree since it's the current node - # and was already validated in the previous iteration - query = f""" - MATCH (a:`{current_label}`)-[r]-(b) - WITH r, b, - COUNT((b)--()) AS b_degree - WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) - RETURN r, b - """ + # Get all edges and target nodes async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - results = await session.run(query, {"min_degree": min_degree}) - async for record in results: - # Handle edges + query = """ + MATCH (a)-[r]-(b) + WHERE id(a) = toInteger($node_id) + WITH r, b, id(r) as edge_id, id(b) as target_id + RETURN r, b, edge_id, target_id + """ + results = await session.run(query, {"node_id": node.id}) + + # Get all records and release database connection + records = await results.fetch() + await results.consume() # Ensure results are consumed + + # Nodes not connected to start node need to check degree + if current_depth > 1 and len(records) < min_degree: + return + + # Add current node to result + result.nodes.append(node) + visited_nodes.add(node.id) + + # Add edge to result if it exists and not already added + if edge and edge.id not in visited_edges: + result.edges.append(edge) + visited_edges.add(edge.id) + + # Prepare nodes and edges for recursive processing + nodes_to_process = [] + for record in records: rel = record["r"] - edge_id = f"{rel.id}_{rel.type}" + edge_id = str(record["edge_id"]) if edge_id not in visited_edges: b_node = record["b"] - if b_node.labels: # Only process if target node has labels - target_label = list(b_node.labels)[0] - result["edges"].append( - { - "id": f"{current_label}_{target_label}", - "type": rel.type, - "source": current_label, - "target": target_label, - "properties": dict(rel), - } - ) - visited_edges.add(edge_id) + target_id = str(record["target_id"]) - # Continue traversal - await traverse(target_label, current_depth + 1) + if b_node.labels: # Only process if target node has labels + # Create KnowledgeGraphNode for target + target_node = KnowledgeGraphNode( + id=target_id, + labels=list(b_node.labels), + properties=dict(b_node), + ) + + # Create KnowledgeGraphEdge + target_edge = KnowledgeGraphEdge( + id=edge_id, + type=rel.type, + source=node.id, + target=target_id, + properties=dict(rel), + ) + + nodes_to_process.append((target_node, target_edge)) else: logger.warning( f"Skipping edge {edge_id} due to missing labels on target node" ) - await traverse(label, 0) + # Process nodes after releasing database connection + for target_node, target_edge in nodes_to_process: + await traverse(target_node, target_edge, current_depth + 1) + + # Get the starting node's data + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + MATCH (n:`{label}`) + RETURN id(n) as node_id, n + """ + node_result = await session.run(query) + try: + node_record = await node_result.single() + if not node_record: + return result + + # Create initial KnowledgeGraphNode + start_node = KnowledgeGraphNode( + id=str(node_record["node_id"]), + labels=list(node_record["n"].labels), + properties=dict(node_record["n"]), + ) + finally: + await node_result.consume() # Ensure results are consumed + + # Start traversal with the initial node + await traverse(start_node, None, 0) + return result async def get_all_labels(self) -> list[str]: From 84222b8b76bb077b144463af8acfde8df188d505 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 10:19:20 +0800 Subject: [PATCH 12/54] Refactor Neo4JStorage methods for robustness and clarity. - Add error handling and resource cleanup - Improve method documentation - Optimize result consumption --- lightrag/kg/neo4j_impl.py | 412 +++++++++++++++++++++++--------------- 1 file changed, 255 insertions(+), 157 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 60e8982e..082b4bf2 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -163,13 +163,14 @@ class Neo4JStorage(BaseGraphStorage): } async def close(self): + """Close the Neo4j driver and release all resources""" if self._driver: await self._driver.close() self._driver = None async def __aexit__(self, exc_type, exc, tb): - if self._driver: - await self._driver.close() + """Ensure driver is closed when context manager exits""" + await self.close() async def index_done_callback(self) -> None: # Noe4J handles persistence automatically @@ -187,33 +188,72 @@ class Neo4JStorage(BaseGraphStorage): return clean_label async def has_node(self, node_id: str) -> bool: + """ + Check if a node with the given label exists in the database + + Args: + node_id: Label of the node to check + + Returns: + bool: True if node exists, False otherwise + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query + """ entity_name_label = await self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = ( - f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" - ) - result = await session.run(query) - single_result = await result.single() - await result.consume() # Ensure result is fully consumed - return single_result["node_exists"] + try: + query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" + result = await session.run(query) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["node_exists"] + except Exception as e: + logger.error( + f"Error checking node existence for {entity_name_label}: {str(e)}" + ) + await result.consume() # Ensure results are consumed even on error + raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + """ + Check if an edge exists between two nodes + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + bool: True if edge exists, False otherwise + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ + entity_name_label_source = await self._ensure_label(source_node_id) + entity_name_label_target = await self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = ( - f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " - "RETURN COUNT(r) > 0 AS edgeExists" - ) - result = await session.run(query) - single_result = await result.single() - await result.consume() # Ensure result is fully consumed - return single_result["edgeExists"] + try: + query = ( + f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run(query) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["edgeExists"] + except Exception as e: + logger.error( + f"Error checking edge existence between {entity_name_label_source} and {entity_name_label_target}: {str(e)}" + ) + await result.consume() # Ensure results are consumed even on error + raise async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier. @@ -224,29 +264,40 @@ class Neo4JStorage(BaseGraphStorage): Returns: dict: Node properties if found None: If node not found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query """ + entity_name_label = await self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - entity_name_label = await self._ensure_label(node_id) - query = f"MATCH (n:`{entity_name_label}`) RETURN n" - result = await session.run(query) - records = await result.fetch( - 2 - ) # Get up to 2 records to check for duplicates - await result.consume() # Ensure result is fully consumed - if len(records) > 1: - logger.warning( - f"Multiple nodes found with label '{entity_name_label}'. Using first node." - ) - if records: - node = records[0]["n"] - node_dict = dict(node) - logger.debug( - f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" - ) - return node_dict - return None + try: + query = f"MATCH (n:`{entity_name_label}`) RETURN n" + result = await session.run(query) + try: + records = await result.fetch( + 2 + ) # Get up to 2 records to check for duplicates + + if len(records) > 1: + logger.warning( + f"Multiple nodes found with label '{entity_name_label}'. Using first node." + ) + if records: + node = records[0]["n"] + node_dict = dict(node) + logger.debug( + f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" + ) + return node_dict + return None + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node for {entity_name_label}: {str(e)}") + raise async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. @@ -258,39 +309,63 @@ class Neo4JStorage(BaseGraphStorage): Returns: int: The number of relationships the node has, or 0 if no node found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query """ - entity_name_label = node_id.strip('"') + entity_name_label = await self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = f""" - MATCH (n:`{entity_name_label}`) - OPTIONAL MATCH (n)-[r]-() - RETURN n, COUNT(r) AS degree - """ - result = await session.run(query) - records = await result.fetch(100) - await result.consume() # Ensure result is fully consumed + try: + query = f""" + MATCH (n:`{entity_name_label}`) + OPTIONAL MATCH (n)-[r]-() + RETURN n, COUNT(r) AS degree + """ + result = await session.run(query) + try: + records = await result.fetch(100) - if not records: - logger.warning(f"No node found with label '{entity_name_label}'") - return 0 + if not records: + logger.warning( + f"No node found with label '{entity_name_label}'" + ) + return 0 - if len(records) > 1: - logger.warning( - f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" + if len(records) > 1: + logger.warning( + f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" + ) + + degree = records[0]["degree"] + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" + ) + return degree + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error( + f"Error getting node degree for {entity_name_label}: {str(e)}" ) - - degree = records[0]["degree"] - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" - ) - return degree + raise async def edge_degree(self, src_id: str, tgt_id: str) -> int: - entity_name_label_source = src_id.strip('"') - entity_name_label_target = tgt_id.strip('"') + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ + entity_name_label_source = await self._ensure_label(src_id) + entity_name_label_target = await self._ensure_label(tgt_id) + src_degree = await self.node_degree(entity_name_label_source) trg_degree = await self.node_degree(entity_name_label_target) @@ -299,17 +374,27 @@ class Neo4JStorage(BaseGraphStorage): trg_degree = 0 if trg_degree is None else trg_degree degrees = int(src_degree) + int(trg_degree) - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}" - ) return degrees async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ try: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + entity_name_label_source = await self._ensure_label(source_node_id) + entity_name_label_target = await self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -320,109 +405,123 @@ class Neo4JStorage(BaseGraphStorage): """ result = await session.run(query) - records = await result.fetch(2) # Get up to 2 records to check for duplicates - await result.consume() # Ensure result is fully consumed before processing records - - if len(records) > 1: - logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + try: + records = await result.fetch( + 2 + ) # Get up to 2 records to check for duplicates + + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + ) + if records: + try: + edge_result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {edge_result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" + ) + return edge_result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" ) - if records: - try: - edge_result = dict(records[0]["edge_properties"]) - logger.debug(f"Result: {edge_result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in edge_result: - edge_result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" - ) - return edge_result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" - ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + # Return default edge properties when no edge found + return { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + finally: + await result.consume() # Ensure result is fully consumed except Exception as e: logger.error( f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - node_label = source_node_id.strip('"') + """Retrieves all edges (relationships) for a particular node identified by its label. + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + ValueError: If source_node_id is invalid + Exception: If there is an error executing the query """ - Retrieves all edges (relationships) for a particular node identified by its label. - :return: List of dictionaries containing edge information - """ - query = f"""MATCH (n:`{node_label}`) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected""" - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - results = await session.run(query) - edges = [] - try: - async for record in results: - source_node = record["n"] - connected_node = record["connected"] + try: + node_label = await self._ensure_label(source_node_id) - source_label = ( - list(source_node.labels)[0] if source_node.labels else None - ) - target_label = ( - list(connected_node.labels)[0] - if connected_node and connected_node.labels - else None - ) + query = f"""MATCH (n:`{node_label}`) + OPTIONAL MATCH (n)-[r]-(connected) + RETURN n, r, connected""" - if source_label and target_label: - edges.append((source_label, target_label)) - finally: - await ( - results.consume() - ) # Ensure results are consumed even if processing fails + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + results = await session.run(query) + edges = [] - return edges + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + + source_label = ( + list(source_node.labels)[0] if source_node.labels else None + ) + target_label = ( + list(connected_node.labels)[0] + if connected_node and connected_node.labels + else None + ) + + if source_label and target_label: + edges.append((source_label, target_label)) + + await results.consume() # Ensure results are consumed + return edges if edges else None + except Exception as e: + logger.error(f"Error getting edges for node {node_label}: {str(e)}") + await results.consume() # Ensure results are consumed even on error + raise + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise @retry( stop=stop_after_attempt(3), @@ -838,7 +937,6 @@ class Neo4JStorage(BaseGraphStorage): RETURN DISTINCT label ORDER BY label """ - result = await session.run(query) labels = [] try: From 78f8d7a1ce1186ce3398afb946f3da79bad50df7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 10:20:10 +0800 Subject: [PATCH 13/54] Convert node and edge IDs to f-strings for consistency. - Use f-strings for node IDs - Use f-strings for edge IDs - Ensure consistent ID formatting --- lightrag/kg/neo4j_impl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 082b4bf2..05deb0a9 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -865,17 +865,17 @@ class Neo4JStorage(BaseGraphStorage): if b_node.labels: # Only process if target node has labels # Create KnowledgeGraphNode for target target_node = KnowledgeGraphNode( - id=target_id, + id=f"{target_id}", labels=list(b_node.labels), properties=dict(b_node), ) # Create KnowledgeGraphEdge target_edge = KnowledgeGraphEdge( - id=edge_id, + id=f"{edge_id}", type=rel.type, - source=node.id, - target=target_id, + source=f"{node.id}", + target=f"{target_id}", properties=dict(rel), ) @@ -905,7 +905,7 @@ class Neo4JStorage(BaseGraphStorage): # Create initial KnowledgeGraphNode start_node = KnowledgeGraphNode( - id=str(node_record["node_id"]), + id=f"{node_record['node_id']}", labels=list(node_record["n"].labels), properties=dict(node_record["n"]), ) From af26d656985e0d9dd722c1cc8ea0d65f6348dc79 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 10:23:27 +0800 Subject: [PATCH 14/54] Convert _ensure_label method from async to sync --- lightrag/kg/neo4j_impl.py | 40 ++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 05deb0a9..cf3c024f 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -176,11 +176,17 @@ class Neo4JStorage(BaseGraphStorage): # Noe4J handles persistence automatically pass - async def _ensure_label(self, label: str) -> str: + def _ensure_label(self, label: str) -> str: """Ensure a label is valid Args: label: The label to validate + + Returns: + str: The cleaned label + + Raises: + ValueError: If label is empty after cleaning """ clean_label = label.strip('"') if not clean_label: @@ -201,7 +207,7 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ - entity_name_label = await self._ensure_label(node_id) + entity_name_label = self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -233,8 +239,8 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If either node_id is invalid Exception: If there is an error executing the query """ - entity_name_label_source = await self._ensure_label(source_node_id) - entity_name_label_target = await self._ensure_label(target_node_id) + entity_name_label_source = self._ensure_label(source_node_id) + entity_name_label_target = self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -269,7 +275,7 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ - entity_name_label = await self._ensure_label(node_id) + entity_name_label = self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -314,7 +320,7 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ - entity_name_label = await self._ensure_label(node_id) + entity_name_label = self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -363,8 +369,8 @@ class Neo4JStorage(BaseGraphStorage): Returns: int: Sum of the degrees of both nodes """ - entity_name_label_source = await self._ensure_label(src_id) - entity_name_label_target = await self._ensure_label(tgt_id) + entity_name_label_source = self._ensure_label(src_id) + entity_name_label_target = self._ensure_label(tgt_id) src_degree = await self.node_degree(entity_name_label_source) trg_degree = await self.node_degree(entity_name_label_target) @@ -393,8 +399,8 @@ class Neo4JStorage(BaseGraphStorage): Exception: If there is an error executing the query """ try: - entity_name_label_source = await self._ensure_label(source_node_id) - entity_name_label_target = await self._ensure_label(target_node_id) + entity_name_label_source = self._ensure_label(source_node_id) + entity_name_label_target = self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -484,7 +490,7 @@ class Neo4JStorage(BaseGraphStorage): Exception: If there is an error executing the query """ try: - node_label = await self._ensure_label(source_node_id) + node_label = self._ensure_label(source_node_id) query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) @@ -543,7 +549,7 @@ class Neo4JStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ - label = await self._ensure_label(node_id) + label = self._ensure_label(node_id) properties = node_data async def _do_upsert(tx: AsyncManagedTransaction): @@ -591,8 +597,8 @@ class Neo4JStorage(BaseGraphStorage): Raises: ValueError: If either source or target node does not exist """ - source_label = await self._ensure_label(source_node_id) - target_label = await self._ensure_label(target_node_id) + source_label = self._ensure_label(source_node_id) + target_label = self._ensure_label(target_node_id) edge_properties = edge_data # Check if both nodes exist @@ -966,7 +972,7 @@ class Neo4JStorage(BaseGraphStorage): Args: node_id: The label of the node to delete """ - label = await self._ensure_label(node_id) + label = self._ensure_label(node_id) async def _do_delete(tx: AsyncManagedTransaction): query = f""" @@ -1024,8 +1030,8 @@ class Neo4JStorage(BaseGraphStorage): edges: List of edges to be deleted, each edge is a (source, target) tuple """ for source, target in edges: - source_label = await self._ensure_label(source) - target_label = await self._ensure_label(target) + source_label = self._ensure_label(source) + target_label = self._ensure_label(target) async def _do_delete_edge(tx: AsyncManagedTransaction): query = f""" From 887f6ed81a2cb6036163105433b160e1343daf98 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 11:20:22 +0800 Subject: [PATCH 15/54] Fix return empty list when no edges is found --- lightrag/kg/neo4j_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index cf3c024f..34226df7 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -520,7 +520,7 @@ class Neo4JStorage(BaseGraphStorage): edges.append((source_label, target_label)) await results.consume() # Ensure results are consumed - return edges if edges else None + return edges except Exception as e: logger.error(f"Error getting edges for node {node_label}: {str(e)}") await results.consume() # Ensure results are consumed even on error From 22a93fb717b7a66dda345fbacdb2e6d5df874707 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 11:29:08 +0800 Subject: [PATCH 16/54] Limit neighbor nodes fetch to 1000 in Neo4JStorage. --- lightrag/kg/neo4j_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 34226df7..7e1007b9 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -843,7 +843,7 @@ class Neo4JStorage(BaseGraphStorage): results = await session.run(query, {"node_id": node.id}) # Get all records and release database connection - records = await results.fetch() + records = await results.fetch(1000) # Max neighbour nodes we can handled await results.consume() # Ensure results are consumed # Nodes not connected to start node need to check degree From fb4a4c736edca76f8ab5968c0b4d8869bec94bf2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 11:36:24 +0800 Subject: [PATCH 17/54] Add duplicate edge upsert checking and logging --- lightrag/kg/neo4j_impl.py | 78 ++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 7e1007b9..1e46798a 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -412,9 +412,7 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) try: - records = await result.fetch( - 2 - ) # Get up to 2 records to check for duplicates + records = await result.fetch(2) if len(records) > 1: logger.warning( @@ -552,20 +550,20 @@ class Neo4JStorage(BaseGraphStorage): label = self._ensure_label(node_id) properties = node_data - async def _do_upsert(tx: AsyncManagedTransaction): - query = f""" - MERGE (n:`{label}`) - SET n += $properties - """ - result = await tx.run(query, properties=properties) - logger.debug( - f"Upserted node with label '{label}' and properties: {properties}" - ) - await result.consume() # Ensure result is fully consumed - try: async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_upsert) + async def execute_upsert(tx: AsyncManagedTransaction): + query = f""" + MERGE (n:`{label}`) + SET n += $properties + """ + result = await tx.run(query, properties=properties) + logger.debug( + f"Upserted node with label '{label}' and properties: {properties}" + ) + await result.consume() # Ensure result is fully consumed + + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during upsert: {str(e)}") raise @@ -614,27 +612,39 @@ class Neo4JStorage(BaseGraphStorage): f"Neo4j: target node with label '{target_label}' does not exist" ) - async def _do_upsert_edge(tx: AsyncManagedTransaction): - query = f""" - MATCH (source:`{source_label}`) - WITH source - MATCH (target:`{target_label}`) - MERGE (source)-[r:DIRECTED]-(target) - SET r += $properties - RETURN r - """ - result = await tx.run(query, properties=edge_properties) - try: - record = await result.single() - logger.debug( - f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" - ) - finally: - await result.consume() # Ensure result is consumed - try: async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_upsert_edge) + async def execute_upsert(tx: AsyncManagedTransaction): + query = f""" + MATCH (source:`{source_label}`) + WITH source + MATCH (target:`{target_label}`) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run(query, properties=edge_properties) + try: + records = await result.fetch(100) + if len(records) > 1: + source_nodes = [dict(r['source']) for r in records] + target_nodes = [dict(r['target']) for r in records] + logger.warning( + f"Multiple edges created: found {len(records)} results for edge between " + f"source label '{source_label}' and target label '{target_label}'. " + f"Source nodes: {source_nodes}, " + f"Target nodes: {target_nodes}. " + "Using first edge only." + ) + if records: + logger.debug( + f"Upserted edge from '{source_label}' to '{target_label}' " + f"with properties: {edge_properties}" + ) + finally: + await result.consume() # Ensure result is consumed + + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during edge upsert: {str(e)}") raise From 59a2202e7c354333a02ff7d907a53306dea3c41e Mon Sep 17 00:00:00 2001 From: baoheping <1340473515@qq.com> Date: Sat, 8 Mar 2025 09:26:21 +0000 Subject: [PATCH 18/54] Added Minimum Degree --- lightrag_webui/bun.lock | 10 + lightrag_webui/package.json | 2 + lightrag_webui/src/components/ThemeToggle.tsx | 6 +- .../documents/ClearDocumentsDialog.tsx | 18 +- .../documents/UploadDocumentsDialog.tsx | 22 +- .../components/graph/FullScreenControl.tsx | 6 +- .../src/components/graph/GraphLabels.tsx | 10 +- .../src/components/graph/GraphSearch.tsx | 6 +- .../src/components/graph/LayoutsControl.tsx | 9 +- .../src/components/graph/PropertiesView.tsx | 27 +- .../src/components/graph/Settings.tsx | 37 +-- .../src/components/graph/StatusCard.tsx | 38 +-- .../src/components/graph/StatusIndicator.tsx | 4 +- .../src/components/graph/ZoomControl.tsx | 8 +- .../src/components/retrieval/ChatMessage.tsx | 7 +- .../components/retrieval/QuerySettings.tsx | 84 ++++--- .../src/features/DocumentManager.tsx | 46 ++-- .../src/features/RetrievalTesting.tsx | 14 +- lightrag_webui/src/features/SiteHeader.tsx | 13 +- lightrag_webui/src/i18n.js | 21 ++ lightrag_webui/src/locales/en.json | 234 +++++++++++++++++ lightrag_webui/src/locales/zh.json | 236 ++++++++++++++++++ lightrag_webui/src/main.tsx | 2 + 23 files changed, 705 insertions(+), 155 deletions(-) create mode 100644 lightrag_webui/src/i18n.js create mode 100644 lightrag_webui/src/locales/en.json create mode 100644 lightrag_webui/src/locales/zh.json diff --git a/lightrag_webui/bun.lock b/lightrag_webui/bun.lock index 6157e38c..3ca0d887 100644 --- a/lightrag_webui/bun.lock +++ b/lightrag_webui/bun.lock @@ -34,11 +34,13 @@ "cmdk": "^1.0.4", "graphology": "^0.26.0", "graphology-generators": "^0.11.2", + "i18next": "^24.2.2", "lucide-react": "^0.475.0", "minisearch": "^7.1.2", "react": "^19.0.0", "react-dom": "^19.0.0", "react-dropzone": "^14.3.6", + "react-i18next": "^15.4.1", "react-markdown": "^9.1.0", "react-number-format": "^5.4.3", "react-syntax-highlighter": "^15.6.1", @@ -765,8 +767,12 @@ "hoist-non-react-statics": ["hoist-non-react-statics@3.3.2", "", { "dependencies": { "react-is": "^16.7.0" } }, "sha512-/gGivxi8JPKWNm/W0jSmzcMPpfpPLc3dY/6GxhX2hQ9iGj3aDfklV4ET7NjKpSinLpJ5vafa9iiGIEZg10SfBw=="], + "html-parse-stringify": ["html-parse-stringify@3.0.1", "", { "dependencies": { "void-elements": "3.1.0" } }, "sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg=="], + "html-url-attributes": ["html-url-attributes@3.0.1", "", {}, "sha512-ol6UPyBWqsrO6EJySPz2O7ZSr856WDrEzM5zMqp+FJJLGMW35cLYmmZnl0vztAZxRUoNZJFTCohfjuIJ8I4QBQ=="], + "i18next": ["i18next@24.2.2", "", { "dependencies": { "@babel/runtime": "^7.23.2" }, "peerDependencies": { "typescript": "^5" }, "optionalPeers": ["typescript"] }, "sha512-NE6i86lBCKRYZa5TaUDkU5S4HFgLIEJRLr3Whf2psgaxBleQ2LC1YW1Vc+SCgkAW7VEzndT6al6+CzegSUHcTQ=="], + "ignore": ["ignore@5.3.2", "", {}, "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g=="], "import-fresh": ["import-fresh@3.3.1", "", { "dependencies": { "parent-module": "^1.0.0", "resolve-from": "^4.0.0" } }, "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ=="], @@ -1093,6 +1099,8 @@ "react-dropzone": ["react-dropzone@14.3.6", "", { "dependencies": { "attr-accept": "^2.2.4", "file-selector": "^2.1.0", "prop-types": "^15.8.1" }, "peerDependencies": { "react": ">= 16.8 || 18.0.0" } }, "sha512-U792j+x0rcwH/U/Slv/OBNU/LGFYbDLHKKiJoPhNaOianayZevCt4Y5S0CraPssH/6/wT6xhKDfzdXUgCBS0HQ=="], + "react-i18next": ["react-i18next@15.4.1", "", { "dependencies": { "@babel/runtime": "^7.25.0", "html-parse-stringify": "^3.0.1" }, "peerDependencies": { "i18next": ">= 23.2.3", "react": ">= 16.8.0" } }, "sha512-ahGab+IaSgZmNPYXdV1n+OYky95TGpFwnKRflX/16dY04DsYYKHtVLjeny7sBSCREEcoMbAgSkFiGLF5g5Oofw=="], + "react-is": ["react-is@16.13.1", "", {}, "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ=="], "react-markdown": ["react-markdown@9.1.0", "", { "dependencies": { "@types/hast": "^3.0.0", "@types/mdast": "^4.0.0", "devlop": "^1.0.0", "hast-util-to-jsx-runtime": "^2.0.0", "html-url-attributes": "^3.0.0", "mdast-util-to-hast": "^13.0.0", "remark-parse": "^11.0.0", "remark-rehype": "^11.0.0", "unified": "^11.0.0", "unist-util-visit": "^5.0.0", "vfile": "^6.0.0" }, "peerDependencies": { "@types/react": ">=18", "react": ">=18" } }, "sha512-xaijuJB0kzGiUdG7nc2MOMDUDBWPyGAjZtUrow9XxUeua8IqeP+VlIfAZ3bphpcLTnSZXz6z9jcVC/TCwbfgdw=="], @@ -1271,6 +1279,8 @@ "vite": ["vite@6.1.1", "", { "dependencies": { "esbuild": "^0.24.2", "postcss": "^8.5.2", "rollup": "^4.30.1" }, "optionalDependencies": { "fsevents": "~2.3.3" }, "peerDependencies": { "@types/node": "^18.0.0 || ^20.0.0 || >=22.0.0", "jiti": ">=1.21.0", "less": "*", "lightningcss": "^1.21.0", "sass": "*", "sass-embedded": "*", "stylus": "*", "sugarss": "*", "terser": "^5.16.0", "tsx": "^4.8.1", "yaml": "^2.4.2" }, "optionalPeers": ["@types/node", "jiti", "less", "lightningcss", "sass", "sass-embedded", "stylus", "sugarss", "terser", "tsx", "yaml"], "bin": { "vite": "bin/vite.js" } }, "sha512-4GgM54XrwRfrOp297aIYspIti66k56v16ZnqHvrIM7mG+HjDlAwS7p+Srr7J6fGvEdOJ5JcQ/D9T7HhtdXDTzA=="], + "void-elements": ["void-elements@3.1.0", "", {}, "sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w=="], + "which": ["which@2.0.2", "", { "dependencies": { "isexe": "^2.0.0" }, "bin": { "node-which": "./bin/node-which" } }, "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA=="], "which-boxed-primitive": ["which-boxed-primitive@1.1.1", "", { "dependencies": { "is-bigint": "^1.1.0", "is-boolean-object": "^1.2.1", "is-number-object": "^1.1.1", "is-string": "^1.1.1", "is-symbol": "^1.1.1" } }, "sha512-TbX3mj8n0odCBFVlY8AxkqcHASw3L60jIuF8jFP78az3C2YhmGvqbHBpAjTRH2/xqYunrJ9g1jSyjCjpoWzIAA=="], diff --git a/lightrag_webui/package.json b/lightrag_webui/package.json index 578ee36f..97fba74d 100644 --- a/lightrag_webui/package.json +++ b/lightrag_webui/package.json @@ -43,11 +43,13 @@ "cmdk": "^1.0.4", "graphology": "^0.26.0", "graphology-generators": "^0.11.2", + "i18next": "^24.2.2", "lucide-react": "^0.475.0", "minisearch": "^7.1.2", "react": "^19.0.0", "react-dom": "^19.0.0", "react-dropzone": "^14.3.6", + "react-i18next": "^15.4.1", "react-markdown": "^9.1.0", "react-number-format": "^5.4.3", "react-syntax-highlighter": "^15.6.1", diff --git a/lightrag_webui/src/components/ThemeToggle.tsx b/lightrag_webui/src/components/ThemeToggle.tsx index 8e92d862..ff333ff0 100644 --- a/lightrag_webui/src/components/ThemeToggle.tsx +++ b/lightrag_webui/src/components/ThemeToggle.tsx @@ -3,6 +3,7 @@ import useTheme from '@/hooks/useTheme' import { MoonIcon, SunIcon } from 'lucide-react' import { useCallback } from 'react' import { controlButtonVariant } from '@/lib/constants' +import { useTranslation } from 'react-i18next' /** * Component that toggles the theme between light and dark. @@ -11,13 +12,14 @@ export default function ThemeToggle() { const { theme, setTheme } = useTheme() const setLight = useCallback(() => setTheme('light'), [setTheme]) const setDark = useCallback(() => setTheme('dark'), [setTheme]) + const { t } = useTranslation() if (theme === 'dark') { return ( e.preventDefault()}> - Clear documents - Do you really want to clear all documents? + {t('documentPanel.clearDocuments.title')} + {t('documentPanel.clearDocuments.confirm')} diff --git a/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx b/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx index 7149eb28..7f17393c 100644 --- a/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx +++ b/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx @@ -14,8 +14,10 @@ import { errorMessage } from '@/lib/utils' import { uploadDocument } from '@/api/lightrag' import { UploadIcon } from 'lucide-react' +import { useTranslation } from 'react-i18next' export default function UploadDocumentsDialog() { + const { t } = useTranslation() const [open, setOpen] = useState(false) const [isUploading, setIsUploading] = useState(false) const [progresses, setProgresses] = useState>({}) @@ -29,24 +31,24 @@ export default function UploadDocumentsDialog() { filesToUpload.map(async (file) => { try { const result = await uploadDocument(file, (percentCompleted: number) => { - console.debug(`Uploading ${file.name}: ${percentCompleted}%`) + console.debug(t('documentPanel.uploadDocuments.uploading', { name: file.name, percent: percentCompleted })) setProgresses((pre) => ({ ...pre, [file.name]: percentCompleted })) }) if (result.status === 'success') { - toast.success(`Upload Success:\n${file.name} uploaded successfully`) + toast.success(t('documentPanel.uploadDocuments.success', { name: file.name })) } else { - toast.error(`Upload Failed:\n${file.name}\n${result.message}`) + toast.error(t('documentPanel.uploadDocuments.failed', { name: file.name, message: result.message })) } } catch (err) { - toast.error(`Upload Failed:\n${file.name}\n${errorMessage(err)}`) + toast.error(t('documentPanel.uploadDocuments.error', { name: file.name, error: errorMessage(err) })) } }) ) } catch (err) { - toast.error('Upload Failed\n' + errorMessage(err)) + toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) })) } finally { setIsUploading(false) // setOpen(false) @@ -66,21 +68,21 @@ export default function UploadDocumentsDialog() { }} > - e.preventDefault()}> - Upload documents + {t('documentPanel.uploadDocuments.title')} - Drag and drop your documents here or click to browse. + {t('documentPanel.uploadDocuments.description')} { const { isFullScreen, toggle } = useFullScreen() + const { t } = useTranslation() return ( <> {isFullScreen ? ( - ) : ( - )} diff --git a/lightrag_webui/src/components/graph/GraphLabels.tsx b/lightrag_webui/src/components/graph/GraphLabels.tsx index a3849e1f..7bc26c88 100644 --- a/lightrag_webui/src/components/graph/GraphLabels.tsx +++ b/lightrag_webui/src/components/graph/GraphLabels.tsx @@ -5,6 +5,7 @@ import { useSettingsStore } from '@/stores/settings' import { useGraphStore } from '@/stores/graph' import { labelListLimit } from '@/lib/constants' import MiniSearch from 'minisearch' +import { useTranslation } from 'react-i18next' const lastGraph: any = { graph: null, @@ -13,6 +14,7 @@ const lastGraph: any = { } const GraphLabels = () => { + const { t } = useTranslation() const label = useSettingsStore.use.queryLabel() const graph = useGraphStore.use.sigmaGraph() @@ -69,7 +71,7 @@ const GraphLabels = () => { return result.length <= labelListLimit ? result - : [...result.slice(0, labelListLimit), `And ${result.length - labelListLimit} others`] + : [...result.slice(0, labelListLimit), t('graphLabels.andOthers', { count: result.length - labelListLimit })] }, [getSearchEngine] ) @@ -84,14 +86,14 @@ const GraphLabels = () => { className="ml-2" triggerClassName="max-h-8" searchInputClassName="max-h-8" - triggerTooltip="Select query label" + triggerTooltip={t('graphPanel.graphLabels.selectTooltip')} fetcher={fetchData} renderOption={(item) =>
{item}
} getOptionValue={(item) => item} getDisplayValue={(item) =>
{item}
} notFound={
No labels found
} - label="Label" - placeholder="Search labels..." + label={t('graphPanel.graphLabels.label')} + placeholder={t('graphPanel.graphLabels.placeholder')} value={label !== null ? label : ''} onChange={setQueryLabel} /> diff --git a/lightrag_webui/src/components/graph/GraphSearch.tsx b/lightrag_webui/src/components/graph/GraphSearch.tsx index 3edc3ede..bbb8cb5b 100644 --- a/lightrag_webui/src/components/graph/GraphSearch.tsx +++ b/lightrag_webui/src/components/graph/GraphSearch.tsx @@ -9,6 +9,7 @@ import { AsyncSearch } from '@/components/ui/AsyncSearch' import { searchResultLimit } from '@/lib/constants' import { useGraphStore } from '@/stores/graph' import MiniSearch from 'minisearch' +import { useTranslation } from 'react-i18next' interface OptionItem { id: string @@ -44,6 +45,7 @@ export const GraphSearchInput = ({ onFocus?: GraphSearchInputProps['onFocus'] value?: GraphSearchInputProps['value'] }) => { + const { t } = useTranslation() const graph = useGraphStore.use.sigmaGraph() const searchEngine = useMemo(() => { @@ -97,7 +99,7 @@ export const GraphSearchInput = ({ { type: 'message', id: messageId, - message: `And ${result.length - searchResultLimit} others` + message: t('graphPanel.search.message', { count: result.length - searchResultLimit }) } ] }, @@ -118,7 +120,7 @@ export const GraphSearchInput = ({ if (id !== messageId && onFocus) onFocus(id ? { id, type: 'nodes' } : null) }} label={'item'} - placeholder="Search nodes..." + placeholder={t('graphPanel.search.placeholder')} /> ) } diff --git a/lightrag_webui/src/components/graph/LayoutsControl.tsx b/lightrag_webui/src/components/graph/LayoutsControl.tsx index c57b371a..0ed97f2f 100644 --- a/lightrag_webui/src/components/graph/LayoutsControl.tsx +++ b/lightrag_webui/src/components/graph/LayoutsControl.tsx @@ -16,6 +16,7 @@ import { controlButtonVariant } from '@/lib/constants' import { useSettingsStore } from '@/stores/settings' import { GripIcon, PlayIcon, PauseIcon } from 'lucide-react' +import { useTranslation } from 'react-i18next' type LayoutName = | 'Circular' @@ -28,6 +29,7 @@ type LayoutName = const WorkerLayoutControl = ({ layout, autoRunFor }: WorkerLayoutControlProps) => { const sigma = useSigma() const { stop, start, isRunning } = layout + const { t } = useTranslation() /** * Init component when Sigma or component settings change. @@ -61,7 +63,7 @@ const WorkerLayoutControl = ({ layout, autoRunFor }: WorkerLayoutControlProps) = @@ -166,7 +169,7 @@ const LayoutsControl = () => { key={name} className="cursor-pointer text-xs" > - {name} + {t(`graphPanel.sideBar.layoutsControl.layouts.${name}`)} ))} diff --git a/lightrag_webui/src/components/graph/PropertiesView.tsx b/lightrag_webui/src/components/graph/PropertiesView.tsx index dec80460..4571b02b 100644 --- a/lightrag_webui/src/components/graph/PropertiesView.tsx +++ b/lightrag_webui/src/components/graph/PropertiesView.tsx @@ -2,6 +2,7 @@ import { useEffect, useState } from 'react' import { useGraphStore, RawNodeType, RawEdgeType } from '@/stores/graph' import Text from '@/components/ui/Text' import useLightragGraph from '@/hooks/useLightragGraph' +import { useTranslation } from 'react-i18next' /** * Component that view properties of elements in graph. @@ -147,21 +148,22 @@ const PropertyRow = ({ } const NodePropertiesView = ({ node }: { node: NodeType }) => { + const { t } = useTranslation() return (
- +
- + { useGraphStore.getState().setSelectedNode(node.id, true) }} /> - +
- +
{Object.keys(node.properties) .sort() @@ -172,7 +174,7 @@ const NodePropertiesView = ({ node }: { node: NodeType }) => { {node.relationships.length > 0 && ( <>
{node.relationships.map(({ type, id, label }) => { @@ -195,28 +197,29 @@ const NodePropertiesView = ({ node }: { node: NodeType }) => { } const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => { + const { t } = useTranslation() return (
- +
- - {edge.type && } + + {edge.type && } { useGraphStore.getState().setSelectedNode(edge.source, true) }} /> { useGraphStore.getState().setSelectedNode(edge.target, true) }} />
- +
{Object.keys(edge.properties) .sort() diff --git a/lightrag_webui/src/components/graph/Settings.tsx b/lightrag_webui/src/components/graph/Settings.tsx index 4d2b998d..ddf05d40 100644 --- a/lightrag_webui/src/components/graph/Settings.tsx +++ b/lightrag_webui/src/components/graph/Settings.tsx @@ -10,6 +10,7 @@ import { useSettingsStore } from '@/stores/settings' import { useBackendState } from '@/stores/state' import { SettingsIcon } from 'lucide-react' +import { useTranslation } from "react-i18next"; /** * Component that displays a checkbox with a label. @@ -195,10 +196,12 @@ export default function Settings() { [setTempApiKey] ) + const { t } = useTranslation(); + return ( - @@ -212,7 +215,7 @@ export default function Settings() { @@ -220,12 +223,12 @@ export default function Settings() { @@ -233,12 +236,12 @@ export default function Settings() { @@ -246,28 +249,34 @@ export default function Settings() { +
- +
e.preventDefault()}>
@@ -295,7 +304,7 @@ export default function Settings() { size="sm" className="max-h-full shrink-0" > - Save + {t("graphPanel.sideBar.settings.save")}
diff --git a/lightrag_webui/src/components/graph/StatusCard.tsx b/lightrag_webui/src/components/graph/StatusCard.tsx index 3084d103..e67cbd30 100644 --- a/lightrag_webui/src/components/graph/StatusCard.tsx +++ b/lightrag_webui/src/components/graph/StatusCard.tsx @@ -1,58 +1,60 @@ import { LightragStatus } from '@/api/lightrag' +import { useTranslation } from 'react-i18next' const StatusCard = ({ status }: { status: LightragStatus | null }) => { + const { t } = useTranslation() if (!status) { - return
Status information unavailable
+ return
{t('graphPanel.statusCard.unavailable')}
} return (
-

Storage Info

+

{t('graphPanel.statusCard.storageInfo')}

- Working Directory: + {t('graphPanel.statusCard.workingDirectory')}: {status.working_directory} - Input Directory: + {t('graphPanel.statusCard.inputDirectory')}: {status.input_directory}
-

LLM Configuration

+

{t('graphPanel.statusCard.llmConfig')}

- LLM Binding: + {t('graphPanel.statusCard.llmBinding')}: {status.configuration.llm_binding} - LLM Binding Host: + {t('graphPanel.statusCard.llmBindingHost')}: {status.configuration.llm_binding_host} - LLM Model: + {t('graphPanel.statusCard.llmModel')}: {status.configuration.llm_model} - Max Tokens: + {t('graphPanel.statusCard.maxTokens')}: {status.configuration.max_tokens}
-

Embedding Configuration

+

{t('graphPanel.statusCard.embeddingConfig')}

- Embedding Binding: + {t('graphPanel.statusCard.embeddingBinding')}: {status.configuration.embedding_binding} - Embedding Binding Host: + {t('graphPanel.statusCard.embeddingBindingHost')}: {status.configuration.embedding_binding_host} - Embedding Model: + {t('graphPanel.statusCard.embeddingModel')}: {status.configuration.embedding_model}
-

Storage Configuration

+

{t('graphPanel.statusCard.storageConfig')}

- KV Storage: + {t('graphPanel.statusCard.kvStorage')}: {status.configuration.kv_storage} - Doc Status Storage: + {t('graphPanel.statusCard.docStatusStorage')}: {status.configuration.doc_status_storage} - Graph Storage: + {t('graphPanel.statusCard.graphStorage')}: {status.configuration.graph_storage} - Vector Storage: + {t('graphPanel.statusCard.vectorStorage')}: {status.configuration.vector_storage}
diff --git a/lightrag_webui/src/components/graph/StatusIndicator.tsx b/lightrag_webui/src/components/graph/StatusIndicator.tsx index 3272d9fa..d7a1831f 100644 --- a/lightrag_webui/src/components/graph/StatusIndicator.tsx +++ b/lightrag_webui/src/components/graph/StatusIndicator.tsx @@ -3,8 +3,10 @@ import { useBackendState } from '@/stores/state' import { useEffect, useState } from 'react' import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/Popover' import StatusCard from '@/components/graph/StatusCard' +import { useTranslation } from 'react-i18next' const StatusIndicator = () => { + const { t } = useTranslation() const health = useBackendState.use.health() const lastCheckTime = useBackendState.use.lastCheckTime() const status = useBackendState.use.status() @@ -33,7 +35,7 @@ const StatusIndicator = () => { )} /> - {health ? 'Connected' : 'Disconnected'} + {health ? t('graphPanel.statusIndicator.connected') : t('graphPanel.statusIndicator.disconnected')}
diff --git a/lightrag_webui/src/components/graph/ZoomControl.tsx b/lightrag_webui/src/components/graph/ZoomControl.tsx index 790b4423..0aa55416 100644 --- a/lightrag_webui/src/components/graph/ZoomControl.tsx +++ b/lightrag_webui/src/components/graph/ZoomControl.tsx @@ -3,12 +3,14 @@ import { useCallback } from 'react' import Button from '@/components/ui/Button' import { ZoomInIcon, ZoomOutIcon, FullscreenIcon } from 'lucide-react' import { controlButtonVariant } from '@/lib/constants' +import { useTranslation } from "react-i18next"; /** * Component that provides zoom controls for the graph viewer. */ const ZoomControl = () => { const { zoomIn, zoomOut, reset } = useCamera({ duration: 200, factor: 1.5 }) + const { t } = useTranslation(); const handleZoomIn = useCallback(() => zoomIn(), [zoomIn]) const handleZoomOut = useCallback(() => zoomOut(), [zoomOut]) @@ -16,16 +18,16 @@ const ZoomControl = () => { return ( <> - -
@@ -98,29 +100,29 @@ export default function DocumentManager() { - Uploaded documents - view the uploaded documents here + {t('documentPanel.documentManager.uploadedTitle')} + {t('documentPanel.documentManager.uploadedDescription')} {!docs && ( )} {docs && ( - ID - Summary - Status - Length - Chunks - Created - Updated - Metadata + {t('documentPanel.documentManager.columns.id')} + {t('documentPanel.documentManager.columns.summary')} + {t('documentPanel.documentManager.columns.status')} + {t('documentPanel.documentManager.columns.length')} + {t('documentPanel.documentManager.columns.chunks')} + {t('documentPanel.documentManager.columns.created')} + {t('documentPanel.documentManager.columns.updated')} + {t('documentPanel.documentManager.columns.metadata')} @@ -137,13 +139,13 @@ export default function DocumentManager() { {status === 'processed' && ( - Completed + {t('documentPanel.documentManager.status.completed')} )} {status === 'processing' && ( - Processing + {t('documentPanel.documentManager.status.processing')} )} - {status === 'pending' && Pending} - {status === 'failed' && Failed} + {status === 'pending' && {t('documentPanel.documentManager.status.pending')}} + {status === 'failed' && {t('documentPanel.documentManager.status.failed')}} {doc.error && ( ⚠️ diff --git a/lightrag_webui/src/features/RetrievalTesting.tsx b/lightrag_webui/src/features/RetrievalTesting.tsx index 340255a2..84955aa1 100644 --- a/lightrag_webui/src/features/RetrievalTesting.tsx +++ b/lightrag_webui/src/features/RetrievalTesting.tsx @@ -8,8 +8,10 @@ import { useDebounce } from '@/hooks/useDebounce' import QuerySettings from '@/components/retrieval/QuerySettings' import { ChatMessage, MessageWithError } from '@/components/retrieval/ChatMessage' import { EraserIcon, SendIcon } from 'lucide-react' +import { useTranslation } from 'react-i18next' export default function RetrievalTesting() { + const { t } = useTranslation() const [messages, setMessages] = useState( () => useSettingsStore.getState().retrievalHistory || [] ) @@ -89,7 +91,7 @@ export default function RetrievalTesting() { } } catch (err) { // Handle error - updateAssistantMessage(`Error: Failed to get response\n${errorMessage(err)}`, true) + updateAssistantMessage(`${t('retrievePanel.retrieval.error')}\n${errorMessage(err)}`, true) } finally { // Clear loading and add messages to state setIsLoading(false) @@ -98,7 +100,7 @@ export default function RetrievalTesting() { .setRetrievalHistory([...prevMessages, userMessage, assistantMessage]) } }, - [inputValue, isLoading, messages, setMessages] + [inputValue, isLoading, messages, setMessages, t] ) const debouncedMessages = useDebounce(messages, 100) @@ -117,7 +119,7 @@ export default function RetrievalTesting() {
{messages.length === 0 ? (
- Start a retrieval by typing your query below + {t('retrievePanel.retrieval.startPrompt')}
) : ( messages.map((message, idx) => ( @@ -143,18 +145,18 @@ export default function RetrievalTesting() { size="sm" > - Clear + {t('retrievePanel.retrieval.clear')} setInputValue(e.target.value)} - placeholder="Type your query..." + placeholder={t('retrievePanel.retrieval.placeholder')} disabled={isLoading} />
diff --git a/lightrag_webui/src/features/SiteHeader.tsx b/lightrag_webui/src/features/SiteHeader.tsx index c09ce089..ac3bdd70 100644 --- a/lightrag_webui/src/features/SiteHeader.tsx +++ b/lightrag_webui/src/features/SiteHeader.tsx @@ -4,6 +4,7 @@ import ThemeToggle from '@/components/ThemeToggle' import { TabsList, TabsTrigger } from '@/components/ui/Tabs' import { useSettingsStore } from '@/stores/settings' import { cn } from '@/lib/utils' +import { useTranslation } from 'react-i18next' import { ZapIcon, GithubIcon } from 'lucide-react' @@ -29,21 +30,22 @@ function NavigationTab({ value, currentTab, children }: NavigationTabProps) { function TabsNavigation() { const currentTab = useSettingsStore.use.currentTab() + const { t } = useTranslation() return (
- Documents + {t('header.documents')} - Knowledge Graph + {t('header.knowledgeGraph')} - Retrieval + {t('header.retrieval')} - API + {t('header.api')}
@@ -51,6 +53,7 @@ function TabsNavigation() { } export default function SiteHeader() { + const { t } = useTranslation() return (
@@ -64,7 +67,7 @@ export default function SiteHeader() {