diff --git a/examples/lightrag_openai_mongodb_graph_demo.py b/examples/lightrag_openai_mongodb_graph_demo.py new file mode 100644 index 00000000..775eb296 --- /dev/null +++ b/examples/lightrag_openai_mongodb_graph_demo.py @@ -0,0 +1,73 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed +from lightrag.utils import EmbeddingFunc +import numpy as np + +######### +# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() +# import nest_asyncio +# nest_asyncio.apply() +######### +WORKING_DIR = "./mongodb_test_dir" +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +os.environ["OPENAI_API_KEY"] = "sk-" +os.environ["MONGO_URI"] = "mongodb://0.0.0.0:27017/?directConnection=true" +os.environ["MONGO_DATABASE"] = "LightRAG" +os.environ["MONGO_KG_COLLECTION"] = "MDB_KG" + +# Embedding Configuration and Functions +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embed( + texts, + model=EMBEDDING_MODEL, + ) + + +async def get_embedding_dimension(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + return embedding.shape[1] + + +async def create_embedding_function_instance(): + # Get embedding dimension + embedding_dimension = await get_embedding_dimension() + # Create embedding function instance + return EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=EMBEDDING_MAX_TOKEN_SIZE, + func=embedding_func, + ) + + +async def initialize_rag(): + embedding_func_instance = await create_embedding_function_instance() + + return LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, + embedding_func=embedding_func_instance, + graph_storage="MongoGraphStorage", + log_level="DEBUG", + ) + + +# Run the initialization +rag = asyncio.run(initialize_rag()) + +with open("book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 65b1a39e..e162f5ec 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -48,18 +48,23 @@ def estimate_tokens(text: str) -> int: return int(tokens) -# Constants for emulated Ollama model information -LIGHTRAG_NAME = "lightrag" -LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") -LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" -LIGHTRAG_SIZE = 7365960935 # it's a dummy value -LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" -LIGHTRAG_DIGEST = "sha256:lightrag" +class OllamaServerInfos: + # Constants for emulated Ollama model information + LIGHTRAG_NAME = "lightrag" + LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") + LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" + LIGHTRAG_SIZE = 7365960935 # it's a dummy value + LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" + LIGHTRAG_DIGEST = "sha256:lightrag" -KV_STORAGE = "JsonKVStorage" -DOC_STATUS_STORAGE = "JsonDocStatusStorage" -GRAPH_STORAGE = "NetworkXStorage" -VECTOR_STORAGE = "NanoVectorDBStorage" + KV_STORAGE = "JsonKVStorage" + DOC_STATUS_STORAGE = "JsonDocStatusStorage" + GRAPH_STORAGE = "NetworkXStorage" + VECTOR_STORAGE = "NanoVectorDBStorage" + + +# Add infos +ollama_server_infos = OllamaServerInfos() # read config.ini config = configparser.ConfigParser() @@ -68,8 +73,8 @@ config.read("config.ini", "utf-8") redis_uri = config.get("redis", "uri", fallback=None) if redis_uri: os.environ["REDIS_URI"] = redis_uri - KV_STORAGE = "RedisKVStorage" - DOC_STATUS_STORAGE = "RedisKVStorage" + ollama_server_infos.KV_STORAGE = "RedisKVStorage" + ollama_server_infos.DOC_STATUS_STORAGE = "RedisKVStorage" # Neo4j config neo4j_uri = config.get("neo4j", "uri", fallback=None) @@ -79,7 +84,7 @@ if neo4j_uri: os.environ["NEO4J_URI"] = neo4j_uri os.environ["NEO4J_USERNAME"] = neo4j_username os.environ["NEO4J_PASSWORD"] = neo4j_password - GRAPH_STORAGE = "Neo4JStorage" + ollama_server_infos.GRAPH_STORAGE = "Neo4JStorage" # Milvus config milvus_uri = config.get("milvus", "uri", fallback=None) @@ -91,7 +96,7 @@ if milvus_uri: os.environ["MILVUS_USER"] = milvus_user os.environ["MILVUS_PASSWORD"] = milvus_password os.environ["MILVUS_DB_NAME"] = milvus_db_name - VECTOR_STORAGE = "MilvusVectorDBStorge" + ollama_server_infos.VECTOR_STORAGE = "MilvusVectorDBStorge" # MongoDB config mongo_uri = config.get("mongodb", "uri", fallback=None) @@ -99,8 +104,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None) if mongo_uri: os.environ["MONGO_URI"] = mongo_uri os.environ["MONGO_DATABASE"] = mongo_database - KV_STORAGE = "MongoKVStorage" - DOC_STATUS_STORAGE = "MongoKVStorage" + ollama_server_infos.KV_STORAGE = "MongoKVStorage" + ollama_server_infos.DOC_STATUS_STORAGE = "MongoKVStorage" def get_default_host(binding_type: str) -> str: @@ -217,7 +222,7 @@ def display_splash_screen(args: argparse.Namespace) -> None: # System Configuration ASCIIColors.magenta("\n🛠️ System Configuration:") ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="") - ASCIIColors.yellow(f"{LIGHTRAG_MODEL}") + ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") ASCIIColors.white(" ├─ Log Level: ", end="") ASCIIColors.yellow(f"{args.log_level}") ASCIIColors.white(" ├─ Timeout: ", end="") @@ -502,8 +507,19 @@ def parse_args() -> argparse.Namespace: help="Cosine similarity threshold (default: from env or 0.4)", ) + parser.add_argument( + "--simulated-model-name", + type=str, + default=get_env_value( + "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL + ), + help="Number of conversation history turns to include (default: from env or 3)", + ) + args = parser.parse_args() + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name + return args @@ -556,7 +572,7 @@ class OllamaMessage(BaseModel): class OllamaChatRequest(BaseModel): - model: str = LIGHTRAG_MODEL + model: str = ollama_server_infos.LIGHTRAG_MODEL messages: List[OllamaMessage] stream: bool = True # Default to streaming mode options: Optional[Dict[str, Any]] = None @@ -571,7 +587,7 @@ class OllamaChatResponse(BaseModel): class OllamaGenerateRequest(BaseModel): - model: str = LIGHTRAG_MODEL + model: str = ollama_server_infos.LIGHTRAG_MODEL prompt: str system: Optional[str] = None stream: bool = False @@ -860,10 +876,10 @@ def create_app(args): if args.llm_binding == "lollms" or args.llm_binding == "ollama" else {}, embedding_func=embedding_func, - kv_storage=KV_STORAGE, - graph_storage=GRAPH_STORAGE, - vector_storage=VECTOR_STORAGE, - doc_status_storage=DOC_STATUS_STORAGE, + kv_storage=ollama_server_infos.KV_STORAGE, + graph_storage=ollama_server_infos.GRAPH_STORAGE, + vector_storage=ollama_server_infos.VECTOR_STORAGE, + doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE, vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, @@ -883,10 +899,10 @@ def create_app(args): llm_model_max_async=args.max_async, llm_model_max_token_size=args.max_tokens, embedding_func=embedding_func, - kv_storage=KV_STORAGE, - graph_storage=GRAPH_STORAGE, - vector_storage=VECTOR_STORAGE, - doc_status_storage=DOC_STATUS_STORAGE, + kv_storage=ollama_server_infos.KV_STORAGE, + graph_storage=ollama_server_infos.GRAPH_STORAGE, + vector_storage=ollama_server_infos.VECTOR_STORAGE, + doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE, vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, @@ -1452,16 +1468,16 @@ def create_app(args): return OllamaTagResponse( models=[ { - "name": LIGHTRAG_MODEL, - "model": LIGHTRAG_MODEL, - "size": LIGHTRAG_SIZE, - "digest": LIGHTRAG_DIGEST, - "modified_at": LIGHTRAG_CREATED_AT, + "name": ollama_server_infos.LIGHTRAG_MODEL, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "size": ollama_server_infos.LIGHTRAG_SIZE, + "digest": ollama_server_infos.LIGHTRAG_DIGEST, + "modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "details": { "parent_model": "", "format": "gguf", - "family": LIGHTRAG_NAME, - "families": [LIGHTRAG_NAME], + "family": ollama_server_infos.LIGHTRAG_NAME, + "families": [ollama_server_infos.LIGHTRAG_NAME], "parameter_size": "13B", "quantization_level": "Q4_0", }, @@ -1524,8 +1540,8 @@ def create_app(args): total_response = response data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": response, "done": False, } @@ -1537,8 +1553,8 @@ def create_app(args): eval_time = last_chunk_time - first_chunk_time data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "done": True, "total_duration": total_time, "load_duration": 0, @@ -1558,8 +1574,8 @@ def create_app(args): total_response += chunk data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": chunk, "done": False, } @@ -1571,8 +1587,8 @@ def create_app(args): eval_time = last_chunk_time - first_chunk_time data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "done": True, "total_duration": total_time, "load_duration": 0, @@ -1616,8 +1632,8 @@ def create_app(args): eval_time = last_chunk_time - first_chunk_time return { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": str(response_text), "done": True, "total_duration": total_time, @@ -1690,8 +1706,8 @@ def create_app(args): total_response = response data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": response, @@ -1707,8 +1723,8 @@ def create_app(args): eval_time = last_chunk_time - first_chunk_time data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "done": True, "total_duration": total_time, "load_duration": 0, @@ -1728,8 +1744,8 @@ def create_app(args): total_response += chunk data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": chunk, @@ -1745,8 +1761,8 @@ def create_app(args): eval_time = last_chunk_time - first_chunk_time data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "done": True, "total_duration": total_time, "load_duration": 0, @@ -1801,8 +1817,8 @@ def create_app(args): eval_time = last_chunk_time - first_chunk_time return { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": str(response_text), @@ -1845,10 +1861,10 @@ def create_app(args): "embedding_binding_host": args.embedding_binding_host, "embedding_model": args.embedding_model, "max_tokens": args.max_tokens, - "kv_storage": KV_STORAGE, - "doc_status_storage": DOC_STATUS_STORAGE, - "graph_storage": GRAPH_STORAGE, - "vector_storage": VECTOR_STORAGE, + "kv_storage": ollama_server_infos.KV_STORAGE, + "doc_status_storage": ollama_server_infos.DOC_STATUS_STORAGE, + "graph_storage": ollama_server_infos.GRAPH_STORAGE, + "vector_storage": ollama_server_infos.VECTOR_STORAGE, }, } diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 21365a70..66331520 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -2,15 +2,18 @@ import os from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm +import np if not pm.is_installed("pymongo"): pm.install("pymongo") from pymongo import MongoClient -from typing import Union +from motor.motor_asyncio import AsyncIOMotorClient +from typing import Union, List, Tuple from lightrag.utils import logger from lightrag.base import BaseKVStorage +from lightrag.base import BaseGraphStorage @dataclass @@ -78,3 +81,360 @@ class MongoKVStorage(BaseKVStorage): async def drop(self): """ """ pass + + +@dataclass +class MongoGraphStorage(BaseGraphStorage): + """ + A concrete implementation using MongoDB’s $graphLookup to demonstrate multi-hop queries. + """ + + def __init__(self, namespace, global_config, embedding_func): + super().__init__( + namespace=namespace, + global_config=global_config, + embedding_func=embedding_func, + ) + self.client = AsyncIOMotorClient( + os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + ) + self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")] + self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")] + + # + # ------------------------------------------------------------------------- + # HELPER: $graphLookup pipeline + # ------------------------------------------------------------------------- + # + + async def _graph_lookup( + self, start_node_id: str, max_depth: int = None + ) -> List[dict]: + """ + Performs a $graphLookup starting from 'start_node_id' and returns + all reachable documents (including the start node itself). + + Pipeline Explanation: + - 1) $match: We match the start node document by _id = start_node_id. + - 2) $graphLookup: + "from": same collection, + "startWith": "$edges.target" (the immediate neighbors in 'edges'), + "connectFromField": "edges.target", + "connectToField": "_id", + "as": "reachableNodes", + "maxDepth": max_depth (if provided), + "depthField": "depth" (used for debugging or filtering). + - 3) We add an $project or $unwind as needed to extract data. + """ + pipeline = [ + {"$match": {"_id": start_node_id}}, + { + "$graphLookup": { + "from": self.collection.name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "as": "reachableNodes", + "depthField": "depth", + } + }, + ] + + # If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth + if max_depth is not None: + pipeline[1]["$graphLookup"]["maxDepth"] = max_depth + + # Return the matching doc plus a field "reachableNodes" + cursor = self.collection.aggregate(pipeline) + results = await cursor.to_list(None) + + # If there's no matching node, results = []. + # Otherwise, results[0] is the start node doc, + # plus results[0]["reachableNodes"] is the array of connected docs. + return results + + # + # ------------------------------------------------------------------------- + # BASIC QUERIES + # ------------------------------------------------------------------------- + # + + async def has_node(self, node_id: str) -> bool: + """ + Check if node_id is present in the collection by looking up its doc. + No real need for $graphLookup here, but let's keep it direct. + """ + doc = await self.collection.find_one({"_id": node_id}, {"_id": 1}) + return doc is not None + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """ + Check if there's a direct single-hop edge from source_node_id to target_node_id. + + We'll do a $graphLookup with maxDepth=0 from the source node—meaning + “Look up zero expansions.” Actually, for a direct edge check, we can do maxDepth=1 + and then see if the target node is in the "reachableNodes" at depth=0. + + But typically for a direct edge, we might just do a find_one. + Below is a demonstration approach. + """ + + # We can do a single-hop graphLookup (maxDepth=0 or 1). + # Then check if the target_node appears among the edges array. + pipeline = [ + {"$match": {"_id": source_node_id}}, + { + "$graphLookup": { + "from": self.collection.name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "as": "reachableNodes", + "depthField": "depth", + "maxDepth": 0, # means: do not follow beyond immediate edges + } + }, + { + "$project": { + "_id": 0, + "reachableNodes._id": 1, # only keep the _id from the subdocs + } + }, + ] + cursor = self.collection.aggregate(pipeline) + results = await cursor.to_list(None) + if not results: + return False + + # results[0]["reachableNodes"] are the immediate neighbors + reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])] + return target_node_id in reachable_ids + + # + # ------------------------------------------------------------------------- + # DEGREES + # ------------------------------------------------------------------------- + # + + async def node_degree(self, node_id: str) -> int: + """ + Returns the total number of edges connected to node_id (both inbound and outbound). + The easiest approach is typically two queries: + - count of edges array in node_id's doc + - count of how many other docs have node_id in their edges.target. + + But we'll do a $graphLookup demonstration for inbound edges: + 1) Outbound edges: direct from node's edges array + 2) Inbound edges: we can do a special $graphLookup from all docs + or do an explicit match. + + For demonstration, let's do this in two steps (with second step $graphLookup). + """ + # --- 1) Outbound edges (direct from doc) --- + doc = await self.collection.find_one({"_id": node_id}, {"edges": 1}) + if not doc: + return 0 + outbound_count = len(doc.get("edges", [])) + + # --- 2) Inbound edges: + # A simple way is: find all docs where "edges.target" == node_id. + # But let's do a $graphLookup from `node_id` in REVERSE. + # There's a trick to do "reverse" graphLookups: you'd store + # reversed edges or do a more advanced pipeline. Typically you'd do + # a direct match. We'll just do a direct match for inbound. + inbound_count_pipeline = [ + {"$match": {"edges.target": node_id}}, + { + "$project": { + "matchingEdgesCount": { + "$size": { + "$filter": { + "input": "$edges", + "as": "edge", + "cond": {"$eq": ["$$edge.target", node_id]}, + } + } + } + } + }, + {"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}}, + ] + inbound_cursor = self.collection.aggregate(inbound_count_pipeline) + inbound_result = await inbound_cursor.to_list(None) + inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0 + + return outbound_count + inbound_count + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """ + If your graph can hold multiple edges from the same src to the same tgt + (e.g. different 'relation' values), you can sum them. If it's always + one edge, this is typically 1 or 0. + + We'll do a single-hop $graphLookup from src_id, + then count how many edges reference tgt_id at depth=0. + """ + pipeline = [ + {"$match": {"_id": src_id}}, + { + "$graphLookup": { + "from": self.collection.name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "as": "neighbors", + "depthField": "depth", + "maxDepth": 0, + } + }, + {"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}}, + ] + cursor = self.collection.aggregate(pipeline) + results = await cursor.to_list(None) + if not results: + return 0 + + # We can simply count how many edges in `results[0].edges` have target == tgt_id. + edges = results[0].get("edges", []) + count = sum(1 for e in edges if e.get("target") == tgt_id) + return count + + # + # ------------------------------------------------------------------------- + # GETTERS + # ------------------------------------------------------------------------- + # + + async def get_node(self, node_id: str) -> Union[dict, None]: + """ + Return the full node document (including "edges"), or None if missing. + """ + return await self.collection.find_one({"_id": node_id}) + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> Union[dict, None]: + """ + Return the first edge dict from source_node_id to target_node_id if it exists. + Uses a single-hop $graphLookup as demonstration, though a direct find is simpler. + """ + pipeline = [ + {"$match": {"_id": source_node_id}}, + { + "$graphLookup": { + "from": self.collection.name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "as": "neighbors", + "depthField": "depth", + "maxDepth": 0, + } + }, + {"$project": {"edges": 1}}, + ] + cursor = self.collection.aggregate(pipeline) + docs = await cursor.to_list(None) + if not docs: + return None + + for e in docs[0].get("edges", []): + if e.get("target") == target_node_id: + return e + return None + + async def get_node_edges( + self, source_node_id: str + ) -> Union[List[Tuple[str, str]], None]: + """ + Return a list of (target_id, relation) for direct edges from source_node_id. + Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. + """ + pipeline = [ + {"$match": {"_id": source_node_id}}, + { + "$graphLookup": { + "from": self.collection.name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "as": "neighbors", + "depthField": "depth", + "maxDepth": 0, + } + }, + {"$project": {"_id": 0, "edges": 1}}, + ] + cursor = self.collection.aggregate(pipeline) + result = await cursor.to_list(None) + if not result: + return None + + edges = result[0].get("edges", []) + return [(e["target"], e["relation"]) for e in edges] + + # + # ------------------------------------------------------------------------- + # UPSERTS + # ------------------------------------------------------------------------- + # + + async def upsert_node(self, node_id: str, node_data: dict): + """ + Insert or update a node document. If new, create an empty edges array. + """ + # By default, preserve existing 'edges'. + # We'll only set 'edges' to [] on insert (no overwrite). + update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}} + await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict + ): + """ + Upsert an edge from source_node_id -> target_node_id with optional 'relation'. + If an edge with the same target exists, we remove it and re-insert with updated data. + """ + # Ensure source node exists + await self.upsert_node(source_node_id, {}) + + # Remove existing edge (if any) + await self.collection.update_one( + {"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}} + ) + + # Insert new edge + new_edge = {"target": target_node_id} + new_edge.update(edge_data) + await self.collection.update_one( + {"_id": source_node_id}, {"$push": {"edges": new_edge}} + ) + + # + # ------------------------------------------------------------------------- + # DELETION + # ------------------------------------------------------------------------- + # + + async def delete_node(self, node_id: str): + """ + 1) Remove node’s doc entirely. + 2) Remove inbound edges from any doc that references node_id. + """ + # Remove inbound edges from all other docs + await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}}) + + # Remove the node doc + await self.collection.delete_one({"_id": node_id}) + + # + # ------------------------------------------------------------------------- + # EMBEDDINGS (NOT IMPLEMENTED) + # ------------------------------------------------------------------------- + # + + async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]: + """ + Placeholder for demonstration, raises NotImplementedError. + """ + raise NotImplementedError("Node embedding is not used in lightrag.") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b40eecaa..acad9295 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -48,6 +48,7 @@ STORAGES = { "OracleVectorDBStorage": ".kg.oracle_impl", "MilvusVectorDBStorge": ".kg.milvus_impl", "MongoKVStorage": ".kg.mongo_impl", + "MongoGraphStorage": ".kg.mongo_impl", "RedisKVStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl", "TiDBKVStorage": ".kg.tidb_impl",