From a5cd2b19588b16fb3afa63786db048edc3c531b9 Mon Sep 17 00:00:00 2001 From: destiny <1170513392@qq.com> Date: Fri, 14 Feb 2025 11:00:54 +0800 Subject: [PATCH 1/5] Fix embedding type conversion issue in the query function of chroma_impl; chroma_impl supports local persistent client: PersistentClient --- examples/test_chromadb.py | 70 +++++++++++++++++++++++++------------- lightrag/kg/chroma_impl.py | 60 ++++++++++++++++++-------------- 2 files changed, 82 insertions(+), 48 deletions(-) diff --git a/examples/test_chromadb.py b/examples/test_chromadb.py index 0e6361ed..5293f05d 100644 --- a/examples/test_chromadb.py +++ b/examples/test_chromadb.py @@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) # ChromaDB Configuration +CHROMADB_USE_LOCAL_PERSISTENT = False +# Local PersistentClient Configuration +CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")) +# Remote HttpClient Configuration CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost") CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000)) CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token") @@ -60,30 +64,50 @@ async def create_embedding_function_instance(): async def initialize_rag(): embedding_func_instance = await create_embedding_function_instance() - - return LightRAG( - working_dir=WORKING_DIR, - llm_model_func=gpt_4o_mini_complete, - embedding_func=embedding_func_instance, - vector_storage="ChromaVectorDBStorage", - log_level="DEBUG", - embedding_batch_num=32, - vector_db_storage_cls_kwargs={ - "host": CHROMADB_HOST, - "port": CHROMADB_PORT, - "auth_token": CHROMADB_AUTH_TOKEN, - "auth_provider": CHROMADB_AUTH_PROVIDER, - "auth_header_name": CHROMADB_AUTH_HEADER, - "collection_settings": { - "hnsw:space": "cosine", - "hnsw:construction_ef": 128, - "hnsw:search_ef": 128, - "hnsw:M": 16, - "hnsw:batch_size": 100, - "hnsw:sync_threshold": 1000, + if CHROMADB_USE_LOCAL_PERSISTENT: + return LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, + embedding_func=embedding_func_instance, + vector_storage="ChromaVectorDBStorage", + log_level="DEBUG", + embedding_batch_num=32, + vector_db_storage_cls_kwargs={ + "local_path": CHROMADB_LOCAL_PATH, + "collection_settings": { + "hnsw:space": "cosine", + "hnsw:construction_ef": 128, + "hnsw:search_ef": 128, + "hnsw:M": 16, + "hnsw:batch_size": 100, + "hnsw:sync_threshold": 1000, + }, }, - }, - ) + ) + else: + return LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, + embedding_func=embedding_func_instance, + vector_storage="ChromaVectorDBStorage", + log_level="DEBUG", + embedding_batch_num=32, + vector_db_storage_cls_kwargs={ + "host": CHROMADB_HOST, + "port": CHROMADB_PORT, + "auth_token": CHROMADB_AUTH_TOKEN, + "auth_provider": CHROMADB_AUTH_PROVIDER, + "auth_header_name": CHROMADB_AUTH_HEADER, + "collection_settings": { + "hnsw:space": "cosine", + "hnsw:construction_ef": 128, + "hnsw:search_ef": 128, + "hnsw:M": 16, + "hnsw:batch_size": 100, + "hnsw:sync_threshold": 1000, + }, + }, + ) # Run the initialization diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 72a2627a..a2fbf674 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -3,7 +3,7 @@ import asyncio from dataclasses import dataclass from typing import Union import numpy as np -from chromadb import HttpClient +from chromadb import HttpClient, PersistentClient from chromadb.config import Settings from lightrag.base import BaseVectorStorage from lightrag.utils import logger @@ -48,31 +48,41 @@ class ChromaVectorDBStorage(BaseVectorStorage): **user_collection_settings, } - auth_provider = config.get( - "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider" - ) - auth_credentials = config.get("auth_token", "secret-token") - headers = {} + local_path = config.get("local_path", None) + if local_path: + self._client = PersistentClient( + path=local_path, + settings=Settings( + allow_reset=True, + anonymized_telemetry=False, + ), + ) + else: + auth_provider = config.get( + "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider" + ) + auth_credentials = config.get("auth_token", "secret-token") + headers = {} - if "token_authn" in auth_provider: - headers = { - config.get("auth_header_name", "X-Chroma-Token"): auth_credentials - } - elif "basic_authn" in auth_provider: - auth_credentials = config.get("auth_credentials", "admin:admin") + if "token_authn" in auth_provider: + headers = { + config.get("auth_header_name", "X-Chroma-Token"): auth_credentials + } + elif "basic_authn" in auth_provider: + auth_credentials = config.get("auth_credentials", "admin:admin") - self._client = HttpClient( - host=config.get("host", "localhost"), - port=config.get("port", 8000), - headers=headers, - settings=Settings( - chroma_api_impl="rest", - chroma_client_auth_provider=auth_provider, - chroma_client_auth_credentials=auth_credentials, - allow_reset=True, - anonymized_telemetry=False, - ), - ) + self._client = HttpClient( + host=config.get("host", "localhost"), + port=config.get("port", 8000), + headers=headers, + settings=Settings( + chroma_api_impl="rest", + chroma_client_auth_provider=auth_provider, + chroma_client_auth_credentials=auth_credentials, + allow_reset=True, + anonymized_telemetry=False, + ), + ) self._collection = self._client.get_or_create_collection( name=self.namespace, @@ -143,7 +153,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func([query]) results = self._collection.query( - query_embeddings=embedding.tolist(), + query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding, n_results=top_k * 2, # Request more results to allow for filtering include=["metadatas", "distances", "documents"], ) From cd81312659630cde494b34bf26f73720187f80fc Mon Sep 17 00:00:00 2001 From: Pankaj Kaushal Date: Fri, 14 Feb 2025 16:04:06 +0100 Subject: [PATCH 2/5] Enhance Neo4j graph storage with error handling and label validation - Add label existence check and validation methods in Neo4j implementation - Improve error handling in get_node, get_edge, and upsert methods - Add default values and logging for missing edge properties - Ensure consistent label processing across graph storage methods --- lightrag/kg/neo4j_impl.py | 134 ++++++++++++++++++++++++++++---------- lightrag/operate.py | 62 ++++++++++++++---- 2 files changed, 150 insertions(+), 46 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index e9a53110..15525375 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage): async def index_done_callback(self): print("KG successfully indexed.") - async def has_node(self, node_id: str) -> bool: - entity_name_label = node_id.strip('"') + async def _label_exists(self, label: str) -> bool: + """Check if a label exists in the Neo4j database.""" + query = "CALL db.labels() YIELD label RETURN label" + try: + async with self._driver.session(database=self._DATABASE) as session: + result = await session.run(query) + labels = [record["label"] for record in await result.data()] + return label in labels + except Exception as e: + logger.error(f"Error checking label existence: {e}") + return False + async def _ensure_label(self, label: str) -> str: + """Ensure a label exists by validating it.""" + clean_label = label.strip('"') + if not await self._label_exists(clean_label): + logger.warning(f"Label '{clean_label}' does not exist in Neo4j") + return clean_label + + async def has_node(self, node_id: str) -> bool: + entity_name_label = await self._ensure_label(node_id) async with self._driver.session(database=self._DATABASE) as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" @@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage): return single_result["edgeExists"] async def get_node(self, node_id: str) -> Union[dict, None]: + """Get node by its label identifier. + + Args: + node_id: The node label to look up + + Returns: + dict: Node properties if found + None: If node not found + """ async with self._driver.session(database=self._DATABASE) as session: - entity_name_label = node_id.strip('"') + entity_name_label = await self._ensure_label(node_id) query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) record = await result.single() @@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') - """ - Find all edges between nodes of two given labels + """Find edge between two nodes identified by their labels. Args: - source_node_label (str): Label of the source nodes - target_node_label (str): Label of the target nodes + source_node_id (str): Label of the source node + target_node_id (str): Label of the target node Returns: - list: List of all relationships/edges found + dict: Edge properties if found, with at least {"weight": 0.0} + None: If error occurs """ - async with self._driver.session(database=self._DATABASE) as session: - query = f""" - MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) - RETURN properties(r) as edge_properties - LIMIT 1 - """.format( - entity_name_label_source=entity_name_label_source, - entity_name_label_target=entity_name_label_target, - ) + try: + entity_name_label_source = source_node_id.strip('"') + entity_name_label_target = target_node_id.strip('"') - result = await session.run(query) - record = await result.single() - if record: - result = dict(record["edge_properties"]) - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" + async with self._driver.session(database=self._DATABASE) as session: + query = f""" + MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) + RETURN properties(r) as edge_properties + LIMIT 1 + """.format( + entity_name_label_source=entity_name_label_source, + entity_name_label_target=entity_name_label_target, ) - return result - else: - return None + + result = await session.run(query) + record = await result.single() + if record and "edge_properties" in record: + try: + result = dict(record["edge_properties"]) + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "target_id": None, + } + for key, default_value in required_keys.items(): + if key not in result: + result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" + ) + return result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return {"weight": 0.0, "source_id": None, "target_id": None} + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + ) + # Return default edge properties when no edge found + return {"weight": 0.0, "source_id": None, "target_id": None} + + except Exception as e: + logger.error( + f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + # Return default edge properties on error + return {"weight": 0.0, "source_id": None, "target_id": None} async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: node_label = source_node_id.strip('"') @@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ - label = node_id.strip('"') + label = await self._ensure_label(node_id) properties = node_data async def _do_upsert(tx: AsyncManagedTransaction): @@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage): neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, ) ), ) @@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage): target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge """ - source_node_label = source_node_id.strip('"') - target_node_label = target_node_id.strip('"') + source_label = await self._ensure_label(source_node_id) + target_label = await self._ensure_label(target_node_id) edge_properties = edge_data async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" - MATCH (source:`{source_node_label}`) + MATCH (source:`{source_label}`) WITH source - MATCH (target:`{target_node_label}`) + MATCH (target:`{target_label}`) MERGE (source)-[r:DIRECTED]->(target) SET r += $properties RETURN r """ - await tx.run(query, properties=edge_properties) + result = await tx.run(query, properties=edge_properties) + record = await result.single() logger.debug( - f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}" + f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" ) try: diff --git a/lightrag/operate.py b/lightrag/operate.py index 04aad0d4..8cf77f57 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -237,25 +237,65 @@ async def _merge_edges_then_upsert( if await knowledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) - already_weights.append(already_edge["weight"]) - already_source_ids.extend( - split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) - ) - already_description.append(already_edge["description"]) - already_keywords.extend( - split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP]) - ) + # Handle the case where get_edge returns None or missing fields + if already_edge: + # Get weight with default 0.0 if missing + if "weight" in already_edge: + already_weights.append(already_edge["weight"]) + else: + logger.warning( + f"Edge between {src_id} and {tgt_id} missing weight field" + ) + already_weights.append(0.0) + # Get source_id with empty string default if missing or None + if "source_id" in already_edge and already_edge["source_id"] is not None: + already_source_ids.extend( + split_string_by_multi_markers( + already_edge["source_id"], [GRAPH_FIELD_SEP] + ) + ) + + # Get description with empty string default if missing or None + if ( + "description" in already_edge + and already_edge["description"] is not None + ): + already_description.append(already_edge["description"]) + + # Get keywords with empty string default if missing or None + if "keywords" in already_edge and already_edge["keywords"] is not None: + already_keywords.extend( + split_string_by_multi_markers( + already_edge["keywords"], [GRAPH_FIELD_SEP] + ) + ) + + # Process edges_data with None checks weight = sum([dp["weight"] for dp in edges_data] + already_weights) description = GRAPH_FIELD_SEP.join( - sorted(set([dp["description"] for dp in edges_data] + already_description)) + sorted( + set( + [dp["description"] for dp in edges_data if dp.get("description")] + + already_description + ) + ) ) keywords = GRAPH_FIELD_SEP.join( - sorted(set([dp["keywords"] for dp in edges_data] + already_keywords)) + sorted( + set( + [dp["keywords"] for dp in edges_data if dp.get("keywords")] + + already_keywords + ) + ) ) source_id = GRAPH_FIELD_SEP.join( - set([dp["source_id"] for dp in edges_data] + already_source_ids) + set( + [dp["source_id"] for dp in edges_data if dp.get("source_id")] + + already_source_ids + ) ) + for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): await knowledge_graph_inst.upsert_node( From 70fc4cbfb0e769dcaea3823b0d79bac6e693410c Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sat, 15 Feb 2025 00:34:38 +0800 Subject: [PATCH 3/5] handle missing edge types in graph data --- lightrag/types.py | 4 ++-- lightrag_webui/src/components/PropertiesView.tsx | 2 +- lightrag_webui/src/hooks/useLightragGraph.tsx | 12 ++++++++++-- lightrag_webui/src/stores/graph.ts | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/lightrag/types.py b/lightrag/types.py index 9c8e0099..d2670ddc 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional class GPTKeywordExtractionFormat(BaseModel): @@ -15,7 +15,7 @@ class KnowledgeGraphNode(BaseModel): class KnowledgeGraphEdge(BaseModel): id: str - type: str + type: Optional[str] source: str # id of source node target: str # id of target node properties: Dict[str, Any] # anything else goes here diff --git a/lightrag_webui/src/components/PropertiesView.tsx b/lightrag_webui/src/components/PropertiesView.tsx index 078420e6..dec80460 100644 --- a/lightrag_webui/src/components/PropertiesView.tsx +++ b/lightrag_webui/src/components/PropertiesView.tsx @@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
- + {edge.type && } { } for (const edge of graph.edges) { - if (!edge.id || !edge.source || !edge.target || !edge.type || !edge.properties) { + if (!edge.id || !edge.source || !edge.target) { return false } } @@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => { if (source !== undefined && source !== undefined) { const sourceNode = rawData.nodes[source] const targetNode = rawData.nodes[target] + if (!sourceNode) { + console.error(`Source node ${edge.source} is undefined`) + continue + } + if (!targetNode) { + console.error(`Target node ${edge.target} is undefined`) + continue + } sourceNode.degree += 1 targetNode.degree += 1 } @@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => { for (const rawEdge of rawGraph?.edges ?? []) { rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, { - label: rawEdge.type + label: rawEdge.type || undefined }) } diff --git a/lightrag_webui/src/stores/graph.ts b/lightrag_webui/src/stores/graph.ts index b78e9bf8..b7c2120c 100644 --- a/lightrag_webui/src/stores/graph.ts +++ b/lightrag_webui/src/stores/graph.ts @@ -19,7 +19,7 @@ export type RawEdgeType = { id: string source: string target: string - type: string + type?: string properties: Record dynamicId: string From a600beb619c8b784bb309c7f3dcec94e14573570 Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sat, 15 Feb 2025 00:38:41 +0800 Subject: [PATCH 4/5] implement MongoDB support for VectorDB storage. optimize existing MongoDB implementations --- lightrag/api/README.md | 3 +- lightrag/kg/mongo_impl.py | 510 +++++++++++++++++++++++++++++++++----- lightrag/lightrag.py | 3 + 3 files changed, 456 insertions(+), 60 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 8e5a61d5..18ab3594 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -177,7 +177,8 @@ TiDBVectorDBStorage TiDB PGVectorStorage Postgres FaissVectorDBStorage Faiss QdrantVectorDBStorage Qdrant -OracleVectorDBStorag Oracle +OracleVectorDBStorage Oracle +MongoVectorDBStorage MongoDB ``` * DOC_STATUS_STORAGE:supported implement-name diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 226aecf2..c216e7be 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -4,6 +4,7 @@ import numpy as np import pipmaster as pm import configparser from tqdm.asyncio import tqdm as tqdm_async +import asyncio if not pm.is_installed("pymongo"): pm.install("pymongo") @@ -14,16 +15,20 @@ if not pm.is_installed("motor"): from typing import Any, List, Tuple, Union from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient +from pymongo.operations import SearchIndexModel +from pymongo.errors import PyMongoError from ..base import ( BaseGraphStorage, BaseKVStorage, + BaseVectorStorage, DocProcessingStatus, DocStatus, DocStatusStorage, ) from ..namespace import NameSpace, is_namespace from ..utils import logger +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge config = configparser.ConfigParser() @@ -33,56 +38,66 @@ config.read("config.ini", "utf-8") @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): - client = MongoClient( - os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), ) + client = AsyncIOMotorClient(uri) database = client.get_database( os.environ.get( "MONGO_DATABASE", config.get("mongodb", "database", fallback="LightRAG"), ) ) - self._data = database.get_collection(self.namespace) - logger.info(f"Use MongoDB as KV {self.namespace}") + + self._collection_name = self.namespace + + self._data = database.get_collection(self._collection_name) + logger.debug(f"Use MongoDB as KV {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.find_one({"_id": id}) + return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - return list(self._data.find({"_id": {"$in": ids}})) + cursor = self._data.find({"_id": {"$in": ids}}) + return await cursor.to_list() async def filter_keys(self, data: set[str]) -> set[str]: - existing_ids = [ - str(x["_id"]) - for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) - ] - return set([s for s in data if s not in existing_ids]) + cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + existing_ids = {str(x["_id"]) async for x in cursor} + return data - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + update_tasks = [] for mode, items in data.items(): - for k, v in tqdm_async(items.items(), desc="Upserting"): + for k, v in items.items(): key = f"{mode}_{k}" - result = self._data.update_one( - {"_id": key}, {"$setOnInsert": v}, upsert=True + data[mode][k]["_id"] = f"{mode}_{k}" + update_tasks.append( + self._data.update_one( + {"_id": key}, {"$setOnInsert": v}, upsert=True + ) ) - if result.upserted_id: - logger.debug(f"\nInserted new document with key: {key}") - data[mode][k]["_id"] = key + await asyncio.gather(*update_tasks) else: - for k, v in tqdm_async(data.items(), desc="Upserting"): - self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + update_tasks = [] + for k, v in data.items(): data[k]["_id"] = k + update_tasks.append( + self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + ) + await asyncio.gather(*update_tasks) async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): res = {} - v = self._data.find_one({"_id": mode + "_" + id}) + v = await self._data.find_one({"_id": mode + "_" + id}) if v: res[id] = v logger.debug(f"llm_response_cache find one by:{id}") @@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage): @dataclass class MongoDocStatusStorage(DocStatusStorage): def __post_init__(self): - client = MongoClient( - os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), ) - database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG")) - self._data = database.get_collection(self.namespace) - logger.info(f"Use MongoDB as doc status {self.namespace}") + client = AsyncIOMotorClient(uri) + database = client.get_database( + os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + ) + + self._collection_name = self.namespace + self._data = database.get_collection(self._collection_name) + + logger.debug(f"Use MongoDB as doc status {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.find_one({"_id": id}) + return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - return list(self._data.find({"_id": {"$in": ids}})) + cursor = self._data.find({"_id": {"$in": ids}}) + return await cursor.to_list() async def filter_keys(self, data: set[str]) -> set[str]: - existing_ids = [ - str(x["_id"]) - for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) - ] - return set([s for s in data if s not in existing_ids]) + cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + existing_ids = {str(x["_id"]) async for x in cursor} + return data - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + update_tasks = [] for k, v in data.items(): - self._data.update_one({"_id": k}, {"$set": v}, upsert=True) data[k]["_id"] = k + update_tasks.append( + self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + ) + await asyncio.gather(*update_tasks) async def drop(self) -> None: """Drop the collection""" @@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage): async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}] - result = list(self._data.aggregate(pipeline)) + cursor = self._data.aggregate(pipeline) + result = await cursor.to_list() counts = {} for doc in result: counts[doc["_id"]] = doc["count"] @@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage): self, status: DocStatus ) -> dict[str, DocProcessingStatus]: """Get all documents by status""" - result = list(self._data.find({"status": status.value})) + cursor = self._data.find({"status": status.value}) + result = await cursor.to_list() return { doc["_id"]: DocProcessingStatus( content=doc["content"], @@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage): global_config=global_config, embedding_func=embedding_func, ) - self.client = AsyncIOMotorClient( - os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), ) - self.db = self.client[ + client = AsyncIOMotorClient(uri) + database = client.get_database( os.environ.get( "MONGO_DATABASE", - mongo_database=config.get("mongodb", "database", fallback="LightRAG"), + config.get("mongodb", "database", fallback="LightRAG"), ) - ] - self.collection = self.db[ - os.environ.get( - "MONGO_KG_COLLECTION", - config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"), - ) - ] + ) + + self._collection_name = self.namespace + self.collection = database.get_collection(self._collection_name) + + logger.debug(f"Use MongoDB as KG {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) # # ------------------------------------------------------------------------- @@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage): self, source_node_id: str ) -> Union[List[Tuple[str, str]], None]: """ - Return a list of (target_id, relation) for direct edges from source_node_id. + Return a list of (source_id, target_id) for direct edges from source_node_id. Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. """ pipeline = [ @@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage): return None edges = result[0].get("edges", []) - return [(e["target"], e["relation"]) for e in edges] + return [(source_node_id, e["target"]) for e in edges] # # ------------------------------------------------------------------------- @@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str): """ - 1) Remove node’s doc entirely. + 1) Remove node's doc entirely. 2) Remove inbound edges from any doc that references node_id. """ # Remove inbound edges from all other docs @@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage): Placeholder for demonstration, raises NotImplementedError. """ raise NotImplementedError("Node embedding is not used in lightrag.") + + # + # ------------------------------------------------------------------------- + # QUERY + # ------------------------------------------------------------------------- + # + + async def get_all_labels(self) -> list[str]: + """ + Get all existing node _id in the database + Returns: + [id1, id2, ...] # Alphabetically sorted id list + """ + # Use MongoDB's distinct and aggregation to get all unique labels + pipeline = [ + {"$group": {"_id": "$_id"}}, # Group by _id + {"$sort": {"_id": 1}}, # Sort alphabetically + ] + + cursor = self.collection.aggregate(pipeline) + labels = [] + async for doc in cursor: + labels.append(doc["_id"]) + return labels + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + """ + Get complete connected subgraph for specified node (including the starting node itself) + + Args: + node_label: Label of the nodes to start from + max_depth: Maximum depth of traversal (default: 5) + + Returns: + KnowledgeGraph object containing nodes and edges of the subgraph + """ + label = node_label + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + try: + if label == "*": + # Get all nodes and edges + async for node_doc in self.collection.find({}): + node_id = str(node_doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_doc.get("_id")], + properties={ + k: v + for k, v in node_doc.items() + if k not in ["_id", "edges"] + }, + ) + ) + seen_nodes.add(node_id) + + # Process edges + for edge in node_doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + else: + # Verify if starting node exists + start_nodes = self.collection.find({"_id": label}) + start_nodes_exist = await start_nodes.to_list(length=1) + if not start_nodes_exist: + logger.warning(f"Starting node with label {label} does not exist!") + return result + + # Use $graphLookup for traversal + pipeline = [ + { + "$match": {"_id": label} + }, # Start with nodes having the specified label + { + "$graphLookup": { + "from": self._collection_name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_nodes", + } + }, + ] + + async for doc in self.collection.aggregate(pipeline): + # Add the start node + node_id = str(doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[ + doc.get( + "_id", + ) + ], + properties={ + k: v + for k, v in doc.items() + if k + not in [ + "_id", + "edges", + "connected_nodes", + "depth", + ] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from start node + for edge in doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + # Add connected nodes and their edges + for connected in doc.get("connected_nodes", []): + node_id = str(connected["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[connected.get("_id")], + properties={ + k: v + for k, v in connected.items() + if k not in ["_id", "edges", "depth"] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from connected nodes + for edge in connected.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except PyMongoError as e: + logger.error(f"MongoDB query failed: {str(e)}") + + return result + + +@dataclass +class MongoVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = None + + def __post_init__(self): + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = kwargs.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) + self.cosine_better_than_threshold = cosine_threshold + + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), + ) + client = AsyncIOMotorClient(uri) + database = client.get_database( + os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + ) + + self._collection_name = self.namespace + self._data = database.get_collection(self._collection_name) + self._max_batch_size = self.global_config["embedding_batch_num"] + + logger.debug(f"Use MongoDB as VDB {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) + + # Ensure vector index exists + self.create_vector_index(uri, database.name, self._collection_name) + + def create_vector_index(self, uri: str, database_name: str, collection_name: str): + """Creates an Atlas Vector Search index.""" + client = MongoClient(uri) + collection = client.get_database(database_name).get_collection( + self._collection_name + ) + + try: + search_index_model = SearchIndexModel( + definition={ + "fields": [ + { + "type": "vector", + "numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions + "path": "vector", + "similarity": "cosine", # Options: euclidean, cosine, dotProduct + } + ] + }, + name="vector_knn_index", + type="vectorSearch", + ) + + collection.create_search_index(search_index_model) + logger.info("Vector index created successfully.") + + except PyMongoError as _: + logger.debug("vector index already exist") + + async def upsert(self, data: dict[str, dict]): + logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") + if not data: + logger.warning("You are inserting an empty data set to vector DB") + return [] + + list_data = [ + { + "_id": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) + embeddings_list = await asyncio.gather(*embedding_tasks) + + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist() + + update_tasks = [] + for doc in list_data: + update_tasks.append( + self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True) + ) + await asyncio.gather(*update_tasks) + + return list_data + + async def query(self, query, top_k=5): + """Queries the vector database using Atlas Vector Search.""" + # Generate the embedding + embedding = await self.embedding_func([query]) + + # Convert numpy array to a list to ensure compatibility with MongoDB + query_vector = embedding[0].tolist() + + # Define the aggregation pipeline with the converted query vector + pipeline = [ + { + "$vectorSearch": { + "index": "vector_knn_index", # Ensure this matches the created index name + "path": "vector", + "queryVector": query_vector, + "numCandidates": 100, # Adjust for performance + "limit": top_k, + } + }, + {"$addFields": {"score": {"$meta": "vectorSearchScore"}}}, + {"$match": {"score": {"$gte": self.cosine_better_than_threshold}}}, + {"$project": {"vector": 0}}, + ] + + # Execute the aggregation pipeline + cursor = self._data.aggregate(pipeline) + results = await cursor.to_list() + + # Format and return the results + return [ + {**doc, "id": doc["_id"], "distance": doc.get("score", None)} + for doc in results + ] + + +def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): + """Check if the collection exists. if not, create it.""" + client = MongoClient(uri) + database = client.get_database(database_name) + + collection_names = database.list_collection_names() + + if collection_name not in collection_names: + database.create_collection(collection_name) + logger.info(f"Created collection: {collection_name}") + else: + logger.debug(f"Collection '{collection_name}' already exists.") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9f74c917..ed0dec29 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -76,6 +76,7 @@ STORAGE_IMPLEMENTATIONS = { "FaissVectorDBStorage", "QdrantVectorDBStorage", "OracleVectorDBStorage", + "MongoVectorDBStorage", ], "required_methods": ["query", "upsert"], }, @@ -140,6 +141,7 @@ STORAGE_ENV_REQUIREMENTS = { "ORACLE_PASSWORD", "ORACLE_CONFIG_DIR", ], + "MongoVectorDBStorage": [], # Document Status Storage Implementations "JsonDocStatusStorage": [], "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], @@ -160,6 +162,7 @@ STORAGES = { "MongoKVStorage": ".kg.mongo_impl", "MongoDocStatusStorage": ".kg.mongo_impl", "MongoGraphStorage": ".kg.mongo_impl", + "MongoVectorDBStorage": ".kg.mongo_impl", "RedisKVStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl", "TiDBKVStorage": ".kg.tidb_impl", From 50919442e906623974a9888c12ec7b6c6ad3335c Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sat, 15 Feb 2025 00:56:45 +0100 Subject: [PATCH 5/5] Improve git and docker ignore --- .dockerignore | 64 ++++++++++++++++++++++++++++++++++++++++- .gitignore | 79 +++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 120 insertions(+), 23 deletions(-) diff --git a/.dockerignore b/.dockerignore index 4c49bd78..f1a82ffa 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,63 @@ -.env +# Python-related files and directories +__pycache__ +.cache + +# Virtual environment directories +*.venv + +# Env +env/ +*.env* +.env_example + +# Distribution / build files +site +dist/ +build/ +.eggs/ +*.egg-info/ +*.tgz +*.tar.gz + +# Exclude siles and folders +*.yml +.dockerignore +Dockerfile +Makefile + +# Exclude other projects +/tests +/scripts + +# Python version manager file +.python-version + +# Reports +*.coverage/ +*.log +log/ +*.logfire + +# Cache +.cache/ +.mypy_cache +.pytest_cache +.ruff_cache +.gradio +.logfire +temp/ + +# MacOS-related files +.DS_Store + +# VS Code settings (local configuration files) +.vscode + +# file +TODO.md + +# Exclude Git-related files +.git +.github +.gitignore +.pre-commit-config.yaml diff --git a/.gitignore b/.gitignore index 2d9a41f3..2d074372 100644 --- a/.gitignore +++ b/.gitignore @@ -1,26 +1,61 @@ -__pycache__ -*.egg-info +# Python-related files +__pycache__/ +*.py[cod] +*.egg-info/ +.eggs/ +*.tgz +*.tar.gz +*.ini # Remove config.ini from repo + +# Virtual Environment +.venv/ +env/ +venv/ +*.env* +.env_example + +# Build / Distribution +dist/ +build/ +site/ + +# Logs / Reports +*.log +*.logfire +*.coverage/ +log/ + +# Caches +.cache/ +.mypy_cache/ +.pytest_cache/ +.ruff_cache/ +.gradio/ +temp/ + +# IDE / Editor Files +.idea/ +.vscode/ +.vscode/settings.json + +# Framework-specific files +local_neo4jWorkDir/ +neo4jWorkDir/ + +# Data & Storage +inputs/ +rag_storage/ +examples/input/ +examples/output/ + +# Miscellaneous +.DS_Store +TODO.md +ignore_this.txt +*.ignore.* + +# Project-specific files dickens/ book.txt lightrag-dev/ -.idea/ -dist/ -env/ -local_neo4jWorkDir/ -neo4jWorkDir/ -ignore_this.txt -.venv/ -*.ignore.* -.ruff_cache/ gui/ -*.log -.vscode -inputs -rag_storage -.env -venv/ -examples/input/ -examples/output/ -.DS_Store -#Remove config.ini from repo -*.ini