From a5cd2b19588b16fb3afa63786db048edc3c531b9 Mon Sep 17 00:00:00 2001 From: destiny <1170513392@qq.com> Date: Fri, 14 Feb 2025 11:00:54 +0800 Subject: [PATCH 01/28] Fix embedding type conversion issue in the query function of chroma_impl; chroma_impl supports local persistent client: PersistentClient --- examples/test_chromadb.py | 70 +++++++++++++++++++++++++------------- lightrag/kg/chroma_impl.py | 60 ++++++++++++++++++-------------- 2 files changed, 82 insertions(+), 48 deletions(-) diff --git a/examples/test_chromadb.py b/examples/test_chromadb.py index 0e6361ed..5293f05d 100644 --- a/examples/test_chromadb.py +++ b/examples/test_chromadb.py @@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) # ChromaDB Configuration +CHROMADB_USE_LOCAL_PERSISTENT = False +# Local PersistentClient Configuration +CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")) +# Remote HttpClient Configuration CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost") CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000)) CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token") @@ -60,30 +64,50 @@ async def create_embedding_function_instance(): async def initialize_rag(): embedding_func_instance = await create_embedding_function_instance() - - return LightRAG( - working_dir=WORKING_DIR, - llm_model_func=gpt_4o_mini_complete, - embedding_func=embedding_func_instance, - vector_storage="ChromaVectorDBStorage", - log_level="DEBUG", - embedding_batch_num=32, - vector_db_storage_cls_kwargs={ - "host": CHROMADB_HOST, - "port": CHROMADB_PORT, - "auth_token": CHROMADB_AUTH_TOKEN, - "auth_provider": CHROMADB_AUTH_PROVIDER, - "auth_header_name": CHROMADB_AUTH_HEADER, - "collection_settings": { - "hnsw:space": "cosine", - "hnsw:construction_ef": 128, - "hnsw:search_ef": 128, - "hnsw:M": 16, - "hnsw:batch_size": 100, - "hnsw:sync_threshold": 1000, + if CHROMADB_USE_LOCAL_PERSISTENT: + return LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, + embedding_func=embedding_func_instance, + vector_storage="ChromaVectorDBStorage", + log_level="DEBUG", + embedding_batch_num=32, + vector_db_storage_cls_kwargs={ + "local_path": CHROMADB_LOCAL_PATH, + "collection_settings": { + "hnsw:space": "cosine", + "hnsw:construction_ef": 128, + "hnsw:search_ef": 128, + "hnsw:M": 16, + "hnsw:batch_size": 100, + "hnsw:sync_threshold": 1000, + }, }, - }, - ) + ) + else: + return LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, + embedding_func=embedding_func_instance, + vector_storage="ChromaVectorDBStorage", + log_level="DEBUG", + embedding_batch_num=32, + vector_db_storage_cls_kwargs={ + "host": CHROMADB_HOST, + "port": CHROMADB_PORT, + "auth_token": CHROMADB_AUTH_TOKEN, + "auth_provider": CHROMADB_AUTH_PROVIDER, + "auth_header_name": CHROMADB_AUTH_HEADER, + "collection_settings": { + "hnsw:space": "cosine", + "hnsw:construction_ef": 128, + "hnsw:search_ef": 128, + "hnsw:M": 16, + "hnsw:batch_size": 100, + "hnsw:sync_threshold": 1000, + }, + }, + ) # Run the initialization diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 72a2627a..a2fbf674 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -3,7 +3,7 @@ import asyncio from dataclasses import dataclass from typing import Union import numpy as np -from chromadb import HttpClient +from chromadb import HttpClient, PersistentClient from chromadb.config import Settings from lightrag.base import BaseVectorStorage from lightrag.utils import logger @@ -48,31 +48,41 @@ class ChromaVectorDBStorage(BaseVectorStorage): **user_collection_settings, } - auth_provider = config.get( - "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider" - ) - auth_credentials = config.get("auth_token", "secret-token") - headers = {} + local_path = config.get("local_path", None) + if local_path: + self._client = PersistentClient( + path=local_path, + settings=Settings( + allow_reset=True, + anonymized_telemetry=False, + ), + ) + else: + auth_provider = config.get( + "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider" + ) + auth_credentials = config.get("auth_token", "secret-token") + headers = {} - if "token_authn" in auth_provider: - headers = { - config.get("auth_header_name", "X-Chroma-Token"): auth_credentials - } - elif "basic_authn" in auth_provider: - auth_credentials = config.get("auth_credentials", "admin:admin") + if "token_authn" in auth_provider: + headers = { + config.get("auth_header_name", "X-Chroma-Token"): auth_credentials + } + elif "basic_authn" in auth_provider: + auth_credentials = config.get("auth_credentials", "admin:admin") - self._client = HttpClient( - host=config.get("host", "localhost"), - port=config.get("port", 8000), - headers=headers, - settings=Settings( - chroma_api_impl="rest", - chroma_client_auth_provider=auth_provider, - chroma_client_auth_credentials=auth_credentials, - allow_reset=True, - anonymized_telemetry=False, - ), - ) + self._client = HttpClient( + host=config.get("host", "localhost"), + port=config.get("port", 8000), + headers=headers, + settings=Settings( + chroma_api_impl="rest", + chroma_client_auth_provider=auth_provider, + chroma_client_auth_credentials=auth_credentials, + allow_reset=True, + anonymized_telemetry=False, + ), + ) self._collection = self._client.get_or_create_collection( name=self.namespace, @@ -143,7 +153,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func([query]) results = self._collection.query( - query_embeddings=embedding.tolist(), + query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding, n_results=top_k * 2, # Request more results to allow for filtering include=["metadatas", "distances", "documents"], ) From 2c56141bfd5ab8a1f8d52b77f08dbab23a067ee2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 14 Feb 2025 12:34:26 +0800 Subject: [PATCH 02/28] Standardize variable names with other vector database implementations (without functional modifications) --- lightrag/kg/faiss_impl.py | 4 ++-- lightrag/kg/nano_vector_db_impl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 0dca9e4c..b2090d78 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -27,8 +27,8 @@ class FaissVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Grab config values if available - config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - cosine_threshold = config.get("cosine_better_than_threshold") + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: raise ValueError( "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 2db8f72a..60eed3dc 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -79,8 +79,8 @@ class NanoVectorDBStorage(BaseVectorStorage): # Initialize lock only for file operations self._save_lock = asyncio.Lock() # Use global config value if specified, otherwise use default - config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - cosine_threshold = config.get("cosine_better_than_threshold") + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: raise ValueError( "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" From 258c7596e6a49eb1533c5e41280bbab89a818902 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 14 Feb 2025 12:50:43 +0800 Subject: [PATCH 03/28] fix: Improve file path handling and logging for document scanning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Convert relative paths to absolute paths • Add logging for file scanning progress • Log total number of new files found • Enhance file scanning feedback • Improve path resolution safety --- lightrag/api/lightrag_server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1aeff264..ce182bc1 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -564,6 +564,10 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() + # conver relative path to absolute path + args.working_dir = os.path.abspath(args.working_dir) + args.input_dir = os.path.abspath(args.input_dir) + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name return args @@ -595,6 +599,7 @@ class DocumentManager: """Scan input directory for new files""" new_files = [] for ext in self.supported_extensions: + logger.info(f"Scanning for {ext} files in {self.input_dir}") for file_path in self.input_dir.rglob(f"*{ext}"): if file_path not in self.indexed_files: new_files.append(file_path) @@ -1198,6 +1203,7 @@ def create_app(args): new_files = doc_manager.scan_directory_for_new_files() scan_progress["total_files"] = len(new_files) + logger.info(f"Found {len(new_files)} new files to index.") for file_path in new_files: try: with progress_lock: From f6058b79b643e8d52386f435b6d9bf4830d06038 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 14 Feb 2025 13:26:19 +0800 Subject: [PATCH 04/28] Update .env.example with absolute path placeholders --- .env.example | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index 4b64ecb4..022bd63d 100644 --- a/.env.example +++ b/.env.example @@ -12,8 +12,8 @@ # LIGHTRAG_API_KEY=your-secure-api-key-here ### Directory Configuration -# WORKING_DIR=./rag_storage -# INPUT_DIR=./inputs +# WORKING_DIR= +# INPUT_DIR= ### Logging level LOG_LEVEL=INFO From cd81312659630cde494b34bf26f73720187f80fc Mon Sep 17 00:00:00 2001 From: Pankaj Kaushal Date: Fri, 14 Feb 2025 16:04:06 +0100 Subject: [PATCH 05/28] Enhance Neo4j graph storage with error handling and label validation - Add label existence check and validation methods in Neo4j implementation - Improve error handling in get_node, get_edge, and upsert methods - Add default values and logging for missing edge properties - Ensure consistent label processing across graph storage methods --- lightrag/kg/neo4j_impl.py | 134 ++++++++++++++++++++++++++++---------- lightrag/operate.py | 62 ++++++++++++++---- 2 files changed, 150 insertions(+), 46 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index e9a53110..15525375 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage): async def index_done_callback(self): print("KG successfully indexed.") - async def has_node(self, node_id: str) -> bool: - entity_name_label = node_id.strip('"') + async def _label_exists(self, label: str) -> bool: + """Check if a label exists in the Neo4j database.""" + query = "CALL db.labels() YIELD label RETURN label" + try: + async with self._driver.session(database=self._DATABASE) as session: + result = await session.run(query) + labels = [record["label"] for record in await result.data()] + return label in labels + except Exception as e: + logger.error(f"Error checking label existence: {e}") + return False + async def _ensure_label(self, label: str) -> str: + """Ensure a label exists by validating it.""" + clean_label = label.strip('"') + if not await self._label_exists(clean_label): + logger.warning(f"Label '{clean_label}' does not exist in Neo4j") + return clean_label + + async def has_node(self, node_id: str) -> bool: + entity_name_label = await self._ensure_label(node_id) async with self._driver.session(database=self._DATABASE) as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" @@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage): return single_result["edgeExists"] async def get_node(self, node_id: str) -> Union[dict, None]: + """Get node by its label identifier. + + Args: + node_id: The node label to look up + + Returns: + dict: Node properties if found + None: If node not found + """ async with self._driver.session(database=self._DATABASE) as session: - entity_name_label = node_id.strip('"') + entity_name_label = await self._ensure_label(node_id) query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) record = await result.single() @@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') - """ - Find all edges between nodes of two given labels + """Find edge between two nodes identified by their labels. Args: - source_node_label (str): Label of the source nodes - target_node_label (str): Label of the target nodes + source_node_id (str): Label of the source node + target_node_id (str): Label of the target node Returns: - list: List of all relationships/edges found + dict: Edge properties if found, with at least {"weight": 0.0} + None: If error occurs """ - async with self._driver.session(database=self._DATABASE) as session: - query = f""" - MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) - RETURN properties(r) as edge_properties - LIMIT 1 - """.format( - entity_name_label_source=entity_name_label_source, - entity_name_label_target=entity_name_label_target, - ) + try: + entity_name_label_source = source_node_id.strip('"') + entity_name_label_target = target_node_id.strip('"') - result = await session.run(query) - record = await result.single() - if record: - result = dict(record["edge_properties"]) - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" + async with self._driver.session(database=self._DATABASE) as session: + query = f""" + MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) + RETURN properties(r) as edge_properties + LIMIT 1 + """.format( + entity_name_label_source=entity_name_label_source, + entity_name_label_target=entity_name_label_target, ) - return result - else: - return None + + result = await session.run(query) + record = await result.single() + if record and "edge_properties" in record: + try: + result = dict(record["edge_properties"]) + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "target_id": None, + } + for key, default_value in required_keys.items(): + if key not in result: + result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" + ) + return result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return {"weight": 0.0, "source_id": None, "target_id": None} + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + ) + # Return default edge properties when no edge found + return {"weight": 0.0, "source_id": None, "target_id": None} + + except Exception as e: + logger.error( + f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + # Return default edge properties on error + return {"weight": 0.0, "source_id": None, "target_id": None} async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: node_label = source_node_id.strip('"') @@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ - label = node_id.strip('"') + label = await self._ensure_label(node_id) properties = node_data async def _do_upsert(tx: AsyncManagedTransaction): @@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage): neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, ) ), ) @@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage): target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge """ - source_node_label = source_node_id.strip('"') - target_node_label = target_node_id.strip('"') + source_label = await self._ensure_label(source_node_id) + target_label = await self._ensure_label(target_node_id) edge_properties = edge_data async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" - MATCH (source:`{source_node_label}`) + MATCH (source:`{source_label}`) WITH source - MATCH (target:`{target_node_label}`) + MATCH (target:`{target_label}`) MERGE (source)-[r:DIRECTED]->(target) SET r += $properties RETURN r """ - await tx.run(query, properties=edge_properties) + result = await tx.run(query, properties=edge_properties) + record = await result.single() logger.debug( - f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}" + f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" ) try: diff --git a/lightrag/operate.py b/lightrag/operate.py index 04aad0d4..8cf77f57 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -237,25 +237,65 @@ async def _merge_edges_then_upsert( if await knowledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) - already_weights.append(already_edge["weight"]) - already_source_ids.extend( - split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) - ) - already_description.append(already_edge["description"]) - already_keywords.extend( - split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP]) - ) + # Handle the case where get_edge returns None or missing fields + if already_edge: + # Get weight with default 0.0 if missing + if "weight" in already_edge: + already_weights.append(already_edge["weight"]) + else: + logger.warning( + f"Edge between {src_id} and {tgt_id} missing weight field" + ) + already_weights.append(0.0) + # Get source_id with empty string default if missing or None + if "source_id" in already_edge and already_edge["source_id"] is not None: + already_source_ids.extend( + split_string_by_multi_markers( + already_edge["source_id"], [GRAPH_FIELD_SEP] + ) + ) + + # Get description with empty string default if missing or None + if ( + "description" in already_edge + and already_edge["description"] is not None + ): + already_description.append(already_edge["description"]) + + # Get keywords with empty string default if missing or None + if "keywords" in already_edge and already_edge["keywords"] is not None: + already_keywords.extend( + split_string_by_multi_markers( + already_edge["keywords"], [GRAPH_FIELD_SEP] + ) + ) + + # Process edges_data with None checks weight = sum([dp["weight"] for dp in edges_data] + already_weights) description = GRAPH_FIELD_SEP.join( - sorted(set([dp["description"] for dp in edges_data] + already_description)) + sorted( + set( + [dp["description"] for dp in edges_data if dp.get("description")] + + already_description + ) + ) ) keywords = GRAPH_FIELD_SEP.join( - sorted(set([dp["keywords"] for dp in edges_data] + already_keywords)) + sorted( + set( + [dp["keywords"] for dp in edges_data if dp.get("keywords")] + + already_keywords + ) + ) ) source_id = GRAPH_FIELD_SEP.join( - set([dp["source_id"] for dp in edges_data] + already_source_ids) + set( + [dp["source_id"] for dp in edges_data if dp.get("source_id")] + + already_source_ids + ) ) + for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): await knowledge_graph_inst.upsert_node( From 70fc4cbfb0e769dcaea3823b0d79bac6e693410c Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sat, 15 Feb 2025 00:34:38 +0800 Subject: [PATCH 06/28] handle missing edge types in graph data --- lightrag/types.py | 4 ++-- lightrag_webui/src/components/PropertiesView.tsx | 2 +- lightrag_webui/src/hooks/useLightragGraph.tsx | 12 ++++++++++-- lightrag_webui/src/stores/graph.ts | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/lightrag/types.py b/lightrag/types.py index 9c8e0099..d2670ddc 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional class GPTKeywordExtractionFormat(BaseModel): @@ -15,7 +15,7 @@ class KnowledgeGraphNode(BaseModel): class KnowledgeGraphEdge(BaseModel): id: str - type: str + type: Optional[str] source: str # id of source node target: str # id of target node properties: Dict[str, Any] # anything else goes here diff --git a/lightrag_webui/src/components/PropertiesView.tsx b/lightrag_webui/src/components/PropertiesView.tsx index 078420e6..dec80460 100644 --- a/lightrag_webui/src/components/PropertiesView.tsx +++ b/lightrag_webui/src/components/PropertiesView.tsx @@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
- + {edge.type && } { } for (const edge of graph.edges) { - if (!edge.id || !edge.source || !edge.target || !edge.type || !edge.properties) { + if (!edge.id || !edge.source || !edge.target) { return false } } @@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => { if (source !== undefined && source !== undefined) { const sourceNode = rawData.nodes[source] const targetNode = rawData.nodes[target] + if (!sourceNode) { + console.error(`Source node ${edge.source} is undefined`) + continue + } + if (!targetNode) { + console.error(`Target node ${edge.target} is undefined`) + continue + } sourceNode.degree += 1 targetNode.degree += 1 } @@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => { for (const rawEdge of rawGraph?.edges ?? []) { rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, { - label: rawEdge.type + label: rawEdge.type || undefined }) } diff --git a/lightrag_webui/src/stores/graph.ts b/lightrag_webui/src/stores/graph.ts index b78e9bf8..b7c2120c 100644 --- a/lightrag_webui/src/stores/graph.ts +++ b/lightrag_webui/src/stores/graph.ts @@ -19,7 +19,7 @@ export type RawEdgeType = { id: string source: string target: string - type: string + type?: string properties: Record dynamicId: string From a600beb619c8b784bb309c7f3dcec94e14573570 Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sat, 15 Feb 2025 00:38:41 +0800 Subject: [PATCH 07/28] implement MongoDB support for VectorDB storage. optimize existing MongoDB implementations --- lightrag/api/README.md | 3 +- lightrag/kg/mongo_impl.py | 510 +++++++++++++++++++++++++++++++++----- lightrag/lightrag.py | 3 + 3 files changed, 456 insertions(+), 60 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 8e5a61d5..18ab3594 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -177,7 +177,8 @@ TiDBVectorDBStorage TiDB PGVectorStorage Postgres FaissVectorDBStorage Faiss QdrantVectorDBStorage Qdrant -OracleVectorDBStorag Oracle +OracleVectorDBStorage Oracle +MongoVectorDBStorage MongoDB ``` * DOC_STATUS_STORAGE:supported implement-name diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 226aecf2..c216e7be 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -4,6 +4,7 @@ import numpy as np import pipmaster as pm import configparser from tqdm.asyncio import tqdm as tqdm_async +import asyncio if not pm.is_installed("pymongo"): pm.install("pymongo") @@ -14,16 +15,20 @@ if not pm.is_installed("motor"): from typing import Any, List, Tuple, Union from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient +from pymongo.operations import SearchIndexModel +from pymongo.errors import PyMongoError from ..base import ( BaseGraphStorage, BaseKVStorage, + BaseVectorStorage, DocProcessingStatus, DocStatus, DocStatusStorage, ) from ..namespace import NameSpace, is_namespace from ..utils import logger +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge config = configparser.ConfigParser() @@ -33,56 +38,66 @@ config.read("config.ini", "utf-8") @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): - client = MongoClient( - os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), ) + client = AsyncIOMotorClient(uri) database = client.get_database( os.environ.get( "MONGO_DATABASE", config.get("mongodb", "database", fallback="LightRAG"), ) ) - self._data = database.get_collection(self.namespace) - logger.info(f"Use MongoDB as KV {self.namespace}") + + self._collection_name = self.namespace + + self._data = database.get_collection(self._collection_name) + logger.debug(f"Use MongoDB as KV {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.find_one({"_id": id}) + return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - return list(self._data.find({"_id": {"$in": ids}})) + cursor = self._data.find({"_id": {"$in": ids}}) + return await cursor.to_list() async def filter_keys(self, data: set[str]) -> set[str]: - existing_ids = [ - str(x["_id"]) - for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) - ] - return set([s for s in data if s not in existing_ids]) + cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + existing_ids = {str(x["_id"]) async for x in cursor} + return data - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + update_tasks = [] for mode, items in data.items(): - for k, v in tqdm_async(items.items(), desc="Upserting"): + for k, v in items.items(): key = f"{mode}_{k}" - result = self._data.update_one( - {"_id": key}, {"$setOnInsert": v}, upsert=True + data[mode][k]["_id"] = f"{mode}_{k}" + update_tasks.append( + self._data.update_one( + {"_id": key}, {"$setOnInsert": v}, upsert=True + ) ) - if result.upserted_id: - logger.debug(f"\nInserted new document with key: {key}") - data[mode][k]["_id"] = key + await asyncio.gather(*update_tasks) else: - for k, v in tqdm_async(data.items(), desc="Upserting"): - self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + update_tasks = [] + for k, v in data.items(): data[k]["_id"] = k + update_tasks.append( + self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + ) + await asyncio.gather(*update_tasks) async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): res = {} - v = self._data.find_one({"_id": mode + "_" + id}) + v = await self._data.find_one({"_id": mode + "_" + id}) if v: res[id] = v logger.debug(f"llm_response_cache find one by:{id}") @@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage): @dataclass class MongoDocStatusStorage(DocStatusStorage): def __post_init__(self): - client = MongoClient( - os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), ) - database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG")) - self._data = database.get_collection(self.namespace) - logger.info(f"Use MongoDB as doc status {self.namespace}") + client = AsyncIOMotorClient(uri) + database = client.get_database( + os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + ) + + self._collection_name = self.namespace + self._data = database.get_collection(self._collection_name) + + logger.debug(f"Use MongoDB as doc status {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.find_one({"_id": id}) + return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - return list(self._data.find({"_id": {"$in": ids}})) + cursor = self._data.find({"_id": {"$in": ids}}) + return await cursor.to_list() async def filter_keys(self, data: set[str]) -> set[str]: - existing_ids = [ - str(x["_id"]) - for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) - ] - return set([s for s in data if s not in existing_ids]) + cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + existing_ids = {str(x["_id"]) async for x in cursor} + return data - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + update_tasks = [] for k, v in data.items(): - self._data.update_one({"_id": k}, {"$set": v}, upsert=True) data[k]["_id"] = k + update_tasks.append( + self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + ) + await asyncio.gather(*update_tasks) async def drop(self) -> None: """Drop the collection""" @@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage): async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}] - result = list(self._data.aggregate(pipeline)) + cursor = self._data.aggregate(pipeline) + result = await cursor.to_list() counts = {} for doc in result: counts[doc["_id"]] = doc["count"] @@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage): self, status: DocStatus ) -> dict[str, DocProcessingStatus]: """Get all documents by status""" - result = list(self._data.find({"status": status.value})) + cursor = self._data.find({"status": status.value}) + result = await cursor.to_list() return { doc["_id"]: DocProcessingStatus( content=doc["content"], @@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage): global_config=global_config, embedding_func=embedding_func, ) - self.client = AsyncIOMotorClient( - os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), ) - self.db = self.client[ + client = AsyncIOMotorClient(uri) + database = client.get_database( os.environ.get( "MONGO_DATABASE", - mongo_database=config.get("mongodb", "database", fallback="LightRAG"), + config.get("mongodb", "database", fallback="LightRAG"), ) - ] - self.collection = self.db[ - os.environ.get( - "MONGO_KG_COLLECTION", - config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"), - ) - ] + ) + + self._collection_name = self.namespace + self.collection = database.get_collection(self._collection_name) + + logger.debug(f"Use MongoDB as KG {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) # # ------------------------------------------------------------------------- @@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage): self, source_node_id: str ) -> Union[List[Tuple[str, str]], None]: """ - Return a list of (target_id, relation) for direct edges from source_node_id. + Return a list of (source_id, target_id) for direct edges from source_node_id. Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. """ pipeline = [ @@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage): return None edges = result[0].get("edges", []) - return [(e["target"], e["relation"]) for e in edges] + return [(source_node_id, e["target"]) for e in edges] # # ------------------------------------------------------------------------- @@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str): """ - 1) Remove node’s doc entirely. + 1) Remove node's doc entirely. 2) Remove inbound edges from any doc that references node_id. """ # Remove inbound edges from all other docs @@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage): Placeholder for demonstration, raises NotImplementedError. """ raise NotImplementedError("Node embedding is not used in lightrag.") + + # + # ------------------------------------------------------------------------- + # QUERY + # ------------------------------------------------------------------------- + # + + async def get_all_labels(self) -> list[str]: + """ + Get all existing node _id in the database + Returns: + [id1, id2, ...] # Alphabetically sorted id list + """ + # Use MongoDB's distinct and aggregation to get all unique labels + pipeline = [ + {"$group": {"_id": "$_id"}}, # Group by _id + {"$sort": {"_id": 1}}, # Sort alphabetically + ] + + cursor = self.collection.aggregate(pipeline) + labels = [] + async for doc in cursor: + labels.append(doc["_id"]) + return labels + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + """ + Get complete connected subgraph for specified node (including the starting node itself) + + Args: + node_label: Label of the nodes to start from + max_depth: Maximum depth of traversal (default: 5) + + Returns: + KnowledgeGraph object containing nodes and edges of the subgraph + """ + label = node_label + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + try: + if label == "*": + # Get all nodes and edges + async for node_doc in self.collection.find({}): + node_id = str(node_doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_doc.get("_id")], + properties={ + k: v + for k, v in node_doc.items() + if k not in ["_id", "edges"] + }, + ) + ) + seen_nodes.add(node_id) + + # Process edges + for edge in node_doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + else: + # Verify if starting node exists + start_nodes = self.collection.find({"_id": label}) + start_nodes_exist = await start_nodes.to_list(length=1) + if not start_nodes_exist: + logger.warning(f"Starting node with label {label} does not exist!") + return result + + # Use $graphLookup for traversal + pipeline = [ + { + "$match": {"_id": label} + }, # Start with nodes having the specified label + { + "$graphLookup": { + "from": self._collection_name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_nodes", + } + }, + ] + + async for doc in self.collection.aggregate(pipeline): + # Add the start node + node_id = str(doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[ + doc.get( + "_id", + ) + ], + properties={ + k: v + for k, v in doc.items() + if k + not in [ + "_id", + "edges", + "connected_nodes", + "depth", + ] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from start node + for edge in doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + # Add connected nodes and their edges + for connected in doc.get("connected_nodes", []): + node_id = str(connected["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[connected.get("_id")], + properties={ + k: v + for k, v in connected.items() + if k not in ["_id", "edges", "depth"] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from connected nodes + for edge in connected.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except PyMongoError as e: + logger.error(f"MongoDB query failed: {str(e)}") + + return result + + +@dataclass +class MongoVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = None + + def __post_init__(self): + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = kwargs.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) + self.cosine_better_than_threshold = cosine_threshold + + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), + ) + client = AsyncIOMotorClient(uri) + database = client.get_database( + os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + ) + + self._collection_name = self.namespace + self._data = database.get_collection(self._collection_name) + self._max_batch_size = self.global_config["embedding_batch_num"] + + logger.debug(f"Use MongoDB as VDB {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) + + # Ensure vector index exists + self.create_vector_index(uri, database.name, self._collection_name) + + def create_vector_index(self, uri: str, database_name: str, collection_name: str): + """Creates an Atlas Vector Search index.""" + client = MongoClient(uri) + collection = client.get_database(database_name).get_collection( + self._collection_name + ) + + try: + search_index_model = SearchIndexModel( + definition={ + "fields": [ + { + "type": "vector", + "numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions + "path": "vector", + "similarity": "cosine", # Options: euclidean, cosine, dotProduct + } + ] + }, + name="vector_knn_index", + type="vectorSearch", + ) + + collection.create_search_index(search_index_model) + logger.info("Vector index created successfully.") + + except PyMongoError as _: + logger.debug("vector index already exist") + + async def upsert(self, data: dict[str, dict]): + logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") + if not data: + logger.warning("You are inserting an empty data set to vector DB") + return [] + + list_data = [ + { + "_id": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) + embeddings_list = await asyncio.gather(*embedding_tasks) + + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist() + + update_tasks = [] + for doc in list_data: + update_tasks.append( + self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True) + ) + await asyncio.gather(*update_tasks) + + return list_data + + async def query(self, query, top_k=5): + """Queries the vector database using Atlas Vector Search.""" + # Generate the embedding + embedding = await self.embedding_func([query]) + + # Convert numpy array to a list to ensure compatibility with MongoDB + query_vector = embedding[0].tolist() + + # Define the aggregation pipeline with the converted query vector + pipeline = [ + { + "$vectorSearch": { + "index": "vector_knn_index", # Ensure this matches the created index name + "path": "vector", + "queryVector": query_vector, + "numCandidates": 100, # Adjust for performance + "limit": top_k, + } + }, + {"$addFields": {"score": {"$meta": "vectorSearchScore"}}}, + {"$match": {"score": {"$gte": self.cosine_better_than_threshold}}}, + {"$project": {"vector": 0}}, + ] + + # Execute the aggregation pipeline + cursor = self._data.aggregate(pipeline) + results = await cursor.to_list() + + # Format and return the results + return [ + {**doc, "id": doc["_id"], "distance": doc.get("score", None)} + for doc in results + ] + + +def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): + """Check if the collection exists. if not, create it.""" + client = MongoClient(uri) + database = client.get_database(database_name) + + collection_names = database.list_collection_names() + + if collection_name not in collection_names: + database.create_collection(collection_name) + logger.info(f"Created collection: {collection_name}") + else: + logger.debug(f"Collection '{collection_name}' already exists.") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9f74c917..ed0dec29 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -76,6 +76,7 @@ STORAGE_IMPLEMENTATIONS = { "FaissVectorDBStorage", "QdrantVectorDBStorage", "OracleVectorDBStorage", + "MongoVectorDBStorage", ], "required_methods": ["query", "upsert"], }, @@ -140,6 +141,7 @@ STORAGE_ENV_REQUIREMENTS = { "ORACLE_PASSWORD", "ORACLE_CONFIG_DIR", ], + "MongoVectorDBStorage": [], # Document Status Storage Implementations "JsonDocStatusStorage": [], "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], @@ -160,6 +162,7 @@ STORAGES = { "MongoKVStorage": ".kg.mongo_impl", "MongoDocStatusStorage": ".kg.mongo_impl", "MongoGraphStorage": ".kg.mongo_impl", + "MongoVectorDBStorage": ".kg.mongo_impl", "RedisKVStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl", "TiDBKVStorage": ".kg.tidb_impl", From 28c8443ff2e3688ba244f126c892354511ad7c6b Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Fri, 14 Feb 2025 22:50:49 +0100 Subject: [PATCH 08/28] cleaning the mess --- lightrag/lightrag.py | 93 +++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 57 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9f74c917..fcea2c57 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import os import configparser @@ -91,7 +93,7 @@ STORAGE_IMPLEMENTATIONS = { } # Storage implementation environment variable without default value -STORAGE_ENV_REQUIREMENTS = { +STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { # KV Storage Implementations "JsonKVStorage": [], "MongoKVStorage": [], @@ -176,7 +178,7 @@ STORAGES = { } -def lazy_external_import(module_name: str, class_name: str): +def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: """Lazily import a class from an external module based on the package of the caller.""" # Get the caller's module and package import inspect @@ -185,7 +187,7 @@ def lazy_external_import(module_name: str, class_name: str): module = inspect.getmodule(caller_frame) package = module.__package__ if module else None - def import_class(*args, **kwargs): + def import_class(*args: Any, **kwargs: Any): import importlib module = importlib.import_module(module_name, package=package) @@ -302,7 +304,7 @@ class LightRAG: - random_seed: Seed value for reproducibility. """ - embedding_func: EmbeddingFunc = None + embedding_func: Union[EmbeddingFunc, None] = None """Function for computing text embeddings. Must be set before use.""" embedding_batch_num: int = 32 @@ -312,7 +314,7 @@ class LightRAG: """Maximum number of concurrent embedding function calls.""" # LLM Configuration - llm_model_func: callable = None + llm_model_func: Union[Callable[..., object], None] = None """Function for interacting with the large language model (LLM). Must be set before use.""" llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" @@ -443,77 +445,77 @@ class LightRAG: logger.debug(f"LightRAG init with param:\n {_print_config}\n") # Init LLM - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( + self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore self.embedding_func ) # Initialize all storages - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore self._get_storage_class(self.kv_storage) ) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore self.vector_storage ) - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore self.graph_storage ) - self.key_string_value_json_storage_cls = partial( + self.key_string_value_json_storage_cls = partial( # type: ignore self.key_string_value_json_storage_cls, global_config=global_config ) - self.vector_db_storage_cls = partial( + self.vector_db_storage_cls = partial( # type: ignore self.vector_db_storage_cls, global_config=global_config ) - self.graph_storage_cls = partial( + self.graph_storage_cls = partial( # type: ignore self.graph_storage_cls, global_config=global_config ) # Initialize document status storage self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) - self.llm_response_cache = self.key_string_value_json_storage_cls( + self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), embedding_func=self.embedding_func, ) - self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( + self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS ), embedding_func=self.embedding_func, ) - self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( + self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS ), embedding_func=self.embedding_func, ) - self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( + self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION ), embedding_func=self.embedding_func, ) - self.entities_vdb = self.vector_db_storage_cls( + self.entities_vdb = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES ), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) - self.relationships_vdb = self.vector_db_storage_cls( + self.relationships_vdb = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS ), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) - self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( + self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS ), @@ -533,7 +535,7 @@ class LightRAG: ): hashing_kv = self.llm_response_cache else: - hashing_kv = self.key_string_value_json_storage_cls( + hashing_kv = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), @@ -542,7 +544,7 @@ class LightRAG: self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( - self.llm_model_func, + self.llm_model_func, # type: ignore hashing_kv=hashing_kv, **self.llm_model_kwargs, ) @@ -559,68 +561,45 @@ class LightRAG: node_label=nodel_label, max_depth=max_depth ) - def _get_storage_class(self, storage_name: str) -> dict: + def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: import_path = STORAGES[storage_name] storage_class = lazy_external_import(import_path, storage_name) return storage_class - def set_storage_client(self, db_client): - # Deprecated, seting correct value to *_storage of LightRAG insteaded - # Inject db to storage implementation (only tested on Oracle Database) - for storage in [ - self.vector_db_storage_cls, - self.graph_storage_cls, - self.doc_status, - self.full_docs, - self.text_chunks, - self.llm_response_cache, - self.key_string_value_json_storage_cls, - self.chunks_vdb, - self.relationships_vdb, - self.entities_vdb, - self.graph_storage_cls, - self.chunk_entity_relation_graph, - self.llm_response_cache, - ]: - # set client - storage.db = db_client - def insert( self, - string_or_strings: Union[str, list[str]], + input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ): """Sync Insert documents with checkpoint support Args: - string_or_strings: Single document string or list of document strings + input: Single document string or list of document strings split_by_character: if split_by_character is not None, split the string by character, if chunk longer than - chunk_size, split the sub chunk by token size. split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. """ loop = always_get_an_event_loop() return loop.run_until_complete( - self.ainsert(string_or_strings, split_by_character, split_by_character_only) + self.ainsert(input, split_by_character, split_by_character_only) ) async def ainsert( self, - string_or_strings: Union[str, list[str]], + input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ): """Async Insert documents with checkpoint support Args: - string_or_strings: Single document string or list of document strings + input: Single document string or list of document strings split_by_character: if split_by_character is not None, split the string by character, if chunk longer than - chunk_size, split the sub chunk by token size. split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. """ - await self.apipeline_enqueue_documents(string_or_strings) + await self.apipeline_enqueue_documents(input) await self.apipeline_process_enqueue_documents( split_by_character, split_by_character_only ) @@ -677,7 +656,7 @@ class LightRAG: if update_storage: await self._insert_done() - async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]): + async def apipeline_enqueue_documents(self, input: str | list[str]): """ Pipeline for Processing Documents @@ -686,11 +665,11 @@ class LightRAG: 3. Filter out already processed documents 4. Enqueue document in status """ - if isinstance(string_or_strings, str): - string_or_strings = [string_or_strings] + if isinstance(input, str): + input = [input] # 1. Remove duplicate contents from the list - unique_contents = list(set(doc.strip() for doc in string_or_strings)) + unique_contents = list(set(doc.strip() for doc in input)) # 2. Generate document IDs and initial status new_docs: dict[str, Any] = { @@ -872,11 +851,11 @@ class LightRAG: tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) await asyncio.gather(*tasks) - def insert_custom_kg(self, custom_kg: dict): + def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]): loop = always_get_an_event_loop() return loop.run_until_complete(self.ainsert_custom_kg(custom_kg)) - async def ainsert_custom_kg(self, custom_kg: dict): + async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]): update_storage = False try: # Insert chunks into vector storage From 66f555677a378c98efa5c7ce0d6bf7af2cd62345 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Fri, 14 Feb 2025 23:31:27 +0100 Subject: [PATCH 09/28] cleaning the message and project no needed --- README.md | 4 +- examples/lightrag_oracle_demo.py | 19 +- .../OpenWebuiTool/openwebui_tool.py | 358 ------------------ lightrag/base.py | 4 +- lightrag/lightrag.py | 148 +++++--- lightrag/operate.py | 30 +- lightrag/utils.py | 7 +- 7 files changed, 129 insertions(+), 441 deletions(-) delete mode 100644 external_bindings/OpenWebuiTool/openwebui_tool.py diff --git a/README.md b/README.md index 62f21a65..487c65f5 100644 --- a/README.md +++ b/README.md @@ -428,9 +428,9 @@ And using a routine to process news documents. ```python rag = LightRAG(..) -await rag.apipeline_enqueue_documents(string_or_strings) +await rag.apipeline_enqueue_documents(input) # Your routine in loop -await rag.apipeline_process_enqueue_documents(string_or_strings) +await rag.apipeline_process_enqueue_documents(input) ``` ### Separate Keyword Extraction diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index f5269fae..9c90424e 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -113,7 +113,24 @@ async def main(): ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.set_storage_client(db_client=oracle_db) + + for storage in [ + rag.vector_db_storage_cls, + rag.graph_storage_cls, + rag.doc_status, + rag.full_docs, + rag.text_chunks, + rag.llm_response_cache, + rag.key_string_value_json_storage_cls, + rag.chunks_vdb, + rag.relationships_vdb, + rag.entities_vdb, + rag.graph_storage_cls, + rag.chunk_entity_relation_graph, + rag.llm_response_cache, + ]: + # set client + storage.db = oracle_db # Extract and Insert into LightRAG storage with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: diff --git a/external_bindings/OpenWebuiTool/openwebui_tool.py b/external_bindings/OpenWebuiTool/openwebui_tool.py deleted file mode 100644 index 8df3109c..00000000 --- a/external_bindings/OpenWebuiTool/openwebui_tool.py +++ /dev/null @@ -1,358 +0,0 @@ -""" -OpenWebui Lightrag Integration Tool -================================== - -This tool enables the integration and use of Lightrag within the OpenWebui environment, -providing a seamless interface for RAG (Retrieval-Augmented Generation) operations. - -Author: ParisNeo (parisneoai@gmail.com) -Social: - - Twitter: @ParisNeo_AI - - Reddit: r/lollms - - Instagram: https://www.instagram.com/parisneo_ai/ - -License: Apache 2.0 -Copyright (c) 2024-2025 ParisNeo - -This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems). -For more information, visit: https://github.com/ParisNeo/lollms - -Requirements: - - Python 3.8+ - - OpenWebui - - Lightrag -""" - -# Tool version -__version__ = "1.0.0" -__author__ = "ParisNeo" -__author_email__ = "parisneoai@gmail.com" -__description__ = "Lightrag integration for OpenWebui" - - -import requests -import json -from pydantic import BaseModel, Field -from typing import Callable, Any, Literal, Union, List, Tuple - - -class StatusEventEmitter: - def __init__(self, event_emitter: Callable[[dict], Any] = None): - self.event_emitter = event_emitter - - async def emit(self, description="Unknown State", status="in_progress", done=False): - if self.event_emitter: - await self.event_emitter( - { - "type": "status", - "data": { - "status": status, - "description": description, - "done": done, - }, - } - ) - - -class MessageEventEmitter: - def __init__(self, event_emitter: Callable[[dict], Any] = None): - self.event_emitter = event_emitter - - async def emit(self, content="Some message"): - if self.event_emitter: - await self.event_emitter( - { - "type": "message", - "data": { - "content": content, - }, - } - ) - - -class Tools: - class Valves(BaseModel): - LIGHTRAG_SERVER_URL: str = Field( - default="http://localhost:9621/query", - description="The base URL for the LightRag server", - ) - MODE: Literal["naive", "local", "global", "hybrid"] = Field( - default="hybrid", - description="The mode to use for the LightRag query. Options: naive, local, global, hybrid", - ) - ONLY_NEED_CONTEXT: bool = Field( - default=False, - description="If True, only the context is needed from the LightRag response", - ) - DEBUG_MODE: bool = Field( - default=False, - description="If True, debugging information will be emitted", - ) - KEY: str = Field( - default="", - description="Optional Bearer Key for authentication", - ) - MAX_ENTITIES: int = Field( - default=5, - description="Maximum number of entities to keep", - ) - MAX_RELATIONSHIPS: int = Field( - default=5, - description="Maximum number of relationships to keep", - ) - MAX_SOURCES: int = Field( - default=3, - description="Maximum number of sources to keep", - ) - - def __init__(self): - self.valves = self.Valves() - self.headers = { - "Content-Type": "application/json", - "User-Agent": "LightRag-Tool/1.0", - } - - async def query_lightrag( - self, - query: str, - __event_emitter__: Callable[[dict], Any] = None, - ) -> str: - """ - Query the LightRag server and retrieve information. - This function must be called before answering the user question - :params query: The query string to send to the LightRag server. - :return: The response from the LightRag server in Markdown format or raw response. - """ - self.status_emitter = StatusEventEmitter(__event_emitter__) - self.message_emitter = MessageEventEmitter(__event_emitter__) - - lightrag_url = self.valves.LIGHTRAG_SERVER_URL - payload = { - "query": query, - "mode": str(self.valves.MODE), - "stream": False, - "only_need_context": self.valves.ONLY_NEED_CONTEXT, - } - await self.status_emitter.emit("Initializing Lightrag query..") - - if self.valves.DEBUG_MODE: - await self.message_emitter.emit( - "### Debug Mode Active\n\nDebugging information will be displayed.\n" - ) - await self.message_emitter.emit( - "#### Payload Sent to LightRag Server\n```json\n" - + json.dumps(payload, indent=4) - + "\n```\n" - ) - - # Add Bearer Key to headers if provided - if self.valves.KEY: - self.headers["Authorization"] = f"Bearer {self.valves.KEY}" - - try: - await self.status_emitter.emit("Sending request to LightRag server") - - response = requests.post( - lightrag_url, json=payload, headers=self.headers, timeout=120 - ) - response.raise_for_status() - data = response.json() - await self.status_emitter.emit( - status="complete", - description="LightRag query Succeeded", - done=True, - ) - - # Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response - if self.valves.ONLY_NEED_CONTEXT: - try: - if self.valves.DEBUG_MODE: - await self.message_emitter.emit( - "#### LightRag Server Response\n```json\n" - + data["response"] - + "\n```\n" - ) - except Exception as ex: - if self.valves.DEBUG_MODE: - await self.message_emitter.emit( - "#### Exception\n" + str(ex) + "\n" - ) - return f"Exception: {ex}" - return data["response"] - else: - if self.valves.DEBUG_MODE: - await self.message_emitter.emit( - "#### LightRag Server Response\n```json\n" - + data["response"] - + "\n```\n" - ) - await self.status_emitter.emit("Lightrag query success") - return data["response"] - - except requests.exceptions.RequestException as e: - await self.status_emitter.emit( - status="error", - description=f"Error during LightRag query: {str(e)}", - done=True, - ) - return json.dumps({"error": str(e)}) - - def extract_code_blocks( - self, text: str, return_remaining_text: bool = False - ) -> Union[List[dict], Tuple[List[dict], str]]: - """ - This function extracts code blocks from a given text and optionally returns the text without code blocks. - - Parameters: - text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```). - return_remaining_text (bool): If True, also returns the text with code blocks removed. - - Returns: - Union[List[dict], Tuple[List[dict], str]]: - - If return_remaining_text is False: Returns only the list of code block dictionaries - - If return_remaining_text is True: Returns a tuple containing: - * List of code block dictionaries - * String containing the text with all code blocks removed - - Each code block dictionary contains: - - 'index' (int): The index of the code block in the text - - 'file_name' (str): The name of the file extracted from the preceding line, if available - - 'content' (str): The content of the code block - - 'type' (str): The type of the code block - - 'is_complete' (bool): True if the block has a closing tag, False otherwise - """ - remaining = text - bloc_index = 0 - first_index = 0 - indices = [] - text_without_blocks = text - - # Find all code block delimiters - while len(remaining) > 0: - try: - index = remaining.index("```") - indices.append(index + first_index) - remaining = remaining[index + 3 :] - first_index += index + 3 - bloc_index += 1 - except Exception: - if bloc_index % 2 == 1: - index = len(remaining) - indices.append(index) - remaining = "" - - code_blocks = [] - is_start = True - - # Process code blocks and build text without blocks if requested - if return_remaining_text: - text_parts = [] - last_end = 0 - - for index, code_delimiter_position in enumerate(indices): - if is_start: - block_infos = { - "index": len(code_blocks), - "file_name": "", - "section": "", - "content": "", - "type": "", - "is_complete": False, - } - - # Store text before code block if returning remaining text - if return_remaining_text: - text_parts.append(text[last_end:code_delimiter_position].strip()) - - # Check the preceding line for file name - preceding_text = text[:code_delimiter_position].strip().splitlines() - if preceding_text: - last_line = preceding_text[-1].strip() - if last_line.startswith("") and last_line.endswith( - "" - ): - file_name = last_line[ - len("") : -len("") - ].strip() - block_infos["file_name"] = file_name - elif last_line.startswith("## filename:"): - file_name = last_line[len("## filename:") :].strip() - block_infos["file_name"] = file_name - if last_line.startswith("
") and last_line.endswith( - "
" - ): - section = last_line[ - len("
") : -len("
") - ].strip() - block_infos["section"] = section - - sub_text = text[code_delimiter_position + 3 :] - if len(sub_text) > 0: - try: - find_space = sub_text.index(" ") - except Exception: - find_space = int(1e10) - try: - find_return = sub_text.index("\n") - except Exception: - find_return = int(1e10) - next_index = min(find_return, find_space) - if "{" in sub_text[:next_index]: - next_index = 0 - start_pos = next_index - - if code_delimiter_position + 3 < len(text) and text[ - code_delimiter_position + 3 - ] in ["\n", " ", "\t"]: - block_infos["type"] = "language-specific" - else: - block_infos["type"] = sub_text[:next_index] - - if index + 1 < len(indices): - next_pos = indices[index + 1] - code_delimiter_position - if ( - next_pos - 3 < len(sub_text) - and sub_text[next_pos - 3] == "`" - ): - block_infos["content"] = sub_text[ - start_pos : next_pos - 3 - ].strip() - block_infos["is_complete"] = True - else: - block_infos["content"] = sub_text[ - start_pos:next_pos - ].strip() - block_infos["is_complete"] = False - - if return_remaining_text: - last_end = indices[index + 1] + 3 - else: - block_infos["content"] = sub_text[start_pos:].strip() - block_infos["is_complete"] = False - - if return_remaining_text: - last_end = len(text) - - code_blocks.append(block_infos) - is_start = False - else: - is_start = True - - if return_remaining_text: - # Add any remaining text after the last code block - if last_end < len(text): - text_parts.append(text[last_end:].strip()) - # Join all non-code parts with newlines - text_without_blocks = "\n".join(filter(None, text_parts)) - return code_blocks, text_without_blocks - - return code_blocks - - def clean(self, csv_content: str): - lines = csv_content.splitlines() - if lines: - # Remove spaces around headers and ensure no spaces between commas - header = ",".join([col.strip() for col in lines[0].split(",")]) - lines[0] = header # Replace the first line with the cleaned header - csv_content = "\n".join(lines) - return csv_content diff --git a/lightrag/base.py b/lightrag/base.py index e75167c4..8e6a212d 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -83,11 +83,11 @@ class StorageNameSpace: namespace: str global_config: dict[str, Any] - async def index_done_callback(self): + async def index_done_callback(self) -> None: """Commit the storage operations after indexing""" pass - async def query_done_callback(self): + async def query_done_callback(self) -> None: """Commit the storage operations after querying""" pass diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index fcea2c57..b4426cd7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -6,7 +6,7 @@ import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, Callable, Optional, Type, Union, cast +from typing import Any, Callable, Optional, Union, cast from .base import ( BaseGraphStorage, @@ -304,7 +304,7 @@ class LightRAG: - random_seed: Seed value for reproducibility. """ - embedding_func: Union[EmbeddingFunc, None] = None + embedding_func: EmbeddingFunc | None = None """Function for computing text embeddings. Must be set before use.""" embedding_batch_num: int = 32 @@ -344,10 +344,8 @@ class LightRAG: # Extensions addon_params: dict[str, Any] = field(default_factory=dict) - """Dictionary for additional parameters and extensions.""" - # extension - addon_params: dict[str, Any] = field(default_factory=dict) + """Dictionary for additional parameters and extensions.""" convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( convert_response_to_json ) @@ -445,77 +443,74 @@ class LightRAG: logger.debug(f"LightRAG init with param:\n {_print_config}\n") # Init LLM - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore + self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore self.embedding_func ) # Initialize all storages - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore + self.key_string_value_json_storage_cls: type[BaseKVStorage] = ( self._get_storage_class(self.kv_storage) - ) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore + ) # type: ignore + self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class( self.vector_storage - ) - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore + ) # type: ignore + self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class( self.graph_storage - ) - - self.key_string_value_json_storage_cls = partial( # type: ignore + ) # type: ignore + self.key_string_value_json_storage_cls = partial( # type: ignore self.key_string_value_json_storage_cls, global_config=global_config ) - - self.vector_db_storage_cls = partial( # type: ignore + self.vector_db_storage_cls = partial( # type: ignore self.vector_db_storage_cls, global_config=global_config ) - - self.graph_storage_cls = partial( # type: ignore + self.graph_storage_cls = partial( # type: ignore self.graph_storage_cls, global_config=global_config ) # Initialize document status storage self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) - self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore + self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), embedding_func=self.embedding_func, ) - self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore + self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS ), embedding_func=self.embedding_func, ) - self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore + self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS ), embedding_func=self.embedding_func, ) - self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore + self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION ), embedding_func=self.embedding_func, ) - self.entities_vdb = self.vector_db_storage_cls( # type: ignore + self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES ), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) - self.relationships_vdb = self.vector_db_storage_cls( # type: ignore + self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS ), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) - self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore + self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS ), @@ -535,16 +530,16 @@ class LightRAG: ): hashing_kv = self.llm_response_cache else: - hashing_kv = self.key_string_value_json_storage_cls( # type: ignore + hashing_kv = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), embedding_func=self.embedding_func, ) - + self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( - self.llm_model_func, # type: ignore + self.llm_model_func, # type: ignore hashing_kv=hashing_kv, **self.llm_model_kwargs, ) @@ -836,32 +831,32 @@ class LightRAG: raise e async def _insert_done(self): - tasks = [] - for storage_inst in [ - self.full_docs, - self.text_chunks, - self.llm_response_cache, - self.entities_vdb, - self.relationships_vdb, - self.chunks_vdb, - self.chunk_entity_relation_graph, - ]: - if storage_inst is None: - continue - tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) + tasks = [ + cast(StorageNameSpace, storage_inst).index_done_callback() + for storage_inst in [ # type: ignore + self.full_docs, + self.text_chunks, + self.llm_response_cache, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.chunk_entity_relation_graph, + ] + if storage_inst is not None + ] await asyncio.gather(*tasks) - def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]): + def insert_custom_kg(self, custom_kg: dict[str, Any]): loop = always_get_an_event_loop() return loop.run_until_complete(self.ainsert_custom_kg(custom_kg)) - async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]): + async def ainsert_custom_kg(self, custom_kg: dict[str, Any]): update_storage = False try: # Insert chunks into vector storage - all_chunks_data = {} - chunk_to_source_map = {} - for chunk_data in custom_kg.get("chunks", []): + all_chunks_data: dict[str, dict[str, str]] = {} + chunk_to_source_map: dict[str, str] = {} + for chunk_data in custom_kg.get("chunks", {}): chunk_content = chunk_data["content"] source_id = chunk_data["source_id"] chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-") @@ -871,13 +866,13 @@ class LightRAG: chunk_to_source_map[source_id] = chunk_id update_storage = True - if self.chunks_vdb is not None and all_chunks_data: + if all_chunks_data: await self.chunks_vdb.upsert(all_chunks_data) - if self.text_chunks is not None and all_chunks_data: + if all_chunks_data: await self.text_chunks.upsert(all_chunks_data) # Insert entities into knowledge graph - all_entities_data = [] + all_entities_data: list[dict[str, str]] = [] for entity_data in custom_kg.get("entities", []): entity_name = f'"{entity_data["entity_name"].upper()}"' entity_type = entity_data.get("entity_type", "UNKNOWN") @@ -893,7 +888,7 @@ class LightRAG: ) # Prepare node data - node_data = { + node_data: dict[str, str] = { "entity_type": entity_type, "description": description, "source_id": source_id, @@ -907,7 +902,7 @@ class LightRAG: update_storage = True # Insert relationships into knowledge graph - all_relationships_data = [] + all_relationships_data: list[dict[str, str]] = [] for relationship_data in custom_kg.get("relationships", []): src_id = f'"{relationship_data["src_id"].upper()}"' tgt_id = f'"{relationship_data["tgt_id"].upper()}"' @@ -949,7 +944,7 @@ class LightRAG: "source_id": source_id, }, ) - edge_data = { + edge_data: dict[str, str] = { "src_id": src_id, "tgt_id": tgt_id, "description": description, @@ -959,19 +954,17 @@ class LightRAG: update_storage = True # Insert entities into vector storage if needed - if self.entities_vdb is not None: - data_for_vdb = { + data_for_vdb = { compute_mdhash_id(dp["entity_name"], prefix="ent-"): { "content": dp["entity_name"] + dp["description"], "entity_name": dp["entity_name"], } for dp in all_entities_data } - await self.entities_vdb.upsert(data_for_vdb) + await self.entities_vdb.upsert(data_for_vdb) # Insert relationships into vector storage if needed - if self.relationships_vdb is not None: - data_for_vdb = { + data_for_vdb = { compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], @@ -982,18 +975,49 @@ class LightRAG: } for dp in all_relationships_data } - await self.relationships_vdb.upsert(data_for_vdb) + await self.relationships_vdb.upsert(data_for_vdb) + finally: if update_storage: await self._insert_done() - def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()): + def query( + self, + query: str, + param: QueryParam = QueryParam(), + prompt: str | None = None + ) -> str: + """ + Perform a sync query. + + Args: + query (str): The query to be executed. + param (QueryParam): Configuration parameters for query execution. + prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. + + Returns: + str: The result of the query execution. + """ loop = always_get_an_event_loop() - return loop.run_until_complete(self.aquery(query, prompt, param)) + return loop.run_until_complete(self.aquery(query, param, prompt)) async def aquery( - self, query: str, prompt: str = "", param: QueryParam = QueryParam() - ): + self, + query: str, + param: QueryParam = QueryParam(), + prompt: str | None = None, + ) -> str: + """ + Perform a async query. + + Args: + query (str): The query to be executed. + param (QueryParam): Configuration parameters for query execution. + prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. + + Returns: + str: The result of the query execution. + """ if param.mode in ["local", "global", "hybrid"]: response = await kg_query( query, diff --git a/lightrag/operate.py b/lightrag/operate.py index 04aad0d4..a961cfd9 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -295,8 +295,8 @@ async def extract_entities( knowledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, - global_config: dict, - llm_response_cache: BaseKVStorage = None, + global_config: dict[str, str], + llm_response_cache: BaseKVStorage | None = None, ) -> Union[BaseGraphStorage, None]: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -563,15 +563,15 @@ async def extract_entities( async def kg_query( - query, + query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, - prompt: str = "", + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + prompt: str | None = None, ) -> str: # Handle cache use_model_func = global_config["llm_model_func"] @@ -681,8 +681,8 @@ async def kg_query( async def extract_keywords_only( text: str, param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ Extract high-level and low-level keywords from the given 'text' using the LLM. @@ -778,8 +778,8 @@ async def mix_kg_vector_query( chunks_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, ) -> str: """ Hybrid retrieval implementation combining knowledge graph and vector search. @@ -1499,12 +1499,12 @@ def combine_contexts(entities, relationships, sources): async def naive_query( - query, + query: str, chunks_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, ): # Handle cache use_model_func = global_config["llm_model_func"] @@ -1606,8 +1606,8 @@ async def kg_query_with_keywords( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, ) -> str: """ Refactored kg_query that does NOT extract keywords by itself. diff --git a/lightrag/utils.py b/lightrag/utils.py index 9df325ca..c94e23cb 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -128,7 +128,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str: return hashlib.md5(args_str.encode()).hexdigest() -def compute_mdhash_id(content, prefix: str = ""): +def compute_mdhash_id(content: str, prefix: str = "") -> str: + """ + Compute a unique ID for a given content string. + + The ID is a combination of the given prefix and the MD5 hash of the content string. + """ return prefix + md5(content.encode()).hexdigest() From dfa8681924a20d479c7348f786863de62c999cfe Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Fri, 14 Feb 2025 23:33:59 +0100 Subject: [PATCH 10/28] code clean --- lightrag/lightrag.py | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b4426cd7..15bb6cc2 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -536,7 +536,7 @@ class LightRAG: ), embedding_func=self.embedding_func, ) - + self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, # type: ignore @@ -955,37 +955,34 @@ class LightRAG: # Insert entities into vector storage if needed data_for_vdb = { - compute_mdhash_id(dp["entity_name"], prefix="ent-"): { - "content": dp["entity_name"] + dp["description"], - "entity_name": dp["entity_name"], - } - for dp in all_entities_data + compute_mdhash_id(dp["entity_name"], prefix="ent-"): { + "content": dp["entity_name"] + dp["description"], + "entity_name": dp["entity_name"], } + for dp in all_entities_data + } await self.entities_vdb.upsert(data_for_vdb) # Insert relationships into vector storage if needed data_for_vdb = { - compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { - "src_id": dp["src_id"], - "tgt_id": dp["tgt_id"], - "content": dp["keywords"] - + dp["src_id"] - + dp["tgt_id"] - + dp["description"], - } - for dp in all_relationships_data + compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { + "src_id": dp["src_id"], + "tgt_id": dp["tgt_id"], + "content": dp["keywords"] + + dp["src_id"] + + dp["tgt_id"] + + dp["description"], } + for dp in all_relationships_data + } await self.relationships_vdb.upsert(data_for_vdb) - + finally: if update_storage: await self._insert_done() def query( - self, - query: str, - param: QueryParam = QueryParam(), - prompt: str | None = None + self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None ) -> str: """ Perform a sync query. @@ -997,7 +994,7 @@ class LightRAG: Returns: str: The result of the query execution. - """ + """ loop = always_get_an_event_loop() return loop.run_until_complete(self.aquery(query, param, prompt)) From cf6e327bf4a338eace34b2e0b829d12a62aa6561 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Fri, 14 Feb 2025 23:42:52 +0100 Subject: [PATCH 11/28] added type and cleaned code --- lightrag/base.py | 34 ++++++++++++++++++++++++---------- lightrag/lightrag.py | 16 +++++++++++----- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 8e6a212d..b1fe50a2 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -95,7 +95,7 @@ class StorageNameSpace: @dataclass class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc - meta_fields: set = field(default_factory=set) + meta_fields: set[str] = field(default_factory=set) async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: raise NotImplementedError @@ -130,50 +130,64 @@ class BaseKVStorage(StorageNameSpace): @dataclass class BaseGraphStorage(StorageNameSpace): - embedding_func: EmbeddingFunc = None - + embedding_func: EmbeddingFunc | None = None + """Check if a node exists in the graph.""" async def has_node(self, node_id: str) -> bool: raise NotImplementedError + """Check if an edge exists in the graph.""" async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: raise NotImplementedError + """Get the degree of a node.""" async def node_degree(self, node_id: str) -> int: raise NotImplementedError + """Get the degree of an edge.""" async def edge_degree(self, src_id: str, tgt_id: str) -> int: raise NotImplementedError - async def get_node(self, node_id: str) -> Union[dict, None]: + """Get a node by its id.""" + async def get_node(self, node_id: str) -> Union[dict[str, str], None]: raise NotImplementedError + """Get an edge by its source and target node ids.""" async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + ) -> Union[dict[str, str], None]: raise NotImplementedError + """Get all edges connected to a node.""" async def get_node_edges( self, source_node_id: str ) -> Union[list[tuple[str, str]], None]: raise NotImplementedError - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + """Upsert a node into the graph.""" + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: raise NotImplementedError + """Upsert an edge into the graph.""" async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + self, source_node_id: str, + target_node_id: str, + edge_data: dict[str, str] + ) -> None: raise NotImplementedError - async def delete_node(self, node_id: str): + """Delete a node from the graph.""" + async def delete_node(self, node_id: str) -> None: raise NotImplementedError - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + """Embed nodes using an algorithm.""" + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError("Node embedding is not used in lightrag.") + """Get all labels in the graph.""" async def get_all_labels(self) -> list[str]: raise NotImplementedError + """Get a knowledge graph of a node.""" async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 15bb6cc2..593b5734 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -6,7 +6,7 @@ import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, Callable, Optional, Union, cast +from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast from .base import ( BaseGraphStorage, @@ -983,7 +983,7 @@ class LightRAG: def query( self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None - ) -> str: + ) -> str | Iterator[str]: """ Perform a sync query. @@ -1003,7 +1003,7 @@ class LightRAG: query: str, param: QueryParam = QueryParam(), prompt: str | None = None, - ) -> str: + ) -> str | AsyncIterator[str]: """ Perform a async query. @@ -1081,7 +1081,10 @@ class LightRAG: return response def query_with_separate_keyword_extraction( - self, query: str, prompt: str, param: QueryParam = QueryParam() + self, + query: str, + prompt: str, + param: QueryParam = QueryParam() ): """ 1. Extract keywords from the 'query' using new function in operate.py. @@ -1093,7 +1096,10 @@ class LightRAG: ) async def aquery_with_separate_keyword_extraction( - self, query: str, prompt: str, param: QueryParam = QueryParam() + self, + query: str, + prompt: str, + param: QueryParam = QueryParam() ): """ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. From e6520ad6a22c1722b6c29454cf81273476d7118e Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Fri, 14 Feb 2025 23:49:39 +0100 Subject: [PATCH 12/28] added typing --- lightrag/base.py | 2 +- lightrag/lightrag.py | 14 +++++++++----- lightrag/operate.py | 6 +++--- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index b1fe50a2..e70dddd1 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -96,7 +96,7 @@ class StorageNameSpace: class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc meta_fields: set[str] = field(default_factory=set) - + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: raise NotImplementedError diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 593b5734..8a65a46c 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -982,7 +982,10 @@ class LightRAG: await self._insert_done() def query( - self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None + self, + query: str, + param: QueryParam = QueryParam(), + prompt: str | None = None ) -> str | Iterator[str]: """ Perform a sync query. @@ -996,7 +999,8 @@ class LightRAG: str: The result of the query execution. """ loop = always_get_an_event_loop() - return loop.run_until_complete(self.aquery(query, param, prompt)) + + return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore async def aquery( self, @@ -1455,7 +1459,7 @@ class LightRAG: async def get_entity_info( self, entity_name: str, include_vector_data: bool = False - ): + ) -> dict[str, str | None | dict[str, str]]: """Get detailed information of an entity Args: @@ -1475,7 +1479,7 @@ class LightRAG: node_data = await self.chunk_entity_relation_graph.get_node(entity_name) source_id = node_data.get("source_id") if node_data else None - result = { + result: dict[str, str | None | dict[str, str]] = { "entity_name": entity_name, "source_id": source_id, "graph_data": node_data, @@ -1531,7 +1535,7 @@ class LightRAG: ) source_id = edge_data.get("source_id") if edge_data else None - result = { + result: dict[str, str | None | dict[str, str]] = { "src_entity": src_entity, "tgt_entity": tgt_entity, "source_id": source_id, diff --git a/lightrag/operate.py b/lightrag/operate.py index a961cfd9..d6cc9f3c 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2,7 +2,7 @@ import asyncio import json import re from tqdm.asyncio import tqdm as tqdm_async -from typing import Any, Union +from typing import Any, AsyncIterator, Union from collections import Counter, defaultdict from .utils import ( logger, @@ -780,7 +780,7 @@ async def mix_kg_vector_query( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, -) -> str: +) -> str | AsyncIterator[str]: """ Hybrid retrieval implementation combining knowledge graph and vector search. @@ -1505,7 +1505,7 @@ async def naive_query( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, -): +) -> str | AsyncIterator[str]: # Handle cache use_model_func = global_config["llm_model_func"] args_hash = compute_args_hash(query_param.mode, query, cache_type="query") From 7e526d343696b6345147e0b9cadd8aa2ecc67d9b Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Fri, 14 Feb 2025 23:52:05 +0100 Subject: [PATCH 13/28] cleaned code --- lightrag/base.py | 23 ++++++++++++++++++----- lightrag/lightrag.py | 25 ++++++++----------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index e70dddd1..2b3e5cad 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -96,7 +96,7 @@ class StorageNameSpace: class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc meta_fields: set[str] = field(default_factory=set) - + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: raise NotImplementedError @@ -132,62 +132,75 @@ class BaseKVStorage(StorageNameSpace): class BaseGraphStorage(StorageNameSpace): embedding_func: EmbeddingFunc | None = None """Check if a node exists in the graph.""" + async def has_node(self, node_id: str) -> bool: raise NotImplementedError """Check if an edge exists in the graph.""" + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: raise NotImplementedError """Get the degree of a node.""" + async def node_degree(self, node_id: str) -> int: raise NotImplementedError """Get the degree of an edge.""" + async def edge_degree(self, src_id: str, tgt_id: str) -> int: raise NotImplementedError """Get a node by its id.""" + async def get_node(self, node_id: str) -> Union[dict[str, str], None]: raise NotImplementedError """Get an edge by its source and target node ids.""" + async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict[str, str], None]: raise NotImplementedError """Get all edges connected to a node.""" + async def get_node_edges( self, source_node_id: str ) -> Union[list[tuple[str, str]], None]: raise NotImplementedError """Upsert a node into the graph.""" + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: raise NotImplementedError """Upsert an edge into the graph.""" + async def upsert_edge( - self, source_node_id: str, - target_node_id: str, - edge_data: dict[str, str] + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: raise NotImplementedError """Delete a node from the graph.""" + async def delete_node(self, node_id: str) -> None: raise NotImplementedError """Embed nodes using an algorithm.""" - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError("Node embedding is not used in lightrag.") """Get all labels in the graph.""" + async def get_all_labels(self) -> list[str]: raise NotImplementedError """Get a knowledge graph of a node.""" + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8a65a46c..08855d60 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -982,10 +982,7 @@ class LightRAG: await self._insert_done() def query( - self, - query: str, - param: QueryParam = QueryParam(), - prompt: str | None = None + self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None ) -> str | Iterator[str]: """ Perform a sync query. @@ -999,8 +996,8 @@ class LightRAG: str: The result of the query execution. """ loop = always_get_an_event_loop() - - return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore + + return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore async def aquery( self, @@ -1085,10 +1082,7 @@ class LightRAG: return response def query_with_separate_keyword_extraction( - self, - query: str, - prompt: str, - param: QueryParam = QueryParam() + self, query: str, prompt: str, param: QueryParam = QueryParam() ): """ 1. Extract keywords from the 'query' using new function in operate.py. @@ -1100,10 +1094,7 @@ class LightRAG: ) async def aquery_with_separate_keyword_extraction( - self, - query: str, - prompt: str, - param: QueryParam = QueryParam() + self, query: str, prompt: str, param: QueryParam = QueryParam() ): """ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. @@ -1127,8 +1118,8 @@ class LightRAG: ), ) - param.hl_keywords = (hl_keywords,) - param.ll_keywords = (ll_keywords,) + param.hl_keywords = hl_keywords + param.ll_keywords = ll_keywords # --------------------- # STEP 2: Final Query Logic @@ -1156,7 +1147,7 @@ class LightRAG: self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), global_config=asdict(self), - embedding_func=self.embedding_funcne, + embedding_func=self.embedding_func, ), ) elif param.mode == "naive": From 805da7b95b75d0b955913633e97f23ab1c64b0e8 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 00:01:21 +0100 Subject: [PATCH 14/28] cleaned code --- lightrag/base.py | 8 +++- lightrag/kg/faiss_impl.py | 2 +- lightrag/kg/nano_vector_db_impl.py | 2 +- lightrag/lightrag.py | 68 ++++++------------------------ 4 files changed, 21 insertions(+), 59 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 2b3e5cad..29335494 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -106,10 +106,16 @@ class BaseVectorStorage(StorageNameSpace): """ raise NotImplementedError + async def delete_entity(self, entity_name: str) -> None: + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + raise NotImplementedError + @dataclass class BaseKVStorage(StorageNameSpace): - embedding_func: EmbeddingFunc + embedding_func: EmbeddingFunc | None = None async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: raise NotImplementedError diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 0dca9e4c..9a5f7e4e 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") await self.delete([entity_id]) - async def delete_entity_relation(self, entity_name: str): + async def delete_entity_relation(self, entity_name: str) -> None: """ Delete relations for a given entity by scanning metadata. """ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 2db8f72a..5d786646 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") - async def delete_entity_relation(self, entity_name: str): + async def delete_entity_relation(self, entity_name: str) -> None: try: relations = [ dp diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 08855d60..ce86e938 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1095,7 +1095,7 @@ class LightRAG: async def aquery_with_separate_keyword_extraction( self, query: str, prompt: str, param: QueryParam = QueryParam() - ): + ) -> str | AsyncIterator[str]: """ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. @@ -1196,12 +1196,7 @@ class LightRAG: return response async def _query_done(self): - tasks = [] - for storage_inst in [self.llm_response_cache]: - if storage_inst is None: - continue - tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) - await asyncio.gather(*tasks) + await self.llm_response_cache.index_done_callback() def delete_by_entity(self, entity_name: str): loop = always_get_an_event_loop() @@ -1223,16 +1218,16 @@ class LightRAG: logger.error(f"Error while deleting entity '{entity_name}': {e}") async def _delete_by_entity_done(self): - tasks = [] - for storage_inst in [ - self.entities_vdb, - self.relationships_vdb, - self.chunk_entity_relation_graph, - ]: - if storage_inst is None: - continue - tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) - await asyncio.gather(*tasks) + await asyncio.gather( + *[ + cast(StorageNameSpace, storage_inst).index_done_callback() + for storage_inst in [ # type: ignore + self.entities_vdb, + self.relationships_vdb, + self.chunk_entity_relation_graph, + ] + ] + ) def _get_content_summary(self, content: str, max_length: int = 100) -> str: """Get summary of document content @@ -1444,10 +1439,6 @@ class LightRAG: except Exception as e: logger.error(f"Error while deleting document {doc_id}: {e}") - def delete_by_doc_id(self, doc_id: str): - """Synchronous version of adelete""" - return asyncio.run(self.adelete_by_doc_id(doc_id)) - async def get_entity_info( self, entity_name: str, include_vector_data: bool = False ) -> dict[str, str | None | dict[str, str]]: @@ -1484,21 +1475,6 @@ class LightRAG: return result - def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False): - """Synchronous version of getting entity information - - Args: - entity_name: Entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - """ - try: - import tracemalloc - - tracemalloc.start() - return asyncio.run(self.get_entity_info(entity_name, include_vector_data)) - finally: - tracemalloc.stop() - async def get_relation_info( self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ): @@ -1540,23 +1516,3 @@ class LightRAG: result["vector_data"] = vector_data[0] if vector_data else None return result - - def get_relation_info_sync( - self, src_entity: str, tgt_entity: str, include_vector_data: bool = False - ): - """Synchronous version of getting relationship information - - Args: - src_entity: Source entity name (no need for quotes) - tgt_entity: Target entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - """ - try: - import tracemalloc - - tracemalloc.start() - return asyncio.run( - self.get_relation_info(src_entity, tgt_entity, include_vector_data) - ) - finally: - tracemalloc.stop() From 621540a54e9dd043b2dab0a366cf214641fb5dd5 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 00:10:37 +0100 Subject: [PATCH 15/28] cleaned code --- lightrag/base.py | 2 ++ lightrag/lightrag.py | 6 ++++-- lightrag/llm.py | 6 +++++- lightrag/operate.py | 2 +- lightrag/utils.py | 24 ++++++++++++++---------- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 29335494..42f6d1e9 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -107,9 +107,11 @@ class BaseVectorStorage(StorageNameSpace): raise NotImplementedError async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" raise NotImplementedError diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ce86e938..af241f65 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -524,7 +524,6 @@ class LightRAG: embedding_func=None, ) - # What's for, Is this nessisary ? if self.llm_response_cache and hasattr( self.llm_response_cache, "global_config" ): @@ -1252,7 +1251,7 @@ class LightRAG: """ return await self.doc_status.get_status_counts() - async def adelete_by_doc_id(self, doc_id: str): + async def adelete_by_doc_id(self, doc_id: str) -> None: """Delete a document and all its related data Args: @@ -1269,6 +1268,9 @@ class LightRAG: # 2. Get all related chunks chunks = await self.text_chunks.get_by_id(doc_id) + if not chunks: + return + chunk_ids = list(chunks.keys()) logger.debug(f"Found {len(chunk_ids)} chunks to delete") diff --git a/lightrag/llm.py b/lightrag/llm.py index 3ca17725..b4baef68 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -66,7 +66,11 @@ class MultiModel: return self._models[self._current_model] async def llm_model_func( - self, prompt, system_prompt=None, history_messages=[], **kwargs + self, + prompt: str, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] = [], + **kwargs: Any, ) -> str: kwargs.pop("model", None) # stop from overwriting the custom model name kwargs.pop("keyword_extraction", None) diff --git a/lightrag/operate.py b/lightrag/operate.py index d6cc9f3c..f2bd6218 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1608,7 +1608,7 @@ async def kg_query_with_keywords( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, -) -> str: +) -> str | AsyncIterator[str]: """ Refactored kg_query that does NOT extract keywords by itself. It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. diff --git a/lightrag/utils.py b/lightrag/utils.py index c94e23cb..9b18d0c2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -9,7 +9,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union, List, Optional +from typing import Any, Callable, Union, List, Optional import xml.etree.ElementTree as ET import bs4 @@ -67,7 +67,7 @@ class EmbeddingFunc: @dataclass class ReasoningResponse: - reasoning_content: str + reasoning_content: str | None response_content: str tag: str @@ -109,7 +109,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]: raise e from None -def compute_args_hash(*args, cache_type: str = None) -> str: +def compute_args_hash(*args: Any, cache_type: str | None = None) -> str: """Compute a hash for the given arguments. Args: *args: Arguments to hash @@ -220,11 +220,13 @@ def clean_str(input: Any) -> str: return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) -def is_float_regex(value): +def is_float_regex(value: str) -> bool: return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) -def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): +def truncate_list_by_token_size( + list_data: list[Any], key: Callable[[Any], str], max_token_size: int +) -> list[int]: """Truncate a list of data by token size""" if max_token_size <= 0: return [] @@ -334,7 +336,7 @@ def xml_to_json(xml_file): return None -def process_combine_contexts(hl, ll): +def process_combine_contexts(hl: str, ll: str): header = None list_hl = csv_string_to_list(hl.strip()) list_ll = csv_string_to_list(ll.strip()) @@ -640,7 +642,9 @@ def exists_func(obj, func_name: str) -> bool: return False -def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str: +def get_conversation_turns( + conversation_history: list[dict[str, Any]], num_turns: int +) -> str: """ Process conversation history to get the specified number of complete turns. @@ -652,8 +656,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> Formatted string of the conversation history """ # Group messages into turns - turns = [] - messages = [] + turns: list[list[dict[str, Any]]] = [] + messages: list[dict[str, Any]] = [] # First, filter out keyword extraction messages for msg in conversation_history: @@ -687,7 +691,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> turns = turns[-num_turns:] # Format the turns into a string - formatted_turns = [] + formatted_turns: list[str] = [] for turn in turns: formatted_turns.extend( [f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"] From 50919442e906623974a9888c12ec7b6c6ad3335c Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 00:56:45 +0100 Subject: [PATCH 16/28] Improve git and docker ignore --- .dockerignore | 64 ++++++++++++++++++++++++++++++++++++++++- .gitignore | 79 +++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 120 insertions(+), 23 deletions(-) diff --git a/.dockerignore b/.dockerignore index 4c49bd78..f1a82ffa 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,63 @@ -.env +# Python-related files and directories +__pycache__ +.cache + +# Virtual environment directories +*.venv + +# Env +env/ +*.env* +.env_example + +# Distribution / build files +site +dist/ +build/ +.eggs/ +*.egg-info/ +*.tgz +*.tar.gz + +# Exclude siles and folders +*.yml +.dockerignore +Dockerfile +Makefile + +# Exclude other projects +/tests +/scripts + +# Python version manager file +.python-version + +# Reports +*.coverage/ +*.log +log/ +*.logfire + +# Cache +.cache/ +.mypy_cache +.pytest_cache +.ruff_cache +.gradio +.logfire +temp/ + +# MacOS-related files +.DS_Store + +# VS Code settings (local configuration files) +.vscode + +# file +TODO.md + +# Exclude Git-related files +.git +.github +.gitignore +.pre-commit-config.yaml diff --git a/.gitignore b/.gitignore index 2d9a41f3..2d074372 100644 --- a/.gitignore +++ b/.gitignore @@ -1,26 +1,61 @@ -__pycache__ -*.egg-info +# Python-related files +__pycache__/ +*.py[cod] +*.egg-info/ +.eggs/ +*.tgz +*.tar.gz +*.ini # Remove config.ini from repo + +# Virtual Environment +.venv/ +env/ +venv/ +*.env* +.env_example + +# Build / Distribution +dist/ +build/ +site/ + +# Logs / Reports +*.log +*.logfire +*.coverage/ +log/ + +# Caches +.cache/ +.mypy_cache/ +.pytest_cache/ +.ruff_cache/ +.gradio/ +temp/ + +# IDE / Editor Files +.idea/ +.vscode/ +.vscode/settings.json + +# Framework-specific files +local_neo4jWorkDir/ +neo4jWorkDir/ + +# Data & Storage +inputs/ +rag_storage/ +examples/input/ +examples/output/ + +# Miscellaneous +.DS_Store +TODO.md +ignore_this.txt +*.ignore.* + +# Project-specific files dickens/ book.txt lightrag-dev/ -.idea/ -dist/ -env/ -local_neo4jWorkDir/ -neo4jWorkDir/ -ignore_this.txt -.venv/ -*.ignore.* -.ruff_cache/ gui/ -*.log -.vscode -inputs -rag_storage -.env -venv/ -examples/input/ -examples/output/ -.DS_Store -#Remove config.ini from repo -*.ini From ad88ba03bf8e531e010f053ecce694d0f343f13a Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 15 Feb 2025 11:07:38 +0800 Subject: [PATCH 17/28] docs: reorganize Ollama emulation API documentation for better readability --- lightrag/api/README.md | 110 ++++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 50 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 8e5a61d5..7e4fda7e 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -74,30 +74,38 @@ LLM_MODEL=model_name_of_azure_ai LLM_BINDING_API_KEY=api_key_of_azure_ai ``` -### About Ollama API +### 3. Install Lightrag as a Linux Service -We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily. +Create a your service file `lightrag.sevice` from the sample file : `lightrag.sevice.example`. Modified the WorkingDirectoryand EexecStart in the service file: -#### Choose Query mode in chat - -A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include: - -``` -/local -/global -/hybrid -/naive -/mix -/bypass +```text +Description=LightRAG Ollama Service +WorkingDirectory= +ExecStart=/lightrag/api/lightrag-api ``` -For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。 +Modify your service startup script: `lightrag-api`. Change you python virtual environment activation command as needed: -"/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the LightRAG query results. (If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix) +```shell +#!/bin/bash + +# your python virtual environment activation +source /home/netman/lightrag-xyj/venv/bin/activate +# start lightrag api server +lightrag-server +``` + +Install LightRAG service. If your system is Ubuntu, the following commands will work: + +```shell +sudo cp lightrag.service /etc/systemd/system/ +sudo systemctl daemon-reload +sudo systemctl start lightrag.service +sudo systemctl status lightrag.service +sudo systemctl enable lightrag.service +``` -#### Connect Open WebUI to LightRAG -After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. ## Configuration @@ -378,7 +386,7 @@ curl -X DELETE "http://localhost:9621/documents" #### GET /api/version -Get Ollama version information +Get Ollama version information. ```bash curl http://localhost:9621/api/version @@ -386,7 +394,7 @@ curl http://localhost:9621/api/version #### GET /api/tags -Get Ollama available models +Get Ollama available models. ```bash curl http://localhost:9621/api/tags @@ -394,7 +402,7 @@ curl http://localhost:9621/api/tags #### POST /api/chat -Handle chat completion requests +Handle chat completion requests. Routes user queries through LightRAG by selecting query mode based on query prefix. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to underlying LLM. ```shell curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/json" -d \ @@ -403,6 +411,10 @@ curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/jso > For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md) +#### POST /api/generate + +Handle generate completion requests. For compatibility purpose, the request is not processed by LightRAG, and will be handled by underlying LLM model. + ### Utility Endpoints #### GET /health @@ -412,7 +424,35 @@ Check server health and configuration. curl "http://localhost:9621/health" ``` +## Ollama Emulation + +We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily. + +### Connect Open WebUI to LightRAG + +After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. You'd better install LightRAG as service for this use case. + +Open WebUI's use LLM to do the session title and session keyword generation task. So the Ollama chat chat completion API detects and forwards OpenWebUI session-related requests directly to underlying LLM. + +### Choose Query mode in chat + +A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include: + +``` +/local +/global +/hybrid +/naive +/mix +/bypass +``` + +For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。 + +"/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the chat history. If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix. + ## Development + Contribute to the project: [Guide](contributor-readme.MD) ### Running in Development Mode @@ -471,33 +511,3 @@ This intelligent caching mechanism: - This optimization significantly reduces startup time for subsequent runs - The working directory (`--working-dir`) stores the vectorized documents database -## Install Lightrag as a Linux Service - -Create a your service file `lightrag.sevice` from the sample file : `lightrag.sevice.example`. Modified the WorkingDirectoryand EexecStart in the service file: - -```text -Description=LightRAG Ollama Service -WorkingDirectory= -ExecStart=/lightrag/api/lightrag-api -``` - -Modify your service startup script: `lightrag-api`. Change you python virtual environment activation command as needed: - -```shell -#!/bin/bash - -# your python virtual environment activation -source /home/netman/lightrag-xyj/venv/bin/activate -# start lightrag api server -lightrag-server -``` - -Install LightRAG service. If your system is Ubuntu, the following commands will work: - -```shell -sudo cp lightrag.service /etc/systemd/system/ -sudo systemctl daemon-reload -sudo systemctl start lightrag.service -sudo systemctl status lightrag.service -sudo systemctl enable lightrag.service -``` From 0db0419c6dc38651589db6121245acad3df74eeb Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 15 Feb 2025 11:08:54 +0800 Subject: [PATCH 18/28] Fix linting --- lightrag/api/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 7e4fda7e..06510618 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -510,4 +510,3 @@ This intelligent caching mechanism: - Only new documents in the input directory will be processed - This optimization significantly reduces startup time for subsequent runs - The working directory (`--working-dir`) stores the vectorized documents database - From 2985d88f976ab63b6ce31d1c9929506e37c288ae Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 15 Feb 2025 11:39:10 +0800 Subject: [PATCH 19/28] refactor: improve CORS and streaming response headers - Add configurable CORS origins - Remove duplicate CORS headers - Add X-Accel-Buffering header - Update env example file - Clean up header configurations --- .env.example | 13 +++++++------ lightrag/api/lightrag_server.py | 16 +++++++++++----- lightrag/api/ollama_api.py | 8 ++------ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/.env.example b/.env.example index 022bd63d..2701335a 100644 --- a/.env.example +++ b/.env.example @@ -1,12 +1,13 @@ ### Server Configuration -#HOST=0.0.0.0 -#PORT=9621 -#NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances +# HOST=0.0.0.0 +# PORT=9621 +# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances +# CORS_ORIGINS=http://localhost:3000,http://localhost:8080 ### Optional SSL Configuration -#SSL=true -#SSL_CERTFILE=/path/to/cert.pem -#SSL_KEYFILE=/path/to/key.pem +# SSL=true +# SSL_CERTFILE=/path/to/cert.pem +# SSL_KEYFILE=/path/to/key.pem ### Security (empty for no api-key is needed) # LIGHTRAG_API_KEY=your-secure-api-key-here diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index ce182bc1..19552faf 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -847,10 +847,19 @@ def create_app(args): lifespan=lifespan, ) + def get_cors_origins(): + """Get allowed origins from environment variable + Returns a list of allowed origins, defaults to ["*"] if not set + """ + origins_str = os.getenv("CORS_ORIGINS", "*") + if origins_str == "*": + return ["*"] + return [origin.strip() for origin in origins_str.split(",")] + # Add CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=get_cors_origins(), allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -1377,10 +1386,7 @@ def create_app(args): "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - "X-Accel-Buffering": "no", # Disable Nginx buffering + "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应 }, ) except Exception as e: diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 01a883ca..94703dee 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -316,9 +316,7 @@ class OllamaAPI: "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", + "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应 }, ) else: @@ -534,9 +532,7 @@ class OllamaAPI: "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", + "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应 }, ) else: From 8fdbcb0d3f749741daa57dfbd346000f1b4e652f Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 15 Feb 2025 11:46:47 +0800 Subject: [PATCH 20/28] fix: reorganize server info display and add CORS origins info MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add CORS origins display • Move API key status higher in display • Fix tree symbols for better readability • Regroup related server info • Remove redundant line breaks --- lightrag/api/lightrag_server.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 19552faf..97f1156f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -159,8 +159,12 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.host}") ASCIIColors.white(" ├─ Port: ", end="") ASCIIColors.yellow(f"{args.port}") - ASCIIColors.white(" └─ SSL Enabled: ", end="") + ASCIIColors.white(" ├─ CORS Origins: ", end="") + ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") + ASCIIColors.white(" ├─ SSL Enabled: ", end="") ASCIIColors.yellow(f"{args.ssl}") + ASCIIColors.white(" └─ API Key: ", end="") + ASCIIColors.yellow("Set" if args.key else "Not Set") if args.ssl: ASCIIColors.white(" ├─ SSL Cert: ", end="") ASCIIColors.yellow(f"{args.ssl_certfile}") @@ -229,10 +233,8 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") ASCIIColors.white(" ├─ Log Level: ", end="") ASCIIColors.yellow(f"{args.log_level}") - ASCIIColors.white(" ├─ Timeout: ", end="") + ASCIIColors.white(" └─ Timeout: ", end="") ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") - ASCIIColors.white(" └─ API Key: ", end="") - ASCIIColors.yellow("Set" if args.key else "Not Set") # Server Status ASCIIColors.green("\n✨ Server starting up...\n") From 875d18d80fc74c7d1b424f1963e80c5d34f96fc5 Mon Sep 17 00:00:00 2001 From: Ethan Heavey <65369063+VeiledTee@users.noreply.github.com> Date: Sat, 15 Feb 2025 12:51:24 -0400 Subject: [PATCH 21/28] Update HuggingFace example in README.md Replaced "hf_embedding" with "hf_embed" in HuggingFace example to match implementation --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 62f21a65..ad9ad819 100644 --- a/README.md +++ b/README.md @@ -237,7 +237,7 @@ rag = LightRAG( * If you want to use Hugging Face models, you only need to set LightRAG as follows: ```python -from lightrag.llm import hf_model_complete, hf_embedding +from lightrag.llm import hf_model_complete, hf_embed from transformers import AutoModel, AutoTokenizer from lightrag.utils import EmbeddingFunc @@ -250,7 +250,7 @@ rag = LightRAG( embedding_func=EmbeddingFunc( embedding_dim=384, max_token_size=5000, - func=lambda texts: hf_embedding( + func=lambda texts: hf_embed( texts, tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"), embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") From 49ce229af7f096d6d69bd6e7d8116ee52aab5024 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 22:23:16 +0100 Subject: [PATCH 22/28] fix wrong type --- lightrag/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/base.py b/lightrag/base.py index 42f6d1e9..c638afd7 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -69,7 +69,7 @@ class QueryParam: ll_keywords: list[str] = field(default_factory=list) """List of low-level keywords to refine retrieval focus.""" - conversation_history: list[dict[str, Any]] = field(default_factory=list) + conversation_history: list[dict[str, str]] = field(default_factory=list) """Stores past conversation history to maintain context. Format: [{"role": "user/assistant", "content": "message"}]. """ From 8d0d8b8279bd35fc76bd17c4603ca8f26febd197 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 22:24:49 +0100 Subject: [PATCH 23/28] remove unsused method --- lightrag/base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index c638afd7..ae451dda 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -87,10 +87,6 @@ class StorageNameSpace: """Commit the storage operations after indexing""" pass - async def query_done_callback(self) -> None: - """Commit the storage operations after querying""" - pass - @dataclass class BaseVectorStorage(StorageNameSpace): From eaf1d553d21b5ce4613c75c9f0dea77c87caa7a2 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 22:37:12 +0100 Subject: [PATCH 24/28] improved typing --- lightrag/base.py | 20 +++++++++++--------- lightrag/exceptions.py | 2 ++ lightrag/lightrag.py | 6 +++--- lightrag/llm.py | 8 +++++--- lightrag/namespace.py | 2 ++ lightrag/operate.py | 8 +++++--- lightrag/prompt.py | 2 ++ lightrag/types.py | 19 +++++++++++-------- lightrag/utils.py | 20 +++++++++++--------- 9 files changed, 52 insertions(+), 35 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index ae451dda..1d7a0a98 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,13 +1,13 @@ +from __future__ import annotations + import os from dataclasses import dataclass, field from enum import Enum from typing import ( Any, Literal, - Optional, TypedDict, TypeVar, - Union, ) import numpy as np @@ -115,7 +115,7 @@ class BaseVectorStorage(StorageNameSpace): class BaseKVStorage(StorageNameSpace): embedding_func: EmbeddingFunc | None = None - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: raise NotImplementedError async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -157,21 +157,23 @@ class BaseGraphStorage(StorageNameSpace): """Get a node by its id.""" - async def get_node(self, node_id: str) -> Union[dict[str, str], None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: raise NotImplementedError """Get an edge by its source and target node ids.""" async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> Union[dict[str, str], None]: + self, + source_node_id: str, + target_node_id: str + ) -> dict[str, str] | None : raise NotImplementedError """Get all edges connected to a node.""" async def get_node_edges( self, source_node_id: str - ) -> Union[list[tuple[str, str]], None]: + ) -> list[tuple[str, str]] | None: raise NotImplementedError """Upsert a node into the graph.""" @@ -236,9 +238,9 @@ class DocProcessingStatus: """ISO format timestamp when document was created""" updated_at: str """ISO format timestamp when document was last updated""" - chunks_count: Optional[int] = None + chunks_count: int | None = None """Number of chunks after splitting, used for processing""" - error: Optional[str] = None + error: str | None = None """Error message if failed""" metadata: dict[str, Any] = field(default_factory=dict) """Additional metadata""" diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py index 5de6b334..ae756f85 100644 --- a/lightrag/exceptions.py +++ b/lightrag/exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import httpx from typing import Literal diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index af241f65..fed555a2 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -6,7 +6,7 @@ import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast +from typing import Any, AsyncIterator, Callable, Iterator, cast from .base import ( BaseGraphStorage, @@ -314,7 +314,7 @@ class LightRAG: """Maximum number of concurrent embedding function calls.""" # LLM Configuration - llm_model_func: Union[Callable[..., object], None] = None + llm_model_func: Callable[..., object] | None = None """Function for interacting with the large language model (LLM). Must be set before use.""" llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" @@ -354,7 +354,7 @@ class LightRAG: chunking_func: Callable[ [ str, - Optional[str], + str | None, bool, int, int, diff --git a/lightrag/llm.py b/lightrag/llm.py index b4baef68..e5f98cf8 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,4 +1,6 @@ -from typing import List, Dict, Callable, Any +from __future__ import annotations + +from typing import Callable, Any from pydantic import BaseModel, Field @@ -23,7 +25,7 @@ class Model(BaseModel): ..., description="A function that generates the response from the llm. The response must be a string", ) - kwargs: Dict[str, Any] = Field( + kwargs: dict[str, Any] = Field( ..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc", ) @@ -57,7 +59,7 @@ class MultiModel: ``` """ - def __init__(self, models: List[Model]): + def __init__(self, models: list[Model]): self._models = models self._current_model = 0 diff --git a/lightrag/namespace.py b/lightrag/namespace.py index ba8e3072..77e04c9e 100644 --- a/lightrag/namespace.py +++ b/lightrag/namespace.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Iterable diff --git a/lightrag/operate.py b/lightrag/operate.py index f2bd6218..37e7523f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import json import re from tqdm.asyncio import tqdm as tqdm_async -from typing import Any, AsyncIterator, Union +from typing import Any, AsyncIterator from collections import Counter, defaultdict from .utils import ( logger, @@ -36,7 +38,7 @@ import time def chunking_by_token_size( content: str, - split_by_character: Union[str, None] = None, + split_by_character: str | None = None, split_by_character_only: bool = False, overlap_token_size: int = 128, max_token_size: int = 1024, @@ -297,7 +299,7 @@ async def extract_entities( relationships_vdb: BaseVectorStorage, global_config: dict[str, str], llm_response_cache: BaseKVStorage | None = None, -) -> Union[BaseGraphStorage, None]: +) -> BaseGraphStorage | None: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[ diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 160663d9..f4f5e38a 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -1,3 +1,5 @@ +from __future__ import annotations + GRAPH_FIELD_SEP = "" PROMPTS = {} diff --git a/lightrag/types.py b/lightrag/types.py index 9c8e0099..2510bed3 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,16 +1,19 @@ + +from __future__ import annotations + from pydantic import BaseModel -from typing import List, Dict, Any +from typing import Any class GPTKeywordExtractionFormat(BaseModel): - high_level_keywords: List[str] - low_level_keywords: List[str] + high_level_keywords: list[str] + low_level_keywords: list[str] class KnowledgeGraphNode(BaseModel): id: str - labels: List[str] - properties: Dict[str, Any] # anything else goes here + labels: list[str] + properties: dict[str, Any] # anything else goes here class KnowledgeGraphEdge(BaseModel): @@ -18,9 +21,9 @@ class KnowledgeGraphEdge(BaseModel): type: str source: str # id of source node target: str # id of target node - properties: Dict[str, Any] # anything else goes here + properties: dict[str, Any] # anything else goes here class KnowledgeGraph(BaseModel): - nodes: List[KnowledgeGraphNode] = [] - edges: List[KnowledgeGraphEdge] = [] + nodes: list[KnowledgeGraphNode] = [] + edges: list[KnowledgeGraphEdge] = [] diff --git a/lightrag/utils.py b/lightrag/utils.py index 9b18d0c2..5b86ee78 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import html import io @@ -9,7 +11,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Callable, Union, List, Optional +from typing import Any, Callable import xml.etree.ElementTree as ET import bs4 @@ -72,7 +74,7 @@ class ReasoningResponse: tag: str -def locate_json_string_body_from_string(content: str) -> Union[str, None]: +def locate_json_string_body_from_string(content: str) -> str | None: """Locate the JSON string body from a string""" try: maybe_json_str = re.search(r"{.*}", content, re.DOTALL) @@ -238,7 +240,7 @@ def truncate_list_by_token_size( return list_data -def list_of_list_to_csv(data: List[List[str]]) -> str: +def list_of_list_to_csv(data: list[list[str]]) -> str: output = io.StringIO() writer = csv.writer( output, @@ -251,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str: return output.getvalue() -def csv_string_to_list(csv_string: str) -> List[List[str]]: +def csv_string_to_list(csv_string: str) -> list[list[str]]: # Clean the string by removing NUL characters cleaned_string = csv_string.replace("\0", "") @@ -382,7 +384,7 @@ async def get_best_cached_response( llm_func=None, original_prompt=None, cache_type=None, -) -> Union[str, None]: +) -> str | None: logger.debug( f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}" ) @@ -486,7 +488,7 @@ def cosine_similarity(v1, v2): return dot_product / (norm1 * norm2) -def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple: +def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple: """Quantize embedding to specified bits""" # Convert list to numpy array if needed if isinstance(embedding, list): @@ -577,9 +579,9 @@ class CacheData: args_hash: str content: str prompt: str - quantized: Optional[np.ndarray] = None - min_val: Optional[float] = None - max_val: Optional[float] = None + quantized: np.ndarray | None = None + min_val: float | None = None + max_val: float | None = None mode: str = "default" cache_type: str = "query" From 3319db0dba5b6efebeab34ea2c6a7d1b328e328a Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 22:37:32 +0100 Subject: [PATCH 25/28] cleaned code --- lightrag/base.py | 10 +++------- lightrag/types.py | 1 - lightrag/utils.py | 2 +- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 1d7a0a98..3d4fc022 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -163,17 +163,13 @@ class BaseGraphStorage(StorageNameSpace): """Get an edge by its source and target node ids.""" async def get_edge( - self, - source_node_id: str, - target_node_id: str - ) -> dict[str, str] | None : + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: raise NotImplementedError """Get all edges connected to a node.""" - async def get_node_edges( - self, source_node_id: str - ) -> list[tuple[str, str]] | None: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: raise NotImplementedError """Upsert a node into the graph.""" diff --git a/lightrag/types.py b/lightrag/types.py index 2510bed3..40e61bc4 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,4 +1,3 @@ - from __future__ import annotations from pydantic import BaseModel diff --git a/lightrag/utils.py b/lightrag/utils.py index 5b86ee78..c8786e7b 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -488,7 +488,7 @@ def cosine_similarity(v1, v2): return dot_product / (norm1 * norm2) -def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple: +def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple: """Quantize embedding to specified bits""" # Convert list to numpy array if needed if isinstance(embedding, list): From edaba428acab649489e1fac2d6c20a2e1fbf58e1 Mon Sep 17 00:00:00 2001 From: St1ve <62241277+St1veLiu@users.noreply.github.com> Date: Sun, 16 Feb 2025 19:33:59 +0800 Subject: [PATCH 26/28] Update json_kv_impl.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There is no delete function in chunks_vdb and text_chunks in lines 1300-1302 of the lightrag.py file: if chunk_ids: await self.chunks_vdb.delete(chunk_ids) await self.text_chunks.delete(chunk_ids) --- lightrag/kg/json_kv_impl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index cfd67367..0bda6d42 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -47,3 +47,8 @@ class JsonKVStorage(BaseKVStorage): async def drop(self) -> None: self._data = {} + + async def delete(self, ids: list[str]) -> None: + for doc_id in ids: + self._data.pop(doc_id, None) + await self.index_done_callback() From 1051ff402d6ea5c3bf4b398b34519051ac1b84d3 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 12:45:27 +0100 Subject: [PATCH 27/28] fixed lint --- examples/test_chromadb.py | 4 +++- lightrag/kg/chroma_impl.py | 8 ++++++-- lightrag/types.py | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/test_chromadb.py b/examples/test_chromadb.py index 5293f05d..99090a6d 100644 --- a/examples/test_chromadb.py +++ b/examples/test_chromadb.py @@ -17,7 +17,9 @@ if not os.path.exists(WORKING_DIR): # ChromaDB Configuration CHROMADB_USE_LOCAL_PERSISTENT = False # Local PersistentClient Configuration -CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")) +CHROMADB_LOCAL_PATH = os.environ.get( + "CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data") +) # Remote HttpClient Configuration CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost") CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000)) diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 7b7642d6..cb3b59f1 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -67,7 +67,9 @@ class ChromaVectorDBStorage(BaseVectorStorage): if "token_authn" in auth_provider: headers = { - config.get("auth_header_name", "X-Chroma-Token"): auth_credentials + config.get( + "auth_header_name", "X-Chroma-Token" + ): auth_credentials } elif "basic_authn" in auth_provider: auth_credentials = config.get("auth_credentials", "admin:admin") @@ -154,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func([query]) results = self._collection.query( - query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding, + query_embeddings=embedding.tolist() + if not isinstance(embedding, list) + else embedding, n_results=top_k * 2, # Request more results to allow for filtering include=["metadatas", "distances", "documents"], ) diff --git a/lightrag/types.py b/lightrag/types.py index 35036d2f..5e3d2948 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,7 +1,7 @@ from __future__ import annotations from pydantic import BaseModel -from typing import List, Dict, Any, Optional +from typing import Any, Optional class GPTKeywordExtractionFormat(BaseModel): From ef0e81315f9496c3cba4505ca4f95a789781687e Mon Sep 17 00:00:00 2001 From: zrguo Date: Sun, 16 Feb 2025 19:53:28 +0800 Subject: [PATCH 28/28] fix linting --- lightrag/kg/json_kv_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 0bda6d42..3ab5b966 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -47,7 +47,7 @@ class JsonKVStorage(BaseKVStorage): async def drop(self) -> None: self._data = {} - + async def delete(self, ids: list[str]) -> None: for doc_id in ids: self._data.pop(doc_id, None)