diff --git a/.dockerignore b/.dockerignore
index 4c49bd78..f1a82ffa 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1 +1,63 @@
-.env
+# Python-related files and directories
+__pycache__
+.cache
+
+# Virtual environment directories
+*.venv
+
+# Env
+env/
+*.env*
+.env_example
+
+# Distribution / build files
+site
+dist/
+build/
+.eggs/
+*.egg-info/
+*.tgz
+*.tar.gz
+
+# Exclude siles and folders
+*.yml
+.dockerignore
+Dockerfile
+Makefile
+
+# Exclude other projects
+/tests
+/scripts
+
+# Python version manager file
+.python-version
+
+# Reports
+*.coverage/
+*.log
+log/
+*.logfire
+
+# Cache
+.cache/
+.mypy_cache
+.pytest_cache
+.ruff_cache
+.gradio
+.logfire
+temp/
+
+# MacOS-related files
+.DS_Store
+
+# VS Code settings (local configuration files)
+.vscode
+
+# file
+TODO.md
+
+# Exclude Git-related files
+.git
+.github
+.gitignore
+.pre-commit-config.yaml
diff --git a/.gitignore b/.gitignore
index 83246d18..58c9e17e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -35,23 +35,27 @@ temp/
# IDE / Editor Files
.idea/
-dist/
-env/
+.vscode/
+.vscode/settings.json
+
+# Framework-specific files
local_neo4jWorkDir/
neo4jWorkDir/
-ignore_this.txt
-.venv/
-*.ignore.*
-.ruff_cache/
-gui/
-*.log
-.vscode
-inputs
-rag_storage
-.env
-venv/
+
+# Data & Storage
+inputs/
+rag_storage/
examples/input/
examples/output/
+
+# Miscellaneous
.DS_Store
-#Remove config.ini from repo
-*.ini
+TODO.md
+ignore_this.txt
+*.ignore.*
+
+# Project-specific files
+dickens/
+book.txt
+lightrag-dev/
+gui/
diff --git a/README.md b/README.md
index 62f21a65..3ccdef08 100644
--- a/README.md
+++ b/README.md
@@ -237,7 +237,7 @@ rag = LightRAG(
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
```python
-from lightrag.llm import hf_model_complete, hf_embedding
+from lightrag.llm import hf_model_complete, hf_embed
from transformers import AutoModel, AutoTokenizer
from lightrag.utils import EmbeddingFunc
@@ -250,7 +250,7 @@ rag = LightRAG(
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
- func=lambda texts: hf_embedding(
+ func=lambda texts: hf_embed(
texts,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
@@ -428,9 +428,9 @@ And using a routine to process news documents.
```python
rag = LightRAG(..)
-await rag.apipeline_enqueue_documents(string_or_strings)
+await rag.apipeline_enqueue_documents(input)
# Your routine in loop
-await rag.apipeline_process_enqueue_documents(string_or_strings)
+await rag.apipeline_process_enqueue_documents(input)
```
### Separate Keyword Extraction
diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py
index f5269fae..9c90424e 100644
--- a/examples/lightrag_oracle_demo.py
+++ b/examples/lightrag_oracle_demo.py
@@ -113,7 +113,24 @@ async def main():
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
- rag.set_storage_client(db_client=oracle_db)
+
+ for storage in [
+ rag.vector_db_storage_cls,
+ rag.graph_storage_cls,
+ rag.doc_status,
+ rag.full_docs,
+ rag.text_chunks,
+ rag.llm_response_cache,
+ rag.key_string_value_json_storage_cls,
+ rag.chunks_vdb,
+ rag.relationships_vdb,
+ rag.entities_vdb,
+ rag.graph_storage_cls,
+ rag.chunk_entity_relation_graph,
+ rag.llm_response_cache,
+ ]:
+ # set client
+ storage.db = oracle_db
# Extract and Insert into LightRAG storage
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
diff --git a/examples/test_chromadb.py b/examples/test_chromadb.py
index 0e6361ed..99090a6d 100644
--- a/examples/test_chromadb.py
+++ b/examples/test_chromadb.py
@@ -15,6 +15,12 @@ if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# ChromaDB Configuration
+CHROMADB_USE_LOCAL_PERSISTENT = False
+# Local PersistentClient Configuration
+CHROMADB_LOCAL_PATH = os.environ.get(
+ "CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
+)
+# Remote HttpClient Configuration
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
@@ -60,30 +66,50 @@ async def create_embedding_function_instance():
async def initialize_rag():
embedding_func_instance = await create_embedding_function_instance()
-
- return LightRAG(
- working_dir=WORKING_DIR,
- llm_model_func=gpt_4o_mini_complete,
- embedding_func=embedding_func_instance,
- vector_storage="ChromaVectorDBStorage",
- log_level="DEBUG",
- embedding_batch_num=32,
- vector_db_storage_cls_kwargs={
- "host": CHROMADB_HOST,
- "port": CHROMADB_PORT,
- "auth_token": CHROMADB_AUTH_TOKEN,
- "auth_provider": CHROMADB_AUTH_PROVIDER,
- "auth_header_name": CHROMADB_AUTH_HEADER,
- "collection_settings": {
- "hnsw:space": "cosine",
- "hnsw:construction_ef": 128,
- "hnsw:search_ef": 128,
- "hnsw:M": 16,
- "hnsw:batch_size": 100,
- "hnsw:sync_threshold": 1000,
+ if CHROMADB_USE_LOCAL_PERSISTENT:
+ return LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=gpt_4o_mini_complete,
+ embedding_func=embedding_func_instance,
+ vector_storage="ChromaVectorDBStorage",
+ log_level="DEBUG",
+ embedding_batch_num=32,
+ vector_db_storage_cls_kwargs={
+ "local_path": CHROMADB_LOCAL_PATH,
+ "collection_settings": {
+ "hnsw:space": "cosine",
+ "hnsw:construction_ef": 128,
+ "hnsw:search_ef": 128,
+ "hnsw:M": 16,
+ "hnsw:batch_size": 100,
+ "hnsw:sync_threshold": 1000,
+ },
},
- },
- )
+ )
+ else:
+ return LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=gpt_4o_mini_complete,
+ embedding_func=embedding_func_instance,
+ vector_storage="ChromaVectorDBStorage",
+ log_level="DEBUG",
+ embedding_batch_num=32,
+ vector_db_storage_cls_kwargs={
+ "host": CHROMADB_HOST,
+ "port": CHROMADB_PORT,
+ "auth_token": CHROMADB_AUTH_TOKEN,
+ "auth_provider": CHROMADB_AUTH_PROVIDER,
+ "auth_header_name": CHROMADB_AUTH_HEADER,
+ "collection_settings": {
+ "hnsw:space": "cosine",
+ "hnsw:construction_ef": 128,
+ "hnsw:search_ef": 128,
+ "hnsw:M": 16,
+ "hnsw:batch_size": 100,
+ "hnsw:sync_threshold": 1000,
+ },
+ },
+ )
# Run the initialization
diff --git a/external_bindings/OpenWebuiTool/openwebui_tool.py b/external_bindings/OpenWebuiTool/openwebui_tool.py
deleted file mode 100644
index 8df3109c..00000000
--- a/external_bindings/OpenWebuiTool/openwebui_tool.py
+++ /dev/null
@@ -1,358 +0,0 @@
-"""
-OpenWebui Lightrag Integration Tool
-==================================
-
-This tool enables the integration and use of Lightrag within the OpenWebui environment,
-providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
-
-Author: ParisNeo (parisneoai@gmail.com)
-Social:
- - Twitter: @ParisNeo_AI
- - Reddit: r/lollms
- - Instagram: https://www.instagram.com/parisneo_ai/
-
-License: Apache 2.0
-Copyright (c) 2024-2025 ParisNeo
-
-This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
-For more information, visit: https://github.com/ParisNeo/lollms
-
-Requirements:
- - Python 3.8+
- - OpenWebui
- - Lightrag
-"""
-
-# Tool version
-__version__ = "1.0.0"
-__author__ = "ParisNeo"
-__author_email__ = "parisneoai@gmail.com"
-__description__ = "Lightrag integration for OpenWebui"
-
-
-import requests
-import json
-from pydantic import BaseModel, Field
-from typing import Callable, Any, Literal, Union, List, Tuple
-
-
-class StatusEventEmitter:
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
- self.event_emitter = event_emitter
-
- async def emit(self, description="Unknown State", status="in_progress", done=False):
- if self.event_emitter:
- await self.event_emitter(
- {
- "type": "status",
- "data": {
- "status": status,
- "description": description,
- "done": done,
- },
- }
- )
-
-
-class MessageEventEmitter:
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
- self.event_emitter = event_emitter
-
- async def emit(self, content="Some message"):
- if self.event_emitter:
- await self.event_emitter(
- {
- "type": "message",
- "data": {
- "content": content,
- },
- }
- )
-
-
-class Tools:
- class Valves(BaseModel):
- LIGHTRAG_SERVER_URL: str = Field(
- default="http://localhost:9621/query",
- description="The base URL for the LightRag server",
- )
- MODE: Literal["naive", "local", "global", "hybrid"] = Field(
- default="hybrid",
- description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
- )
- ONLY_NEED_CONTEXT: bool = Field(
- default=False,
- description="If True, only the context is needed from the LightRag response",
- )
- DEBUG_MODE: bool = Field(
- default=False,
- description="If True, debugging information will be emitted",
- )
- KEY: str = Field(
- default="",
- description="Optional Bearer Key for authentication",
- )
- MAX_ENTITIES: int = Field(
- default=5,
- description="Maximum number of entities to keep",
- )
- MAX_RELATIONSHIPS: int = Field(
- default=5,
- description="Maximum number of relationships to keep",
- )
- MAX_SOURCES: int = Field(
- default=3,
- description="Maximum number of sources to keep",
- )
-
- def __init__(self):
- self.valves = self.Valves()
- self.headers = {
- "Content-Type": "application/json",
- "User-Agent": "LightRag-Tool/1.0",
- }
-
- async def query_lightrag(
- self,
- query: str,
- __event_emitter__: Callable[[dict], Any] = None,
- ) -> str:
- """
- Query the LightRag server and retrieve information.
- This function must be called before answering the user question
- :params query: The query string to send to the LightRag server.
- :return: The response from the LightRag server in Markdown format or raw response.
- """
- self.status_emitter = StatusEventEmitter(__event_emitter__)
- self.message_emitter = MessageEventEmitter(__event_emitter__)
-
- lightrag_url = self.valves.LIGHTRAG_SERVER_URL
- payload = {
- "query": query,
- "mode": str(self.valves.MODE),
- "stream": False,
- "only_need_context": self.valves.ONLY_NEED_CONTEXT,
- }
- await self.status_emitter.emit("Initializing Lightrag query..")
-
- if self.valves.DEBUG_MODE:
- await self.message_emitter.emit(
- "### Debug Mode Active\n\nDebugging information will be displayed.\n"
- )
- await self.message_emitter.emit(
- "#### Payload Sent to LightRag Server\n```json\n"
- + json.dumps(payload, indent=4)
- + "\n```\n"
- )
-
- # Add Bearer Key to headers if provided
- if self.valves.KEY:
- self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
-
- try:
- await self.status_emitter.emit("Sending request to LightRag server")
-
- response = requests.post(
- lightrag_url, json=payload, headers=self.headers, timeout=120
- )
- response.raise_for_status()
- data = response.json()
- await self.status_emitter.emit(
- status="complete",
- description="LightRag query Succeeded",
- done=True,
- )
-
- # Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
- if self.valves.ONLY_NEED_CONTEXT:
- try:
- if self.valves.DEBUG_MODE:
- await self.message_emitter.emit(
- "#### LightRag Server Response\n```json\n"
- + data["response"]
- + "\n```\n"
- )
- except Exception as ex:
- if self.valves.DEBUG_MODE:
- await self.message_emitter.emit(
- "#### Exception\n" + str(ex) + "\n"
- )
- return f"Exception: {ex}"
- return data["response"]
- else:
- if self.valves.DEBUG_MODE:
- await self.message_emitter.emit(
- "#### LightRag Server Response\n```json\n"
- + data["response"]
- + "\n```\n"
- )
- await self.status_emitter.emit("Lightrag query success")
- return data["response"]
-
- except requests.exceptions.RequestException as e:
- await self.status_emitter.emit(
- status="error",
- description=f"Error during LightRag query: {str(e)}",
- done=True,
- )
- return json.dumps({"error": str(e)})
-
- def extract_code_blocks(
- self, text: str, return_remaining_text: bool = False
- ) -> Union[List[dict], Tuple[List[dict], str]]:
- """
- This function extracts code blocks from a given text and optionally returns the text without code blocks.
-
- Parameters:
- text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
- return_remaining_text (bool): If True, also returns the text with code blocks removed.
-
- Returns:
- Union[List[dict], Tuple[List[dict], str]]:
- - If return_remaining_text is False: Returns only the list of code block dictionaries
- - If return_remaining_text is True: Returns a tuple containing:
- * List of code block dictionaries
- * String containing the text with all code blocks removed
-
- Each code block dictionary contains:
- - 'index' (int): The index of the code block in the text
- - 'file_name' (str): The name of the file extracted from the preceding line, if available
- - 'content' (str): The content of the code block
- - 'type' (str): The type of the code block
- - 'is_complete' (bool): True if the block has a closing tag, False otherwise
- """
- remaining = text
- bloc_index = 0
- first_index = 0
- indices = []
- text_without_blocks = text
-
- # Find all code block delimiters
- while len(remaining) > 0:
- try:
- index = remaining.index("```")
- indices.append(index + first_index)
- remaining = remaining[index + 3 :]
- first_index += index + 3
- bloc_index += 1
- except Exception:
- if bloc_index % 2 == 1:
- index = len(remaining)
- indices.append(index)
- remaining = ""
-
- code_blocks = []
- is_start = True
-
- # Process code blocks and build text without blocks if requested
- if return_remaining_text:
- text_parts = []
- last_end = 0
-
- for index, code_delimiter_position in enumerate(indices):
- if is_start:
- block_infos = {
- "index": len(code_blocks),
- "file_name": "",
- "section": "",
- "content": "",
- "type": "",
- "is_complete": False,
- }
-
- # Store text before code block if returning remaining text
- if return_remaining_text:
- text_parts.append(text[last_end:code_delimiter_position].strip())
-
- # Check the preceding line for file name
- preceding_text = text[:code_delimiter_position].strip().splitlines()
- if preceding_text:
- last_line = preceding_text[-1].strip()
- if last_line.startswith("") and last_line.endswith(
- ""
- ):
- file_name = last_line[
- len("") : -len("")
- ].strip()
- block_infos["file_name"] = file_name
- elif last_line.startswith("## filename:"):
- file_name = last_line[len("## filename:") :].strip()
- block_infos["file_name"] = file_name
- if last_line.startswith("") and last_line.endswith(
- ""
- ):
- section = last_line[
- len("")
- ].strip()
- block_infos["section"] = section
-
- sub_text = text[code_delimiter_position + 3 :]
- if len(sub_text) > 0:
- try:
- find_space = sub_text.index(" ")
- except Exception:
- find_space = int(1e10)
- try:
- find_return = sub_text.index("\n")
- except Exception:
- find_return = int(1e10)
- next_index = min(find_return, find_space)
- if "{" in sub_text[:next_index]:
- next_index = 0
- start_pos = next_index
-
- if code_delimiter_position + 3 < len(text) and text[
- code_delimiter_position + 3
- ] in ["\n", " ", "\t"]:
- block_infos["type"] = "language-specific"
- else:
- block_infos["type"] = sub_text[:next_index]
-
- if index + 1 < len(indices):
- next_pos = indices[index + 1] - code_delimiter_position
- if (
- next_pos - 3 < len(sub_text)
- and sub_text[next_pos - 3] == "`"
- ):
- block_infos["content"] = sub_text[
- start_pos : next_pos - 3
- ].strip()
- block_infos["is_complete"] = True
- else:
- block_infos["content"] = sub_text[
- start_pos:next_pos
- ].strip()
- block_infos["is_complete"] = False
-
- if return_remaining_text:
- last_end = indices[index + 1] + 3
- else:
- block_infos["content"] = sub_text[start_pos:].strip()
- block_infos["is_complete"] = False
-
- if return_remaining_text:
- last_end = len(text)
-
- code_blocks.append(block_infos)
- is_start = False
- else:
- is_start = True
-
- if return_remaining_text:
- # Add any remaining text after the last code block
- if last_end < len(text):
- text_parts.append(text[last_end:].strip())
- # Join all non-code parts with newlines
- text_without_blocks = "\n".join(filter(None, text_parts))
- return code_blocks, text_without_blocks
-
- return code_blocks
-
- def clean(self, csv_content: str):
- lines = csv_content.splitlines()
- if lines:
- # Remove spaces around headers and ensure no spaces between commas
- header = ",".join([col.strip() for col in lines[0].split(",")])
- lines[0] = header # Replace the first line with the cleaned header
- csv_content = "\n".join(lines)
- return csv_content
diff --git a/lightrag/api/README.md b/lightrag/api/README.md
index 06510618..d48b6732 100644
--- a/lightrag/api/README.md
+++ b/lightrag/api/README.md
@@ -185,7 +185,8 @@ TiDBVectorDBStorage TiDB
PGVectorStorage Postgres
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
-OracleVectorDBStorag Oracle
+OracleVectorDBStorage Oracle
+MongoVectorDBStorage MongoDB
```
* DOC_STATUS_STORAGE:supported implement-name
diff --git a/lightrag/base.py b/lightrag/base.py
index ca1fac7f..aafa97b2 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
from dotenv import load_dotenv
from dataclasses import dataclass, field
@@ -5,10 +7,8 @@ from enum import Enum
from typing import (
Any,
Literal,
- Optional,
TypedDict,
TypeVar,
- Union,
)
import numpy as np
from .utils import EmbeddingFunc
@@ -72,7 +72,7 @@ class QueryParam:
ll_keywords: list[str] = field(default_factory=list)
"""List of low-level keywords to refine retrieval focus."""
- conversation_history: list[dict[str, Any]] = field(default_factory=list)
+ conversation_history: list[dict[str, str]] = field(default_factory=list)
"""Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}].
"""
@@ -86,19 +86,15 @@ class StorageNameSpace:
namespace: str
global_config: dict[str, Any]
- async def index_done_callback(self):
+ async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing"""
pass
- async def query_done_callback(self):
- """Commit the storage operations after querying"""
- pass
-
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
- meta_fields: set = field(default_factory=set)
+ meta_fields: set[str] = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
raise NotImplementedError
@@ -109,12 +105,20 @@ class BaseVectorStorage(StorageNameSpace):
"""
raise NotImplementedError
+ async def delete_entity(self, entity_name: str) -> None:
+ """Delete a single entity by its name"""
+ raise NotImplementedError
+
+ async def delete_entity_relation(self, entity_name: str) -> None:
+ """Delete relations for a given entity by scanning metadata"""
+ raise NotImplementedError
+
@dataclass
class BaseKVStorage(StorageNameSpace):
- embedding_func: EmbeddingFunc
+ embedding_func: EmbeddingFunc | None = None
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
raise NotImplementedError
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -133,50 +137,75 @@ class BaseKVStorage(StorageNameSpace):
@dataclass
class BaseGraphStorage(StorageNameSpace):
- embedding_func: EmbeddingFunc = None
+ embedding_func: EmbeddingFunc | None = None
+ """Check if a node exists in the graph."""
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
+ """Check if an edge exists in the graph."""
+
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
+ """Get the degree of a node."""
+
async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
+ """Get the degree of an edge."""
+
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
- async def get_node(self, node_id: str) -> Union[dict, None]:
+ """Get a node by its id."""
+
+ async def get_node(self, node_id: str) -> dict[str, str] | None:
raise NotImplementedError
+ """Get an edge by its source and target node ids."""
+
async def get_edge(
self, source_node_id: str, target_node_id: str
- ) -> Union[dict, None]:
+ ) -> dict[str, str] | None:
raise NotImplementedError
- async def get_node_edges(
- self, source_node_id: str
- ) -> Union[list[tuple[str, str]], None]:
+ """Get all edges connected to a node."""
+
+ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
raise NotImplementedError
- async def upsert_node(self, node_id: str, node_data: dict[str, str]):
+ """Upsert a node into the graph."""
+
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
raise NotImplementedError
+ """Upsert an edge into the graph."""
+
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
- ):
+ ) -> None:
raise NotImplementedError
- async def delete_node(self, node_id: str):
+ """Delete a node from the graph."""
+
+ async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
- async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
+ """Embed nodes using an algorithm."""
+
+ async def embed_nodes(
+ self, algorithm: str
+ ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
+ """Get all labels in the graph."""
+
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
+ """Get a knowledge graph of a node."""
+
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
@@ -208,9 +237,9 @@ class DocProcessingStatus:
"""ISO format timestamp when document was created"""
updated_at: str
"""ISO format timestamp when document was last updated"""
- chunks_count: Optional[int] = None
+ chunks_count: int | None = None
"""Number of chunks after splitting, used for processing"""
- error: Optional[str] = None
+ error: str | None = None
"""Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict)
"""Additional metadata"""
diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py
index 5de6b334..ae756f85 100644
--- a/lightrag/exceptions.py
+++ b/lightrag/exceptions.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import httpx
from typing import Literal
diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py
index 82e723a1..cb3b59f1 100644
--- a/lightrag/kg/chroma_impl.py
+++ b/lightrag/kg/chroma_impl.py
@@ -2,7 +2,7 @@ import asyncio
from dataclasses import dataclass
from typing import Union
import numpy as np
-from chromadb import HttpClient
+from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
from lightrag.base import BaseVectorStorage
from lightrag.utils import logger
@@ -49,31 +49,43 @@ class ChromaVectorDBStorage(BaseVectorStorage):
**user_collection_settings,
}
- auth_provider = config.get(
- "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
- )
- auth_credentials = config.get("auth_token", "secret-token")
- headers = {}
+ local_path = config.get("local_path", None)
+ if local_path:
+ self._client = PersistentClient(
+ path=local_path,
+ settings=Settings(
+ allow_reset=True,
+ anonymized_telemetry=False,
+ ),
+ )
+ else:
+ auth_provider = config.get(
+ "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
+ )
+ auth_credentials = config.get("auth_token", "secret-token")
+ headers = {}
- if "token_authn" in auth_provider:
- headers = {
- config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
- }
- elif "basic_authn" in auth_provider:
- auth_credentials = config.get("auth_credentials", "admin:admin")
+ if "token_authn" in auth_provider:
+ headers = {
+ config.get(
+ "auth_header_name", "X-Chroma-Token"
+ ): auth_credentials
+ }
+ elif "basic_authn" in auth_provider:
+ auth_credentials = config.get("auth_credentials", "admin:admin")
- self._client = HttpClient(
- host=config.get("host", "localhost"),
- port=config.get("port", 8000),
- headers=headers,
- settings=Settings(
- chroma_api_impl="rest",
- chroma_client_auth_provider=auth_provider,
- chroma_client_auth_credentials=auth_credentials,
- allow_reset=True,
- anonymized_telemetry=False,
- ),
- )
+ self._client = HttpClient(
+ host=config.get("host", "localhost"),
+ port=config.get("port", 8000),
+ headers=headers,
+ settings=Settings(
+ chroma_api_impl="rest",
+ chroma_client_auth_provider=auth_provider,
+ chroma_client_auth_credentials=auth_credentials,
+ allow_reset=True,
+ anonymized_telemetry=False,
+ ),
+ )
self._collection = self._client.get_or_create_collection(
name=self.namespace,
@@ -144,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query])
results = self._collection.query(
- query_embeddings=embedding.tolist(),
+ query_embeddings=embedding.tolist()
+ if not isinstance(embedding, list)
+ else embedding,
n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"],
)
diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py
index b2090d78..ba94e39e 100644
--- a/lightrag/kg/faiss_impl.py
+++ b/lightrag/kg/faiss_impl.py
@@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
await self.delete([entity_id])
- async def delete_entity_relation(self, entity_name: str):
+ async def delete_entity_relation(self, entity_name: str) -> None:
"""
Delete relations for a given entity by scanning metadata.
"""
diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py
index cfd67367..3ab5b966 100644
--- a/lightrag/kg/json_kv_impl.py
+++ b/lightrag/kg/json_kv_impl.py
@@ -47,3 +47,8 @@ class JsonKVStorage(BaseKVStorage):
async def drop(self) -> None:
self._data = {}
+
+ async def delete(self, ids: list[str]) -> None:
+ for doc_id in ids:
+ self._data.pop(doc_id, None)
+ await self.index_done_callback()
diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py
index 226aecf2..c216e7be 100644
--- a/lightrag/kg/mongo_impl.py
+++ b/lightrag/kg/mongo_impl.py
@@ -4,6 +4,7 @@ import numpy as np
import pipmaster as pm
import configparser
from tqdm.asyncio import tqdm as tqdm_async
+import asyncio
if not pm.is_installed("pymongo"):
pm.install("pymongo")
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
+from pymongo.operations import SearchIndexModel
+from pymongo.errors import PyMongoError
from ..base import (
BaseGraphStorage,
BaseKVStorage,
+ BaseVectorStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
+from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
config = configparser.ConfigParser()
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
@dataclass
class MongoKVStorage(BaseKVStorage):
def __post_init__(self):
- client = MongoClient(
- os.environ.get(
- "MONGO_URI",
- config.get(
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
- ),
- )
+ uri = os.environ.get(
+ "MONGO_URI",
+ config.get(
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
+ ),
)
+ client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
- self._data = database.get_collection(self.namespace)
- logger.info(f"Use MongoDB as KV {self.namespace}")
+
+ self._collection_name = self.namespace
+
+ self._data = database.get_collection(self._collection_name)
+ logger.debug(f"Use MongoDB as KV {self._collection_name}")
+
+ # Ensure collection exists
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
- return self._data.find_one({"_id": id})
+ return await self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
- return list(self._data.find({"_id": {"$in": ids}}))
+ cursor = self._data.find({"_id": {"$in": ids}})
+ return await cursor.to_list()
async def filter_keys(self, data: set[str]) -> set[str]:
- existing_ids = [
- str(x["_id"])
- for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
- ]
- return set([s for s in data if s not in existing_ids])
+ cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
+ existing_ids = {str(x["_id"]) async for x in cursor}
+ return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
+ update_tasks = []
for mode, items in data.items():
- for k, v in tqdm_async(items.items(), desc="Upserting"):
+ for k, v in items.items():
key = f"{mode}_{k}"
- result = self._data.update_one(
- {"_id": key}, {"$setOnInsert": v}, upsert=True
+ data[mode][k]["_id"] = f"{mode}_{k}"
+ update_tasks.append(
+ self._data.update_one(
+ {"_id": key}, {"$setOnInsert": v}, upsert=True
+ )
)
- if result.upserted_id:
- logger.debug(f"\nInserted new document with key: {key}")
- data[mode][k]["_id"] = key
+ await asyncio.gather(*update_tasks)
else:
- for k, v in tqdm_async(data.items(), desc="Upserting"):
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
+ update_tasks = []
+ for k, v in data.items():
data[k]["_id"] = k
+ update_tasks.append(
+ self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
+ )
+ await asyncio.gather(*update_tasks)
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
res = {}
- v = self._data.find_one({"_id": mode + "_" + id})
+ v = await self._data.find_one({"_id": mode + "_" + id})
if v:
res[id] = v
logger.debug(f"llm_response_cache find one by:{id}")
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
- client = MongoClient(
- os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
+ uri = os.environ.get(
+ "MONGO_URI",
+ config.get(
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
+ ),
)
- database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
- self._data = database.get_collection(self.namespace)
- logger.info(f"Use MongoDB as doc status {self.namespace}")
+ client = AsyncIOMotorClient(uri)
+ database = client.get_database(
+ os.environ.get(
+ "MONGO_DATABASE",
+ config.get("mongodb", "database", fallback="LightRAG"),
+ )
+ )
+
+ self._collection_name = self.namespace
+ self._data = database.get_collection(self._collection_name)
+
+ logger.debug(f"Use MongoDB as doc status {self._collection_name}")
+
+ # Ensure collection exists
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
- return self._data.find_one({"_id": id})
+ return await self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
- return list(self._data.find({"_id": {"$in": ids}}))
+ cursor = self._data.find({"_id": {"$in": ids}})
+ return await cursor.to_list()
async def filter_keys(self, data: set[str]) -> set[str]:
- existing_ids = [
- str(x["_id"])
- for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
- ]
- return set([s for s in data if s not in existing_ids])
+ cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
+ existing_ids = {str(x["_id"]) async for x in cursor}
+ return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
+ update_tasks = []
for k, v in data.items():
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
data[k]["_id"] = k
+ update_tasks.append(
+ self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
+ )
+ await asyncio.gather(*update_tasks)
async def drop(self) -> None:
"""Drop the collection"""
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
- result = list(self._data.aggregate(pipeline))
+ cursor = self._data.aggregate(pipeline)
+ result = await cursor.to_list()
counts = {}
for doc in result:
counts[doc["_id"]] = doc["count"]
@@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage):
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents by status"""
- result = list(self._data.find({"status": status.value}))
+ cursor = self._data.find({"status": status.value})
+ result = await cursor.to_list()
return {
doc["_id"]: DocProcessingStatus(
content=doc["content"],
@@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage):
global_config=global_config,
embedding_func=embedding_func,
)
- self.client = AsyncIOMotorClient(
- os.environ.get(
- "MONGO_URI",
- config.get(
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
- ),
- )
+ uri = os.environ.get(
+ "MONGO_URI",
+ config.get(
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
+ ),
)
- self.db = self.client[
+ client = AsyncIOMotorClient(uri)
+ database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
- mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
+ config.get("mongodb", "database", fallback="LightRAG"),
)
- ]
- self.collection = self.db[
- os.environ.get(
- "MONGO_KG_COLLECTION",
- config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
- )
- ]
+ )
+
+ self._collection_name = self.namespace
+ self.collection = database.get_collection(self._collection_name)
+
+ logger.debug(f"Use MongoDB as KG {self._collection_name}")
+
+ # Ensure collection exists
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
#
# -------------------------------------------------------------------------
@@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage):
self, source_node_id: str
) -> Union[List[Tuple[str, str]], None]:
"""
- Return a list of (target_id, relation) for direct edges from source_node_id.
+ Return a list of (source_id, target_id) for direct edges from source_node_id.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
"""
pipeline = [
@@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage):
return None
edges = result[0].get("edges", [])
- return [(e["target"], e["relation"]) for e in edges]
+ return [(source_node_id, e["target"]) for e in edges]
#
# -------------------------------------------------------------------------
@@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str):
"""
- 1) Remove node’s doc entirely.
+ 1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id.
"""
# Remove inbound edges from all other docs
@@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage):
Placeholder for demonstration, raises NotImplementedError.
"""
raise NotImplementedError("Node embedding is not used in lightrag.")
+
+ #
+ # -------------------------------------------------------------------------
+ # QUERY
+ # -------------------------------------------------------------------------
+ #
+
+ async def get_all_labels(self) -> list[str]:
+ """
+ Get all existing node _id in the database
+ Returns:
+ [id1, id2, ...] # Alphabetically sorted id list
+ """
+ # Use MongoDB's distinct and aggregation to get all unique labels
+ pipeline = [
+ {"$group": {"_id": "$_id"}}, # Group by _id
+ {"$sort": {"_id": 1}}, # Sort alphabetically
+ ]
+
+ cursor = self.collection.aggregate(pipeline)
+ labels = []
+ async for doc in cursor:
+ labels.append(doc["_id"])
+ return labels
+
+ async def get_knowledge_graph(
+ self, node_label: str, max_depth: int = 5
+ ) -> KnowledgeGraph:
+ """
+ Get complete connected subgraph for specified node (including the starting node itself)
+
+ Args:
+ node_label: Label of the nodes to start from
+ max_depth: Maximum depth of traversal (default: 5)
+
+ Returns:
+ KnowledgeGraph object containing nodes and edges of the subgraph
+ """
+ label = node_label
+ result = KnowledgeGraph()
+ seen_nodes = set()
+ seen_edges = set()
+
+ try:
+ if label == "*":
+ # Get all nodes and edges
+ async for node_doc in self.collection.find({}):
+ node_id = str(node_doc["_id"])
+ if node_id not in seen_nodes:
+ result.nodes.append(
+ KnowledgeGraphNode(
+ id=node_id,
+ labels=[node_doc.get("_id")],
+ properties={
+ k: v
+ for k, v in node_doc.items()
+ if k not in ["_id", "edges"]
+ },
+ )
+ )
+ seen_nodes.add(node_id)
+
+ # Process edges
+ for edge in node_doc.get("edges", []):
+ edge_id = f"{node_id}-{edge['target']}"
+ if edge_id not in seen_edges:
+ result.edges.append(
+ KnowledgeGraphEdge(
+ id=edge_id,
+ type=edge.get("relation", ""),
+ source=node_id,
+ target=edge["target"],
+ properties={
+ k: v
+ for k, v in edge.items()
+ if k not in ["target", "relation"]
+ },
+ )
+ )
+ seen_edges.add(edge_id)
+ else:
+ # Verify if starting node exists
+ start_nodes = self.collection.find({"_id": label})
+ start_nodes_exist = await start_nodes.to_list(length=1)
+ if not start_nodes_exist:
+ logger.warning(f"Starting node with label {label} does not exist!")
+ return result
+
+ # Use $graphLookup for traversal
+ pipeline = [
+ {
+ "$match": {"_id": label}
+ }, # Start with nodes having the specified label
+ {
+ "$graphLookup": {
+ "from": self._collection_name,
+ "startWith": "$edges.target",
+ "connectFromField": "edges.target",
+ "connectToField": "_id",
+ "maxDepth": max_depth,
+ "depthField": "depth",
+ "as": "connected_nodes",
+ }
+ },
+ ]
+
+ async for doc in self.collection.aggregate(pipeline):
+ # Add the start node
+ node_id = str(doc["_id"])
+ if node_id not in seen_nodes:
+ result.nodes.append(
+ KnowledgeGraphNode(
+ id=node_id,
+ labels=[
+ doc.get(
+ "_id",
+ )
+ ],
+ properties={
+ k: v
+ for k, v in doc.items()
+ if k
+ not in [
+ "_id",
+ "edges",
+ "connected_nodes",
+ "depth",
+ ]
+ },
+ )
+ )
+ seen_nodes.add(node_id)
+
+ # Add edges from start node
+ for edge in doc.get("edges", []):
+ edge_id = f"{node_id}-{edge['target']}"
+ if edge_id not in seen_edges:
+ result.edges.append(
+ KnowledgeGraphEdge(
+ id=edge_id,
+ type=edge.get("relation", ""),
+ source=node_id,
+ target=edge["target"],
+ properties={
+ k: v
+ for k, v in edge.items()
+ if k not in ["target", "relation"]
+ },
+ )
+ )
+ seen_edges.add(edge_id)
+
+ # Add connected nodes and their edges
+ for connected in doc.get("connected_nodes", []):
+ node_id = str(connected["_id"])
+ if node_id not in seen_nodes:
+ result.nodes.append(
+ KnowledgeGraphNode(
+ id=node_id,
+ labels=[connected.get("_id")],
+ properties={
+ k: v
+ for k, v in connected.items()
+ if k not in ["_id", "edges", "depth"]
+ },
+ )
+ )
+ seen_nodes.add(node_id)
+
+ # Add edges from connected nodes
+ for edge in connected.get("edges", []):
+ edge_id = f"{node_id}-{edge['target']}"
+ if edge_id not in seen_edges:
+ result.edges.append(
+ KnowledgeGraphEdge(
+ id=edge_id,
+ type=edge.get("relation", ""),
+ source=node_id,
+ target=edge["target"],
+ properties={
+ k: v
+ for k, v in edge.items()
+ if k not in ["target", "relation"]
+ },
+ )
+ )
+ seen_edges.add(edge_id)
+
+ logger.info(
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
+ )
+
+ except PyMongoError as e:
+ logger.error(f"MongoDB query failed: {str(e)}")
+
+ return result
+
+
+@dataclass
+class MongoVectorDBStorage(BaseVectorStorage):
+ cosine_better_than_threshold: float = None
+
+ def __post_init__(self):
+ kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
+ cosine_threshold = kwargs.get("cosine_better_than_threshold")
+ if cosine_threshold is None:
+ raise ValueError(
+ "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
+ )
+ self.cosine_better_than_threshold = cosine_threshold
+
+ uri = os.environ.get(
+ "MONGO_URI",
+ config.get(
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
+ ),
+ )
+ client = AsyncIOMotorClient(uri)
+ database = client.get_database(
+ os.environ.get(
+ "MONGO_DATABASE",
+ config.get("mongodb", "database", fallback="LightRAG"),
+ )
+ )
+
+ self._collection_name = self.namespace
+ self._data = database.get_collection(self._collection_name)
+ self._max_batch_size = self.global_config["embedding_batch_num"]
+
+ logger.debug(f"Use MongoDB as VDB {self._collection_name}")
+
+ # Ensure collection exists
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
+
+ # Ensure vector index exists
+ self.create_vector_index(uri, database.name, self._collection_name)
+
+ def create_vector_index(self, uri: str, database_name: str, collection_name: str):
+ """Creates an Atlas Vector Search index."""
+ client = MongoClient(uri)
+ collection = client.get_database(database_name).get_collection(
+ self._collection_name
+ )
+
+ try:
+ search_index_model = SearchIndexModel(
+ definition={
+ "fields": [
+ {
+ "type": "vector",
+ "numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
+ "path": "vector",
+ "similarity": "cosine", # Options: euclidean, cosine, dotProduct
+ }
+ ]
+ },
+ name="vector_knn_index",
+ type="vectorSearch",
+ )
+
+ collection.create_search_index(search_index_model)
+ logger.info("Vector index created successfully.")
+
+ except PyMongoError as _:
+ logger.debug("vector index already exist")
+
+ async def upsert(self, data: dict[str, dict]):
+ logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
+ if not data:
+ logger.warning("You are inserting an empty data set to vector DB")
+ return []
+
+ list_data = [
+ {
+ "_id": k,
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
+ }
+ for k, v in data.items()
+ ]
+ contents = [v["content"] for v in data.values()]
+ batches = [
+ contents[i : i + self._max_batch_size]
+ for i in range(0, len(contents), self._max_batch_size)
+ ]
+
+ async def wrapped_task(batch):
+ result = await self.embedding_func(batch)
+ pbar.update(1)
+ return result
+
+ embedding_tasks = [wrapped_task(batch) for batch in batches]
+ pbar = tqdm_async(
+ total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
+ )
+ embeddings_list = await asyncio.gather(*embedding_tasks)
+
+ embeddings = np.concatenate(embeddings_list)
+ for i, d in enumerate(list_data):
+ d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
+
+ update_tasks = []
+ for doc in list_data:
+ update_tasks.append(
+ self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
+ )
+ await asyncio.gather(*update_tasks)
+
+ return list_data
+
+ async def query(self, query, top_k=5):
+ """Queries the vector database using Atlas Vector Search."""
+ # Generate the embedding
+ embedding = await self.embedding_func([query])
+
+ # Convert numpy array to a list to ensure compatibility with MongoDB
+ query_vector = embedding[0].tolist()
+
+ # Define the aggregation pipeline with the converted query vector
+ pipeline = [
+ {
+ "$vectorSearch": {
+ "index": "vector_knn_index", # Ensure this matches the created index name
+ "path": "vector",
+ "queryVector": query_vector,
+ "numCandidates": 100, # Adjust for performance
+ "limit": top_k,
+ }
+ },
+ {"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
+ {"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
+ {"$project": {"vector": 0}},
+ ]
+
+ # Execute the aggregation pipeline
+ cursor = self._data.aggregate(pipeline)
+ results = await cursor.to_list()
+
+ # Format and return the results
+ return [
+ {**doc, "id": doc["_id"], "distance": doc.get("score", None)}
+ for doc in results
+ ]
+
+
+def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
+ """Check if the collection exists. if not, create it."""
+ client = MongoClient(uri)
+ database = client.get_database(database_name)
+
+ collection_names = database.list_collection_names()
+
+ if collection_name not in collection_names:
+ database.create_collection(collection_name)
+ logger.info(f"Created collection: {collection_name}")
+ else:
+ logger.debug(f"Collection '{collection_name}' already exists.")
diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py
index 60eed3dc..96c5a8dd 100644
--- a/lightrag/kg/nano_vector_db_impl.py
+++ b/lightrag/kg/nano_vector_db_impl.py
@@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
- async def delete_entity_relation(self, entity_name: str):
+ async def delete_entity_relation(self, entity_name: str) -> None:
try:
relations = [
dp
diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py
index e9a53110..15525375 100644
--- a/lightrag/kg/neo4j_impl.py
+++ b/lightrag/kg/neo4j_impl.py
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
async def index_done_callback(self):
print("KG successfully indexed.")
- async def has_node(self, node_id: str) -> bool:
- entity_name_label = node_id.strip('"')
+ async def _label_exists(self, label: str) -> bool:
+ """Check if a label exists in the Neo4j database."""
+ query = "CALL db.labels() YIELD label RETURN label"
+ try:
+ async with self._driver.session(database=self._DATABASE) as session:
+ result = await session.run(query)
+ labels = [record["label"] for record in await result.data()]
+ return label in labels
+ except Exception as e:
+ logger.error(f"Error checking label existence: {e}")
+ return False
+ async def _ensure_label(self, label: str) -> str:
+ """Ensure a label exists by validating it."""
+ clean_label = label.strip('"')
+ if not await self._label_exists(clean_label):
+ logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
+ return clean_label
+
+ async def has_node(self, node_id: str) -> bool:
+ entity_name_label = await self._ensure_label(node_id)
async with self._driver.session(database=self._DATABASE) as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
return single_result["edgeExists"]
async def get_node(self, node_id: str) -> Union[dict, None]:
+ """Get node by its label identifier.
+
+ Args:
+ node_id: The node label to look up
+
+ Returns:
+ dict: Node properties if found
+ None: If node not found
+ """
async with self._driver.session(database=self._DATABASE) as session:
- entity_name_label = node_id.strip('"')
+ entity_name_label = await self._ensure_label(node_id)
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
record = await result.single()
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
- entity_name_label_source = source_node_id.strip('"')
- entity_name_label_target = target_node_id.strip('"')
- """
- Find all edges between nodes of two given labels
+ """Find edge between two nodes identified by their labels.
Args:
- source_node_label (str): Label of the source nodes
- target_node_label (str): Label of the target nodes
+ source_node_id (str): Label of the source node
+ target_node_id (str): Label of the target node
Returns:
- list: List of all relationships/edges found
+ dict: Edge properties if found, with at least {"weight": 0.0}
+ None: If error occurs
"""
- async with self._driver.session(database=self._DATABASE) as session:
- query = f"""
- MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
- RETURN properties(r) as edge_properties
- LIMIT 1
- """.format(
- entity_name_label_source=entity_name_label_source,
- entity_name_label_target=entity_name_label_target,
- )
+ try:
+ entity_name_label_source = source_node_id.strip('"')
+ entity_name_label_target = target_node_id.strip('"')
- result = await session.run(query)
- record = await result.single()
- if record:
- result = dict(record["edge_properties"])
- logger.debug(
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
+ async with self._driver.session(database=self._DATABASE) as session:
+ query = f"""
+ MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
+ RETURN properties(r) as edge_properties
+ LIMIT 1
+ """.format(
+ entity_name_label_source=entity_name_label_source,
+ entity_name_label_target=entity_name_label_target,
)
- return result
- else:
- return None
+
+ result = await session.run(query)
+ record = await result.single()
+ if record and "edge_properties" in record:
+ try:
+ result = dict(record["edge_properties"])
+ # Ensure required keys exist with defaults
+ required_keys = {
+ "weight": 0.0,
+ "source_id": None,
+ "target_id": None,
+ }
+ for key, default_value in required_keys.items():
+ if key not in result:
+ result[key] = default_value
+ logger.warning(
+ f"Edge between {entity_name_label_source} and {entity_name_label_target} "
+ f"missing {key}, using default: {default_value}"
+ )
+
+ logger.debug(
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
+ )
+ return result
+ except (KeyError, TypeError, ValueError) as e:
+ logger.error(
+ f"Error processing edge properties between {entity_name_label_source} "
+ f"and {entity_name_label_target}: {str(e)}"
+ )
+ # Return default edge properties on error
+ return {"weight": 0.0, "source_id": None, "target_id": None}
+
+ logger.debug(
+ f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
+ )
+ # Return default edge properties when no edge found
+ return {"weight": 0.0, "source_id": None, "target_id": None}
+
+ except Exception as e:
+ logger.error(
+ f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
+ )
+ # Return default edge properties on error
+ return {"weight": 0.0, "source_id": None, "target_id": None}
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
node_label = source_node_id.strip('"')
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
- label = node_id.strip('"')
+ label = await self._ensure_label(node_id)
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
+ neo4jExceptions.ClientError,
)
),
)
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
- source_node_label = source_node_id.strip('"')
- target_node_label = target_node_id.strip('"')
+ source_label = await self._ensure_label(source_node_id)
+ target_label = await self._ensure_label(target_node_id)
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f"""
- MATCH (source:`{source_node_label}`)
+ MATCH (source:`{source_label}`)
WITH source
- MATCH (target:`{target_node_label}`)
+ MATCH (target:`{target_label}`)
MERGE (source)-[r:DIRECTED]->(target)
SET r += $properties
RETURN r
"""
- await tx.run(query, properties=edge_properties)
+ result = await tx.run(query, properties=edge_properties)
+ record = await result.single()
logger.debug(
- f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
+ f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
)
try:
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 554cba22..529336e9 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -1,10 +1,12 @@
+from __future__ import annotations
+
import asyncio
import os
import configparser
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
-from typing import Any, Callable, Optional, Type, Union, cast
+from typing import Any, AsyncIterator, Callable, Iterator, cast
from .base import (
BaseGraphStorage,
@@ -76,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
+ "MongoVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
@@ -91,7 +94,7 @@ STORAGE_IMPLEMENTATIONS = {
}
# Storage implementation environment variable without default value
-STORAGE_ENV_REQUIREMENTS = {
+STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
@@ -140,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS = {
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
+ "MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
@@ -160,6 +164,7 @@ STORAGES = {
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
+ "MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
@@ -176,7 +181,7 @@ STORAGES = {
}
-def lazy_external_import(module_name: str, class_name: str):
+def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
@@ -185,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
- def import_class(*args, **kwargs):
+ def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
@@ -302,7 +307,7 @@ class LightRAG:
- random_seed: Seed value for reproducibility.
"""
- embedding_func: EmbeddingFunc = None
+ embedding_func: EmbeddingFunc | None = None
"""Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32
@@ -312,7 +317,7 @@ class LightRAG:
"""Maximum number of concurrent embedding function calls."""
# LLM Configuration
- llm_model_func: callable = None
+ llm_model_func: Callable[..., object] | None = None
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -342,10 +347,8 @@ class LightRAG:
# Extensions
addon_params: dict[str, Any] = field(default_factory=dict)
- """Dictionary for additional parameters and extensions."""
- # extension
- addon_params: dict[str, Any] = field(default_factory=dict)
+ """Dictionary for additional parameters and extensions."""
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
)
@@ -354,7 +357,7 @@ class LightRAG:
chunking_func: Callable[
[
str,
- Optional[str],
+ str | None,
bool,
int,
int,
@@ -443,77 +446,74 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
- self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
+ self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func
)
# Initialize all storages
- self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
+ self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
self._get_storage_class(self.kv_storage)
- )
- self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
+ ) # type: ignore
+ self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage
- )
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
+ ) # type: ignore
+ self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage
- )
-
- self.key_string_value_json_storage_cls = partial(
+ ) # type: ignore
+ self.key_string_value_json_storage_cls = partial( # type: ignore
self.key_string_value_json_storage_cls, global_config=global_config
)
-
- self.vector_db_storage_cls = partial(
+ self.vector_db_storage_cls = partial( # type: ignore
self.vector_db_storage_cls, global_config=global_config
)
-
- self.graph_storage_cls = partial(
+ self.graph_storage_cls = partial( # type: ignore
self.graph_storage_cls, global_config=global_config
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
- self.llm_response_cache = self.key_string_value_json_storage_cls(
+ self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
- self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
+ self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
),
embedding_func=self.embedding_func,
)
- self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
+ self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
embedding_func=self.embedding_func,
)
- self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
+ self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
),
embedding_func=self.embedding_func,
)
- self.entities_vdb = self.vector_db_storage_cls(
+ self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
- self.relationships_vdb = self.vector_db_storage_cls(
+ self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
- self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
+ self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
@@ -527,13 +527,12 @@ class LightRAG:
embedding_func=None,
)
- # What's for, Is this nessisary ?
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
hashing_kv = self.llm_response_cache
else:
- hashing_kv = self.key_string_value_json_storage_cls(
+ hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
@@ -542,7 +541,7 @@ class LightRAG:
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
- self.llm_model_func,
+ self.llm_model_func, # type: ignore
hashing_kv=hashing_kv,
**self.llm_model_kwargs,
)
@@ -559,68 +558,45 @@ class LightRAG:
node_label=nodel_label, max_depth=max_depth
)
- def _get_storage_class(self, storage_name: str) -> dict:
+ def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
- def set_storage_client(self, db_client):
- # Deprecated, seting correct value to *_storage of LightRAG insteaded
- # Inject db to storage implementation (only tested on Oracle Database)
- for storage in [
- self.vector_db_storage_cls,
- self.graph_storage_cls,
- self.doc_status,
- self.full_docs,
- self.text_chunks,
- self.llm_response_cache,
- self.key_string_value_json_storage_cls,
- self.chunks_vdb,
- self.relationships_vdb,
- self.entities_vdb,
- self.graph_storage_cls,
- self.chunk_entity_relation_graph,
- self.llm_response_cache,
- ]:
- # set client
- storage.db = db_client
-
def insert(
self,
- string_or_strings: Union[str, list[str]],
+ input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
"""Sync Insert documents with checkpoint support
Args:
- string_or_strings: Single document string or list of document strings
+ input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
- chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
- self.ainsert(string_or_strings, split_by_character, split_by_character_only)
+ self.ainsert(input, split_by_character, split_by_character_only)
)
async def ainsert(
self,
- string_or_strings: Union[str, list[str]],
+ input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
"""Async Insert documents with checkpoint support
Args:
- string_or_strings: Single document string or list of document strings
+ input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
- chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
- await self.apipeline_enqueue_documents(string_or_strings)
+ await self.apipeline_enqueue_documents(input)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
@@ -677,7 +653,7 @@ class LightRAG:
if update_storage:
await self._insert_done()
- async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
+ async def apipeline_enqueue_documents(self, input: str | list[str]):
"""
Pipeline for Processing Documents
@@ -686,11 +662,11 @@ class LightRAG:
3. Filter out already processed documents
4. Enqueue document in status
"""
- if isinstance(string_or_strings, str):
- string_or_strings = [string_or_strings]
+ if isinstance(input, str):
+ input = [input]
# 1. Remove duplicate contents from the list
- unique_contents = list(set(doc.strip() for doc in string_or_strings))
+ unique_contents = list(set(doc.strip() for doc in input))
# 2. Generate document IDs and initial status
new_docs: dict[str, Any] = {
@@ -857,32 +833,32 @@ class LightRAG:
raise e
async def _insert_done(self):
- tasks = []
- for storage_inst in [
- self.full_docs,
- self.text_chunks,
- self.llm_response_cache,
- self.entities_vdb,
- self.relationships_vdb,
- self.chunks_vdb,
- self.chunk_entity_relation_graph,
- ]:
- if storage_inst is None:
- continue
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
+ tasks = [
+ cast(StorageNameSpace, storage_inst).index_done_callback()
+ for storage_inst in [ # type: ignore
+ self.full_docs,
+ self.text_chunks,
+ self.llm_response_cache,
+ self.entities_vdb,
+ self.relationships_vdb,
+ self.chunks_vdb,
+ self.chunk_entity_relation_graph,
+ ]
+ if storage_inst is not None
+ ]
await asyncio.gather(*tasks)
- def insert_custom_kg(self, custom_kg: dict):
+ def insert_custom_kg(self, custom_kg: dict[str, Any]):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
- async def ainsert_custom_kg(self, custom_kg: dict):
+ async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
update_storage = False
try:
# Insert chunks into vector storage
- all_chunks_data = {}
- chunk_to_source_map = {}
- for chunk_data in custom_kg.get("chunks", []):
+ all_chunks_data: dict[str, dict[str, str]] = {}
+ chunk_to_source_map: dict[str, str] = {}
+ for chunk_data in custom_kg.get("chunks", {}):
chunk_content = chunk_data["content"]
source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
@@ -892,13 +868,13 @@ class LightRAG:
chunk_to_source_map[source_id] = chunk_id
update_storage = True
- if self.chunks_vdb is not None and all_chunks_data:
+ if all_chunks_data:
await self.chunks_vdb.upsert(all_chunks_data)
- if self.text_chunks is not None and all_chunks_data:
+ if all_chunks_data:
await self.text_chunks.upsert(all_chunks_data)
# Insert entities into knowledge graph
- all_entities_data = []
+ all_entities_data: list[dict[str, str]] = []
for entity_data in custom_kg.get("entities", []):
entity_name = f'"{entity_data["entity_name"].upper()}"'
entity_type = entity_data.get("entity_type", "UNKNOWN")
@@ -914,7 +890,7 @@ class LightRAG:
)
# Prepare node data
- node_data = {
+ node_data: dict[str, str] = {
"entity_type": entity_type,
"description": description,
"source_id": source_id,
@@ -928,7 +904,7 @@ class LightRAG:
update_storage = True
# Insert relationships into knowledge graph
- all_relationships_data = []
+ all_relationships_data: list[dict[str, str]] = []
for relationship_data in custom_kg.get("relationships", []):
src_id = f'"{relationship_data["src_id"].upper()}"'
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
@@ -970,7 +946,7 @@ class LightRAG:
"source_id": source_id,
},
)
- edge_data = {
+ edge_data: dict[str, str] = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
@@ -980,41 +956,68 @@ class LightRAG:
update_storage = True
# Insert entities into vector storage if needed
- if self.entities_vdb is not None:
- data_for_vdb = {
- compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
- "content": dp["entity_name"] + dp["description"],
- "entity_name": dp["entity_name"],
- }
- for dp in all_entities_data
+ data_for_vdb = {
+ compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
+ "content": dp["entity_name"] + dp["description"],
+ "entity_name": dp["entity_name"],
}
- await self.entities_vdb.upsert(data_for_vdb)
+ for dp in all_entities_data
+ }
+ await self.entities_vdb.upsert(data_for_vdb)
# Insert relationships into vector storage if needed
- if self.relationships_vdb is not None:
- data_for_vdb = {
- compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
- "src_id": dp["src_id"],
- "tgt_id": dp["tgt_id"],
- "content": dp["keywords"]
- + dp["src_id"]
- + dp["tgt_id"]
- + dp["description"],
- }
- for dp in all_relationships_data
+ data_for_vdb = {
+ compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
+ "src_id": dp["src_id"],
+ "tgt_id": dp["tgt_id"],
+ "content": dp["keywords"]
+ + dp["src_id"]
+ + dp["tgt_id"]
+ + dp["description"],
}
- await self.relationships_vdb.upsert(data_for_vdb)
+ for dp in all_relationships_data
+ }
+ await self.relationships_vdb.upsert(data_for_vdb)
+
finally:
if update_storage:
await self._insert_done()
- def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
+ def query(
+ self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
+ ) -> str | Iterator[str]:
+ """
+ Perform a sync query.
+
+ Args:
+ query (str): The query to be executed.
+ param (QueryParam): Configuration parameters for query execution.
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
+
+ Returns:
+ str: The result of the query execution.
+ """
loop = always_get_an_event_loop()
- return loop.run_until_complete(self.aquery(query, prompt, param))
+
+ return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
async def aquery(
- self, query: str, prompt: str = "", param: QueryParam = QueryParam()
- ):
+ self,
+ query: str,
+ param: QueryParam = QueryParam(),
+ prompt: str | None = None,
+ ) -> str | AsyncIterator[str]:
+ """
+ Perform a async query.
+
+ Args:
+ query (str): The query to be executed.
+ param (QueryParam): Configuration parameters for query execution.
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
+
+ Returns:
+ str: The result of the query execution.
+ """
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query,
@@ -1094,7 +1097,7 @@ class LightRAG:
async def aquery_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam()
- ):
+ ) -> str | AsyncIterator[str]:
"""
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
@@ -1117,8 +1120,8 @@ class LightRAG:
),
)
- param.hl_keywords = (hl_keywords,)
- param.ll_keywords = (ll_keywords,)
+ param.hl_keywords = hl_keywords
+ param.ll_keywords = ll_keywords
# ---------------------
# STEP 2: Final Query Logic
@@ -1146,7 +1149,7 @@ class LightRAG:
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
- embedding_func=self.embedding_funcne,
+ embedding_func=self.embedding_func,
),
)
elif param.mode == "naive":
@@ -1195,12 +1198,7 @@ class LightRAG:
return response
async def _query_done(self):
- tasks = []
- for storage_inst in [self.llm_response_cache]:
- if storage_inst is None:
- continue
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
- await asyncio.gather(*tasks)
+ await self.llm_response_cache.index_done_callback()
def delete_by_entity(self, entity_name: str):
loop = always_get_an_event_loop()
@@ -1222,16 +1220,16 @@ class LightRAG:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(self):
- tasks = []
- for storage_inst in [
- self.entities_vdb,
- self.relationships_vdb,
- self.chunk_entity_relation_graph,
- ]:
- if storage_inst is None:
- continue
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
- await asyncio.gather(*tasks)
+ await asyncio.gather(
+ *[
+ cast(StorageNameSpace, storage_inst).index_done_callback()
+ for storage_inst in [ # type: ignore
+ self.entities_vdb,
+ self.relationships_vdb,
+ self.chunk_entity_relation_graph,
+ ]
+ ]
+ )
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
@@ -1256,7 +1254,7 @@ class LightRAG:
"""
return await self.doc_status.get_status_counts()
- async def adelete_by_doc_id(self, doc_id: str):
+ async def adelete_by_doc_id(self, doc_id: str) -> None:
"""Delete a document and all its related data
Args:
@@ -1273,6 +1271,9 @@ class LightRAG:
# 2. Get all related chunks
chunks = await self.text_chunks.get_by_id(doc_id)
+ if not chunks:
+ return
+
chunk_ids = list(chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
@@ -1443,13 +1444,9 @@ class LightRAG:
except Exception as e:
logger.error(f"Error while deleting document {doc_id}: {e}")
- def delete_by_doc_id(self, doc_id: str):
- """Synchronous version of adelete"""
- return asyncio.run(self.adelete_by_doc_id(doc_id))
-
async def get_entity_info(
self, entity_name: str, include_vector_data: bool = False
- ):
+ ) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity
Args:
@@ -1469,7 +1466,7 @@ class LightRAG:
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
source_id = node_data.get("source_id") if node_data else None
- result = {
+ result: dict[str, str | None | dict[str, str]] = {
"entity_name": entity_name,
"source_id": source_id,
"graph_data": node_data,
@@ -1483,21 +1480,6 @@ class LightRAG:
return result
- def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
- """Synchronous version of getting entity information
-
- Args:
- entity_name: Entity name (no need for quotes)
- include_vector_data: Whether to include data from the vector database
- """
- try:
- import tracemalloc
-
- tracemalloc.start()
- return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
- finally:
- tracemalloc.stop()
-
async def get_relation_info(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
):
@@ -1525,7 +1507,7 @@ class LightRAG:
)
source_id = edge_data.get("source_id") if edge_data else None
- result = {
+ result: dict[str, str | None | dict[str, str]] = {
"src_entity": src_entity,
"tgt_entity": tgt_entity,
"source_id": source_id,
@@ -1539,23 +1521,3 @@ class LightRAG:
result["vector_data"] = vector_data[0] if vector_data else None
return result
-
- def get_relation_info_sync(
- self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
- ):
- """Synchronous version of getting relationship information
-
- Args:
- src_entity: Source entity name (no need for quotes)
- tgt_entity: Target entity name (no need for quotes)
- include_vector_data: Whether to include data from the vector database
- """
- try:
- import tracemalloc
-
- tracemalloc.start()
- return asyncio.run(
- self.get_relation_info(src_entity, tgt_entity, include_vector_data)
- )
- finally:
- tracemalloc.stop()
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 3ca17725..e5f98cf8 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -1,4 +1,6 @@
-from typing import List, Dict, Callable, Any
+from __future__ import annotations
+
+from typing import Callable, Any
from pydantic import BaseModel, Field
@@ -23,7 +25,7 @@ class Model(BaseModel):
...,
description="A function that generates the response from the llm. The response must be a string",
)
- kwargs: Dict[str, Any] = Field(
+ kwargs: dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
@@ -57,7 +59,7 @@ class MultiModel:
```
"""
- def __init__(self, models: List[Model]):
+ def __init__(self, models: list[Model]):
self._models = models
self._current_model = 0
@@ -66,7 +68,11 @@ class MultiModel:
return self._models[self._current_model]
async def llm_model_func(
- self, prompt, system_prompt=None, history_messages=[], **kwargs
+ self,
+ prompt: str,
+ system_prompt: str | None = None,
+ history_messages: list[dict[str, Any]] = [],
+ **kwargs: Any,
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("keyword_extraction", None)
diff --git a/lightrag/namespace.py b/lightrag/namespace.py
index ba8e3072..77e04c9e 100644
--- a/lightrag/namespace.py
+++ b/lightrag/namespace.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from typing import Iterable
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 04d06e6b..fa52c55a 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import asyncio
import json
import re
from tqdm.asyncio import tqdm as tqdm_async
-from typing import Any, Union
+from typing import Any, AsyncIterator
from collections import Counter, defaultdict
from .utils import (
logger,
@@ -36,7 +38,7 @@ import time
def chunking_by_token_size(
content: str,
- split_by_character: Union[str, None] = None,
+ split_by_character: str | None = None,
split_by_character_only: bool = False,
overlap_token_size: int = 128,
max_token_size: int = 1024,
@@ -237,25 +239,65 @@ async def _merge_edges_then_upsert(
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
- already_weights.append(already_edge["weight"])
- already_source_ids.extend(
- split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
- )
- already_description.append(already_edge["description"])
- already_keywords.extend(
- split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
- )
+ # Handle the case where get_edge returns None or missing fields
+ if already_edge:
+ # Get weight with default 0.0 if missing
+ if "weight" in already_edge:
+ already_weights.append(already_edge["weight"])
+ else:
+ logger.warning(
+ f"Edge between {src_id} and {tgt_id} missing weight field"
+ )
+ already_weights.append(0.0)
+ # Get source_id with empty string default if missing or None
+ if "source_id" in already_edge and already_edge["source_id"] is not None:
+ already_source_ids.extend(
+ split_string_by_multi_markers(
+ already_edge["source_id"], [GRAPH_FIELD_SEP]
+ )
+ )
+
+ # Get description with empty string default if missing or None
+ if (
+ "description" in already_edge
+ and already_edge["description"] is not None
+ ):
+ already_description.append(already_edge["description"])
+
+ # Get keywords with empty string default if missing or None
+ if "keywords" in already_edge and already_edge["keywords"] is not None:
+ already_keywords.extend(
+ split_string_by_multi_markers(
+ already_edge["keywords"], [GRAPH_FIELD_SEP]
+ )
+ )
+
+ # Process edges_data with None checks
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
- sorted(set([dp["description"] for dp in edges_data] + already_description))
+ sorted(
+ set(
+ [dp["description"] for dp in edges_data if dp.get("description")]
+ + already_description
+ )
+ )
)
keywords = GRAPH_FIELD_SEP.join(
- sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
+ sorted(
+ set(
+ [dp["keywords"] for dp in edges_data if dp.get("keywords")]
+ + already_keywords
+ )
+ )
)
source_id = GRAPH_FIELD_SEP.join(
- set([dp["source_id"] for dp in edges_data] + already_source_ids)
+ set(
+ [dp["source_id"] for dp in edges_data if dp.get("source_id")]
+ + already_source_ids
+ )
)
+
for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node(
@@ -295,9 +337,9 @@ async def extract_entities(
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
- global_config: dict,
- llm_response_cache: BaseKVStorage = None,
-) -> Union[BaseGraphStorage, None]:
+ global_config: dict[str, str],
+ llm_response_cache: BaseKVStorage | None = None,
+) -> BaseGraphStorage | None:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
@@ -563,15 +605,15 @@ async def extract_entities(
async def kg_query(
- query,
+ query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
- prompt: str = "",
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
+ prompt: str | None = None,
) -> str:
# Handle cache
use_model_func = global_config["llm_model_func"]
@@ -684,8 +726,8 @@ async def kg_query(
async def extract_keywords_only(
text: str,
param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]:
"""
Extract high-level and low-level keywords from the given 'text' using the LLM.
@@ -784,9 +826,9 @@ async def mix_kg_vector_query(
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
-) -> str:
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
+) -> str | AsyncIterator[str]:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -1551,13 +1593,13 @@ def combine_contexts(entities, relationships, sources):
async def naive_query(
- query,
+ query: str,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
-):
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
+) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
@@ -1664,9 +1706,9 @@ async def kg_query_with_keywords(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
-) -> str:
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
+) -> str | AsyncIterator[str]:
"""
Refactored kg_query that does NOT extract keywords by itself.
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
diff --git a/lightrag/prompt.py b/lightrag/prompt.py
index 160663d9..f4f5e38a 100644
--- a/lightrag/prompt.py
+++ b/lightrag/prompt.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
GRAPH_FIELD_SEP = ""
PROMPTS = {}
diff --git a/lightrag/types.py b/lightrag/types.py
index 9c8e0099..5e3d2948 100644
--- a/lightrag/types.py
+++ b/lightrag/types.py
@@ -1,26 +1,28 @@
+from __future__ import annotations
+
from pydantic import BaseModel
-from typing import List, Dict, Any
+from typing import Any, Optional
class GPTKeywordExtractionFormat(BaseModel):
- high_level_keywords: List[str]
- low_level_keywords: List[str]
+ high_level_keywords: list[str]
+ low_level_keywords: list[str]
class KnowledgeGraphNode(BaseModel):
id: str
- labels: List[str]
- properties: Dict[str, Any] # anything else goes here
+ labels: list[str]
+ properties: dict[str, Any] # anything else goes here
class KnowledgeGraphEdge(BaseModel):
id: str
- type: str
+ type: Optional[str]
source: str # id of source node
target: str # id of target node
- properties: Dict[str, Any] # anything else goes here
+ properties: dict[str, Any] # anything else goes here
class KnowledgeGraph(BaseModel):
- nodes: List[KnowledgeGraphNode] = []
- edges: List[KnowledgeGraphEdge] = []
+ nodes: list[KnowledgeGraphNode] = []
+ edges: list[KnowledgeGraphEdge] = []
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 9df325ca..c8786e7b 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import asyncio
import html
import io
@@ -9,7 +11,7 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
-from typing import Any, Union, List, Optional
+from typing import Any, Callable
import xml.etree.ElementTree as ET
import bs4
@@ -67,12 +69,12 @@ class EmbeddingFunc:
@dataclass
class ReasoningResponse:
- reasoning_content: str
+ reasoning_content: str | None
response_content: str
tag: str
-def locate_json_string_body_from_string(content: str) -> Union[str, None]:
+def locate_json_string_body_from_string(content: str) -> str | None:
"""Locate the JSON string body from a string"""
try:
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -109,7 +111,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
raise e from None
-def compute_args_hash(*args, cache_type: str = None) -> str:
+def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
@@ -128,7 +130,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
return hashlib.md5(args_str.encode()).hexdigest()
-def compute_mdhash_id(content, prefix: str = ""):
+def compute_mdhash_id(content: str, prefix: str = "") -> str:
+ """
+ Compute a unique ID for a given content string.
+
+ The ID is a combination of the given prefix and the MD5 hash of the content string.
+ """
return prefix + md5(content.encode()).hexdigest()
@@ -215,11 +222,13 @@ def clean_str(input: Any) -> str:
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
-def is_float_regex(value):
+def is_float_regex(value: str) -> bool:
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
-def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
+def truncate_list_by_token_size(
+ list_data: list[Any], key: Callable[[Any], str], max_token_size: int
+) -> list[int]:
"""Truncate a list of data by token size"""
if max_token_size <= 0:
return []
@@ -231,7 +240,7 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data
-def list_of_list_to_csv(data: List[List[str]]) -> str:
+def list_of_list_to_csv(data: list[list[str]]) -> str:
output = io.StringIO()
writer = csv.writer(
output,
@@ -244,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
return output.getvalue()
-def csv_string_to_list(csv_string: str) -> List[List[str]]:
+def csv_string_to_list(csv_string: str) -> list[list[str]]:
# Clean the string by removing NUL characters
cleaned_string = csv_string.replace("\0", "")
@@ -329,7 +338,7 @@ def xml_to_json(xml_file):
return None
-def process_combine_contexts(hl, ll):
+def process_combine_contexts(hl: str, ll: str):
header = None
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
@@ -375,7 +384,7 @@ async def get_best_cached_response(
llm_func=None,
original_prompt=None,
cache_type=None,
-) -> Union[str, None]:
+) -> str | None:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
@@ -479,7 +488,7 @@ def cosine_similarity(v1, v2):
return dot_product / (norm1 * norm2)
-def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
+def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
"""Quantize embedding to specified bits"""
# Convert list to numpy array if needed
if isinstance(embedding, list):
@@ -570,9 +579,9 @@ class CacheData:
args_hash: str
content: str
prompt: str
- quantized: Optional[np.ndarray] = None
- min_val: Optional[float] = None
- max_val: Optional[float] = None
+ quantized: np.ndarray | None = None
+ min_val: float | None = None
+ max_val: float | None = None
mode: str = "default"
cache_type: str = "query"
@@ -635,7 +644,9 @@ def exists_func(obj, func_name: str) -> bool:
return False
-def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
+def get_conversation_turns(
+ conversation_history: list[dict[str, Any]], num_turns: int
+) -> str:
"""
Process conversation history to get the specified number of complete turns.
@@ -647,8 +658,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
Formatted string of the conversation history
"""
# Group messages into turns
- turns = []
- messages = []
+ turns: list[list[dict[str, Any]]] = []
+ messages: list[dict[str, Any]] = []
# First, filter out keyword extraction messages
for msg in conversation_history:
@@ -682,7 +693,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
turns = turns[-num_turns:]
# Format the turns into a string
- formatted_turns = []
+ formatted_turns: list[str] = []
for turn in turns:
formatted_turns.extend(
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
diff --git a/lightrag_webui/src/components/PropertiesView.tsx b/lightrag_webui/src/components/PropertiesView.tsx
index 078420e6..dec80460 100644
--- a/lightrag_webui/src/components/PropertiesView.tsx
+++ b/lightrag_webui/src/components/PropertiesView.tsx
@@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
-
+ {edge.type &&
}
{
}
for (const edge of graph.edges) {
- if (!edge.id || !edge.source || !edge.target || !edge.type || !edge.properties) {
+ if (!edge.id || !edge.source || !edge.target) {
return false
}
}
@@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => {
if (source !== undefined && source !== undefined) {
const sourceNode = rawData.nodes[source]
const targetNode = rawData.nodes[target]
+ if (!sourceNode) {
+ console.error(`Source node ${edge.source} is undefined`)
+ continue
+ }
+ if (!targetNode) {
+ console.error(`Target node ${edge.target} is undefined`)
+ continue
+ }
sourceNode.degree += 1
targetNode.degree += 1
}
@@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
for (const rawEdge of rawGraph?.edges ?? []) {
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
- label: rawEdge.type
+ label: rawEdge.type || undefined
})
}
diff --git a/lightrag_webui/src/stores/graph.ts b/lightrag_webui/src/stores/graph.ts
index b78e9bf8..b7c2120c 100644
--- a/lightrag_webui/src/stores/graph.ts
+++ b/lightrag_webui/src/stores/graph.ts
@@ -19,7 +19,7 @@ export type RawEdgeType = {
id: string
source: string
target: string
- type: string
+ type?: string
properties: Record
dynamicId: string