From 70fc4cbfb0e769dcaea3823b0d79bac6e693410c Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sat, 15 Feb 2025 00:34:38 +0800 Subject: [PATCH 1/2] handle missing edge types in graph data --- lightrag/types.py | 4 ++-- lightrag_webui/src/components/PropertiesView.tsx | 2 +- lightrag_webui/src/hooks/useLightragGraph.tsx | 12 ++++++++++-- lightrag_webui/src/stores/graph.ts | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/lightrag/types.py b/lightrag/types.py index 9c8e0099..d2670ddc 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional class GPTKeywordExtractionFormat(BaseModel): @@ -15,7 +15,7 @@ class KnowledgeGraphNode(BaseModel): class KnowledgeGraphEdge(BaseModel): id: str - type: str + type: Optional[str] source: str # id of source node target: str # id of target node properties: Dict[str, Any] # anything else goes here diff --git a/lightrag_webui/src/components/PropertiesView.tsx b/lightrag_webui/src/components/PropertiesView.tsx index 078420e6..dec80460 100644 --- a/lightrag_webui/src/components/PropertiesView.tsx +++ b/lightrag_webui/src/components/PropertiesView.tsx @@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
- + {edge.type && } { } for (const edge of graph.edges) { - if (!edge.id || !edge.source || !edge.target || !edge.type || !edge.properties) { + if (!edge.id || !edge.source || !edge.target) { return false } } @@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => { if (source !== undefined && source !== undefined) { const sourceNode = rawData.nodes[source] const targetNode = rawData.nodes[target] + if (!sourceNode) { + console.error(`Source node ${edge.source} is undefined`) + continue + } + if (!targetNode) { + console.error(`Target node ${edge.target} is undefined`) + continue + } sourceNode.degree += 1 targetNode.degree += 1 } @@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => { for (const rawEdge of rawGraph?.edges ?? []) { rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, { - label: rawEdge.type + label: rawEdge.type || undefined }) } diff --git a/lightrag_webui/src/stores/graph.ts b/lightrag_webui/src/stores/graph.ts index b78e9bf8..b7c2120c 100644 --- a/lightrag_webui/src/stores/graph.ts +++ b/lightrag_webui/src/stores/graph.ts @@ -19,7 +19,7 @@ export type RawEdgeType = { id: string source: string target: string - type: string + type?: string properties: Record dynamicId: string From a600beb619c8b784bb309c7f3dcec94e14573570 Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sat, 15 Feb 2025 00:38:41 +0800 Subject: [PATCH 2/2] implement MongoDB support for VectorDB storage. optimize existing MongoDB implementations --- lightrag/api/README.md | 3 +- lightrag/kg/mongo_impl.py | 510 +++++++++++++++++++++++++++++++++----- lightrag/lightrag.py | 3 + 3 files changed, 456 insertions(+), 60 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 8e5a61d5..18ab3594 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -177,7 +177,8 @@ TiDBVectorDBStorage TiDB PGVectorStorage Postgres FaissVectorDBStorage Faiss QdrantVectorDBStorage Qdrant -OracleVectorDBStorag Oracle +OracleVectorDBStorage Oracle +MongoVectorDBStorage MongoDB ``` * DOC_STATUS_STORAGE:supported implement-name diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 226aecf2..c216e7be 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -4,6 +4,7 @@ import numpy as np import pipmaster as pm import configparser from tqdm.asyncio import tqdm as tqdm_async +import asyncio if not pm.is_installed("pymongo"): pm.install("pymongo") @@ -14,16 +15,20 @@ if not pm.is_installed("motor"): from typing import Any, List, Tuple, Union from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient +from pymongo.operations import SearchIndexModel +from pymongo.errors import PyMongoError from ..base import ( BaseGraphStorage, BaseKVStorage, + BaseVectorStorage, DocProcessingStatus, DocStatus, DocStatusStorage, ) from ..namespace import NameSpace, is_namespace from ..utils import logger +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge config = configparser.ConfigParser() @@ -33,56 +38,66 @@ config.read("config.ini", "utf-8") @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): - client = MongoClient( - os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), ) + client = AsyncIOMotorClient(uri) database = client.get_database( os.environ.get( "MONGO_DATABASE", config.get("mongodb", "database", fallback="LightRAG"), ) ) - self._data = database.get_collection(self.namespace) - logger.info(f"Use MongoDB as KV {self.namespace}") + + self._collection_name = self.namespace + + self._data = database.get_collection(self._collection_name) + logger.debug(f"Use MongoDB as KV {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.find_one({"_id": id}) + return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - return list(self._data.find({"_id": {"$in": ids}})) + cursor = self._data.find({"_id": {"$in": ids}}) + return await cursor.to_list() async def filter_keys(self, data: set[str]) -> set[str]: - existing_ids = [ - str(x["_id"]) - for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) - ] - return set([s for s in data if s not in existing_ids]) + cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + existing_ids = {str(x["_id"]) async for x in cursor} + return data - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + update_tasks = [] for mode, items in data.items(): - for k, v in tqdm_async(items.items(), desc="Upserting"): + for k, v in items.items(): key = f"{mode}_{k}" - result = self._data.update_one( - {"_id": key}, {"$setOnInsert": v}, upsert=True + data[mode][k]["_id"] = f"{mode}_{k}" + update_tasks.append( + self._data.update_one( + {"_id": key}, {"$setOnInsert": v}, upsert=True + ) ) - if result.upserted_id: - logger.debug(f"\nInserted new document with key: {key}") - data[mode][k]["_id"] = key + await asyncio.gather(*update_tasks) else: - for k, v in tqdm_async(data.items(), desc="Upserting"): - self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + update_tasks = [] + for k, v in data.items(): data[k]["_id"] = k + update_tasks.append( + self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + ) + await asyncio.gather(*update_tasks) async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): res = {} - v = self._data.find_one({"_id": mode + "_" + id}) + v = await self._data.find_one({"_id": mode + "_" + id}) if v: res[id] = v logger.debug(f"llm_response_cache find one by:{id}") @@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage): @dataclass class MongoDocStatusStorage(DocStatusStorage): def __post_init__(self): - client = MongoClient( - os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), ) - database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG")) - self._data = database.get_collection(self.namespace) - logger.info(f"Use MongoDB as doc status {self.namespace}") + client = AsyncIOMotorClient(uri) + database = client.get_database( + os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + ) + + self._collection_name = self.namespace + self._data = database.get_collection(self._collection_name) + + logger.debug(f"Use MongoDB as doc status {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.find_one({"_id": id}) + return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - return list(self._data.find({"_id": {"$in": ids}})) + cursor = self._data.find({"_id": {"$in": ids}}) + return await cursor.to_list() async def filter_keys(self, data: set[str]) -> set[str]: - existing_ids = [ - str(x["_id"]) - for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) - ] - return set([s for s in data if s not in existing_ids]) + cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + existing_ids = {str(x["_id"]) async for x in cursor} + return data - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + update_tasks = [] for k, v in data.items(): - self._data.update_one({"_id": k}, {"$set": v}, upsert=True) data[k]["_id"] = k + update_tasks.append( + self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + ) + await asyncio.gather(*update_tasks) async def drop(self) -> None: """Drop the collection""" @@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage): async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}] - result = list(self._data.aggregate(pipeline)) + cursor = self._data.aggregate(pipeline) + result = await cursor.to_list() counts = {} for doc in result: counts[doc["_id"]] = doc["count"] @@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage): self, status: DocStatus ) -> dict[str, DocProcessingStatus]: """Get all documents by status""" - result = list(self._data.find({"status": status.value})) + cursor = self._data.find({"status": status.value}) + result = await cursor.to_list() return { doc["_id"]: DocProcessingStatus( content=doc["content"], @@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage): global_config=global_config, embedding_func=embedding_func, ) - self.client = AsyncIOMotorClient( - os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), ) - self.db = self.client[ + client = AsyncIOMotorClient(uri) + database = client.get_database( os.environ.get( "MONGO_DATABASE", - mongo_database=config.get("mongodb", "database", fallback="LightRAG"), + config.get("mongodb", "database", fallback="LightRAG"), ) - ] - self.collection = self.db[ - os.environ.get( - "MONGO_KG_COLLECTION", - config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"), - ) - ] + ) + + self._collection_name = self.namespace + self.collection = database.get_collection(self._collection_name) + + logger.debug(f"Use MongoDB as KG {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) # # ------------------------------------------------------------------------- @@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage): self, source_node_id: str ) -> Union[List[Tuple[str, str]], None]: """ - Return a list of (target_id, relation) for direct edges from source_node_id. + Return a list of (source_id, target_id) for direct edges from source_node_id. Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. """ pipeline = [ @@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage): return None edges = result[0].get("edges", []) - return [(e["target"], e["relation"]) for e in edges] + return [(source_node_id, e["target"]) for e in edges] # # ------------------------------------------------------------------------- @@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str): """ - 1) Remove node’s doc entirely. + 1) Remove node's doc entirely. 2) Remove inbound edges from any doc that references node_id. """ # Remove inbound edges from all other docs @@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage): Placeholder for demonstration, raises NotImplementedError. """ raise NotImplementedError("Node embedding is not used in lightrag.") + + # + # ------------------------------------------------------------------------- + # QUERY + # ------------------------------------------------------------------------- + # + + async def get_all_labels(self) -> list[str]: + """ + Get all existing node _id in the database + Returns: + [id1, id2, ...] # Alphabetically sorted id list + """ + # Use MongoDB's distinct and aggregation to get all unique labels + pipeline = [ + {"$group": {"_id": "$_id"}}, # Group by _id + {"$sort": {"_id": 1}}, # Sort alphabetically + ] + + cursor = self.collection.aggregate(pipeline) + labels = [] + async for doc in cursor: + labels.append(doc["_id"]) + return labels + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + """ + Get complete connected subgraph for specified node (including the starting node itself) + + Args: + node_label: Label of the nodes to start from + max_depth: Maximum depth of traversal (default: 5) + + Returns: + KnowledgeGraph object containing nodes and edges of the subgraph + """ + label = node_label + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + try: + if label == "*": + # Get all nodes and edges + async for node_doc in self.collection.find({}): + node_id = str(node_doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_doc.get("_id")], + properties={ + k: v + for k, v in node_doc.items() + if k not in ["_id", "edges"] + }, + ) + ) + seen_nodes.add(node_id) + + # Process edges + for edge in node_doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + else: + # Verify if starting node exists + start_nodes = self.collection.find({"_id": label}) + start_nodes_exist = await start_nodes.to_list(length=1) + if not start_nodes_exist: + logger.warning(f"Starting node with label {label} does not exist!") + return result + + # Use $graphLookup for traversal + pipeline = [ + { + "$match": {"_id": label} + }, # Start with nodes having the specified label + { + "$graphLookup": { + "from": self._collection_name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_nodes", + } + }, + ] + + async for doc in self.collection.aggregate(pipeline): + # Add the start node + node_id = str(doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[ + doc.get( + "_id", + ) + ], + properties={ + k: v + for k, v in doc.items() + if k + not in [ + "_id", + "edges", + "connected_nodes", + "depth", + ] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from start node + for edge in doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + # Add connected nodes and their edges + for connected in doc.get("connected_nodes", []): + node_id = str(connected["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[connected.get("_id")], + properties={ + k: v + for k, v in connected.items() + if k not in ["_id", "edges", "depth"] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from connected nodes + for edge in connected.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except PyMongoError as e: + logger.error(f"MongoDB query failed: {str(e)}") + + return result + + +@dataclass +class MongoVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = None + + def __post_init__(self): + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = kwargs.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) + self.cosine_better_than_threshold = cosine_threshold + + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), + ) + client = AsyncIOMotorClient(uri) + database = client.get_database( + os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + ) + + self._collection_name = self.namespace + self._data = database.get_collection(self._collection_name) + self._max_batch_size = self.global_config["embedding_batch_num"] + + logger.debug(f"Use MongoDB as VDB {self._collection_name}") + + # Ensure collection exists + create_collection_if_not_exists(uri, database.name, self._collection_name) + + # Ensure vector index exists + self.create_vector_index(uri, database.name, self._collection_name) + + def create_vector_index(self, uri: str, database_name: str, collection_name: str): + """Creates an Atlas Vector Search index.""" + client = MongoClient(uri) + collection = client.get_database(database_name).get_collection( + self._collection_name + ) + + try: + search_index_model = SearchIndexModel( + definition={ + "fields": [ + { + "type": "vector", + "numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions + "path": "vector", + "similarity": "cosine", # Options: euclidean, cosine, dotProduct + } + ] + }, + name="vector_knn_index", + type="vectorSearch", + ) + + collection.create_search_index(search_index_model) + logger.info("Vector index created successfully.") + + except PyMongoError as _: + logger.debug("vector index already exist") + + async def upsert(self, data: dict[str, dict]): + logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") + if not data: + logger.warning("You are inserting an empty data set to vector DB") + return [] + + list_data = [ + { + "_id": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) + embeddings_list = await asyncio.gather(*embedding_tasks) + + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist() + + update_tasks = [] + for doc in list_data: + update_tasks.append( + self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True) + ) + await asyncio.gather(*update_tasks) + + return list_data + + async def query(self, query, top_k=5): + """Queries the vector database using Atlas Vector Search.""" + # Generate the embedding + embedding = await self.embedding_func([query]) + + # Convert numpy array to a list to ensure compatibility with MongoDB + query_vector = embedding[0].tolist() + + # Define the aggregation pipeline with the converted query vector + pipeline = [ + { + "$vectorSearch": { + "index": "vector_knn_index", # Ensure this matches the created index name + "path": "vector", + "queryVector": query_vector, + "numCandidates": 100, # Adjust for performance + "limit": top_k, + } + }, + {"$addFields": {"score": {"$meta": "vectorSearchScore"}}}, + {"$match": {"score": {"$gte": self.cosine_better_than_threshold}}}, + {"$project": {"vector": 0}}, + ] + + # Execute the aggregation pipeline + cursor = self._data.aggregate(pipeline) + results = await cursor.to_list() + + # Format and return the results + return [ + {**doc, "id": doc["_id"], "distance": doc.get("score", None)} + for doc in results + ] + + +def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): + """Check if the collection exists. if not, create it.""" + client = MongoClient(uri) + database = client.get_database(database_name) + + collection_names = database.list_collection_names() + + if collection_name not in collection_names: + database.create_collection(collection_name) + logger.info(f"Created collection: {collection_name}") + else: + logger.debug(f"Collection '{collection_name}' already exists.") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9f74c917..ed0dec29 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -76,6 +76,7 @@ STORAGE_IMPLEMENTATIONS = { "FaissVectorDBStorage", "QdrantVectorDBStorage", "OracleVectorDBStorage", + "MongoVectorDBStorage", ], "required_methods": ["query", "upsert"], }, @@ -140,6 +141,7 @@ STORAGE_ENV_REQUIREMENTS = { "ORACLE_PASSWORD", "ORACLE_CONFIG_DIR", ], + "MongoVectorDBStorage": [], # Document Status Storage Implementations "JsonDocStatusStorage": [], "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], @@ -160,6 +162,7 @@ STORAGES = { "MongoKVStorage": ".kg.mongo_impl", "MongoDocStatusStorage": ".kg.mongo_impl", "MongoGraphStorage": ".kg.mongo_impl", + "MongoVectorDBStorage": ".kg.mongo_impl", "RedisKVStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl", "TiDBKVStorage": ".kg.tidb_impl",