Merge branch 'HKUDS:main' into main

This commit is contained in:
Saifeddine ALOUI
2025-03-04 08:27:53 +01:00
committed by GitHub
23 changed files with 563 additions and 177 deletions

1
MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
recursive-include lightrag/api/webui *

View File

@@ -106,6 +106,9 @@ import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.utils import setup_logger
setup_logger("lightrag", level="INFO")
async def initialize_rag():
rag = LightRAG(
@@ -344,6 +347,10 @@ from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_i
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.utils import setup_logger
# Setup log handler for LightRAG
setup_logger("lightrag", level="INFO")
async def initialize_rag():
rag = LightRAG(
@@ -498,44 +505,58 @@ rag.query_with_separate_keyword_extraction(
```python
custom_kg = {
"chunks": [
{
"content": "Alice and Bob are collaborating on quantum computing research.",
"source_id": "doc-1"
}
],
"entities": [
{
"entity_name": "CompanyA",
"entity_type": "Organization",
"description": "A major technology company",
"source_id": "Source1"
"entity_name": "Alice",
"entity_type": "person",
"description": "Alice is a researcher specializing in quantum physics.",
"source_id": "doc-1"
},
{
"entity_name": "ProductX",
"entity_type": "Product",
"description": "A popular product developed by CompanyA",
"source_id": "Source1"
"entity_name": "Bob",
"entity_type": "person",
"description": "Bob is a mathematician.",
"source_id": "doc-1"
},
{
"entity_name": "Quantum Computing",
"entity_type": "technology",
"description": "Quantum computing utilizes quantum mechanical phenomena for computation.",
"source_id": "doc-1"
}
],
"relationships": [
{
"src_id": "CompanyA",
"tgt_id": "ProductX",
"description": "CompanyA develops ProductX",
"keywords": "develop, produce",
"src_id": "Alice",
"tgt_id": "Bob",
"description": "Alice and Bob are research partners.",
"keywords": "collaboration research",
"weight": 1.0,
"source_id": "Source1"
"source_id": "doc-1"
},
{
"src_id": "Alice",
"tgt_id": "Quantum Computing",
"description": "Alice conducts research on quantum computing.",
"keywords": "research expertise",
"weight": 1.0,
"source_id": "doc-1"
},
{
"src_id": "Bob",
"tgt_id": "Quantum Computing",
"description": "Bob researches quantum computing.",
"keywords": "research application",
"weight": 1.0,
"source_id": "doc-1"
}
],
"chunks": [
{
"content": "ProductX, developed by CompanyA, has revolutionized the market with its cutting-edge features.",
"source_id": "Source1",
},
{
"content": "PersonA is a prominent researcher at UniversityB, focusing on artificial intelligence and machine learning.",
"source_id": "Source2",
},
{
"content": "None",
"source_id": "UNKNOWN",
},
],
]
}
rag.insert_custom_kg(custom_kg)
@@ -640,17 +661,27 @@ export NEO4J_URI="neo4j://localhost:7687"
export NEO4J_USERNAME="neo4j"
export NEO4J_PASSWORD="password"
# Setup logger for LightRAG
setup_logger("lightrag", level="INFO")
# When you launch the project be sure to override the default KG: NetworkX
# by specifying kg="Neo4JStorage".
# Note: Default settings use NetworkX
# Initialize LightRAG with Neo4J implementation.
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
graph_storage="Neo4JStorage", #<-----------override KG default
log_level="DEBUG" #<-----------override log_level default
)
# Initialize database connections
await rag.initialize_storages()
# Initialize pipeline status for document processing
await initialize_pipeline_status()
return rag
```
see test_neo4j.py for a working example.
@@ -754,7 +785,8 @@ rag.delete_by_doc_id("doc_id")
LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.
### Create Entities and Relations
<details>
<summary> <b>Create Entities and Relations</b> </summary>
```python
# Create new entity
@@ -776,8 +808,10 @@ relation = rag.create_relation("Google", "Gmail", {
"weight": 2.0
})
```
</details>
### Edit Entities and Relations
<details>
<summary> <b>Edit Entities and Relations</b> </summary>
```python
# Edit an existing entity
@@ -799,6 +833,7 @@ updated_relation = rag.edit_relation("Google", "Google Mail", {
"weight": 3.0
})
```
</details>
All operations are available in both synchronous and asynchronous versions. The asynchronous versions have the prefix "a" (e.g., `acreate_entity`, `aedit_relation`).
@@ -859,7 +894,6 @@ Valid modes are:
| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
| **log\_level** | | Log level for application runtime | `logging.DEBUG` |
| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
@@ -881,7 +915,6 @@ Valid modes are:
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
|**log\_dir** | `str` | Directory to store logs. | `./` |
</details>

View File

@@ -5,6 +5,7 @@
# PORT=9621
# WORKERS=1
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
# MAX_GRAPH_NODES=1000 # Max nodes return from grap retrieval
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
### Optional SSL Configuration

View File

@@ -81,6 +81,8 @@ asyncio.run(test_funcs())
embedding_dimension = 3072
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
@@ -91,8 +93,14 @@ rag = LightRAG(
),
)
rag.initialize_storages()
initialize_pipeline_status()
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
rag = asyncio.run(initialize_rag())
book1 = open("./book_1.txt", encoding="utf-8")
book2 = open("./book_2.txt", encoding="utf-8")
@@ -112,3 +120,7 @@ print(rag.query(query_text, param=QueryParam(mode="global")))
print("\nResult (Hybrid):")
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
if __name__ == "__main__":
main()

View File

@@ -53,3 +53,7 @@ def main():
"What are the top themes in this story?", param=QueryParam(mode=mode)
)
)
if __name__ == "__main__":
main()

View File

@@ -125,7 +125,7 @@ async def initialize_rag():
async def main():
try:
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
rag = await initialize_rag()
# reading file
with open("./book.txt", "r", encoding="utf-8") as f:

View File

@@ -77,7 +77,7 @@ async def initialize_rag():
async def main():
try:
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
rag = await initialize_rag()
with open("./book.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())

View File

@@ -81,7 +81,7 @@ async def initialize_rag():
async def main():
try:
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
rag = await initialize_rag()
with open("./book.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())

View File

@@ -107,7 +107,7 @@ async def initialize_rag():
async def main():
try:
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
rag = await initialize_rag()
# Extract and Insert into LightRAG storage
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:

View File

@@ -87,7 +87,7 @@ async def initialize_rag():
async def main():
try:
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
rag = await initialize_rag()
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())

View File

@@ -59,7 +59,7 @@ async def initialize_rag():
async def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
rag = await initialize_rag()
# add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func

View File

@@ -102,7 +102,7 @@ async def initialize_rag():
# Example function demonstrating the new query_with_separate_keyword_extraction usage
async def run_example():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
rag = await initialize_rag()
book1 = open("./book_1.txt", encoding="utf-8")
book2 = open("./book_2.txt", encoding="utf-8")

View File

@@ -2,12 +2,15 @@
import os
import logging
from lightrag.kg.shared_storage import finalize_share_data
from lightrag.api.lightrag_server import LightragPathFilter
from lightrag.utils import setup_logger
# Get log directory path from environment variable
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Ensure log directory exists
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
@@ -108,6 +111,9 @@ def on_starting(server):
except ImportError:
print("psutil not installed, skipping memory usage reporting")
# Log the location of the LightRAG log file
print(f"LightRAG log file: {log_file_path}\n")
print("Gunicorn initialization complete, forking workers...\n")
@@ -134,51 +140,18 @@ def post_fork(server, worker):
Executed after a worker has been forked.
This is a good place to set up worker-specific configurations.
"""
# Configure formatters
detailed_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False):
"""Set up a logger with console and file handlers"""
logger_instance = logging.getLogger(logger_name)
logger_instance.setLevel(level)
logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
# Add file handler
file_handler = logging.handlers.RotatingFileHandler(
filename=log_file_path,
maxBytes=log_max_bytes,
backupCount=log_backup_count,
encoding="utf-8",
)
file_handler.setFormatter(detailed_formatter)
file_handler.setLevel(level)
logger_instance.addHandler(file_handler)
# Add path filter if requested
if add_filter:
path_filter = LightragPathFilter()
logger_instance.addFilter(path_filter)
# Set up main loggers
log_level = loglevel.upper() if loglevel else "INFO"
setup_logger("uvicorn", log_level)
setup_logger("uvicorn.access", log_level, add_filter=True)
setup_logger("lightrag", log_level, add_filter=True)
setup_logger("uvicorn", log_level, add_filter=False, log_file_path=log_file_path)
setup_logger(
"uvicorn.access", log_level, add_filter=True, log_file_path=log_file_path
)
setup_logger("lightrag", log_level, add_filter=True, log_file_path=log_file_path)
# Set up lightrag submodule loggers
for name in logging.root.manager.loggerDict:
if name.startswith("lightrag."):
setup_logger(name, log_level, add_filter=True)
setup_logger(name, log_level, add_filter=True, log_file_path=log_file_path)
# Disable uvicorn.error logger
uvicorn_error_logger = logging.getLogger("uvicorn.error")

View File

@@ -6,7 +6,6 @@ from fastapi import (
FastAPI,
Depends,
)
from fastapi.responses import FileResponse
import asyncio
import os
import logging
@@ -331,7 +330,6 @@ def create_app(args):
"similarity_threshold": 0.95,
"use_llm_check": False,
},
log_level=args.log_level,
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
)
@@ -361,7 +359,6 @@ def create_app(args):
"similarity_threshold": 0.95,
"use_llm_check": False,
},
log_level=args.log_level,
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
)
@@ -412,10 +409,6 @@ def create_app(args):
name="webui",
)
@app.get("/webui/")
async def webui_root():
return FileResponse(static_dir / "index.html")
return app
@@ -439,6 +432,9 @@ def configure_logging():
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
print(f"\nLightRAG log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups

View File

@@ -215,9 +215,29 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
| ".scss"
| ".less"
):
try:
# Try to decode as UTF-8
content = file.decode("utf-8")
# Validate content
if not content or len(content.strip()) == 0:
logger.error(f"Empty content in file: {file_path.name}")
return False
# Check if content looks like binary data string representation
if content.startswith("b'") or content.startswith('b"'):
logger.error(
f"File {file_path.name} appears to contain binary data representation instead of text"
)
return False
except UnicodeDecodeError:
logger.error(
f"File {file_path.name} is not valid UTF-8 encoded text. Please convert it to UTF-8 before processing."
)
return False
case ".pdf":
if not pm.is_installed("pypdf2"):
if not pm.is_installed("pypdf2"): # type: ignore
pm.install("pypdf2")
from PyPDF2 import PdfReader # type: ignore
from io import BytesIO
@@ -227,18 +247,18 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
if not pm.is_installed("python-docx"): # type: ignore
pm.install("docx")
from docx import Document
from docx import Document # type: ignore
from io import BytesIO
docx_file = BytesIO(file)
doc = Document(docx_file)
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
case ".pptx":
if not pm.is_installed("pptx"):
if not pm.is_installed("python-pptx"): # type: ignore
pm.install("pptx")
from pptx import Presentation
from pptx import Presentation # type: ignore
from io import BytesIO
pptx_file = BytesIO(file)
@@ -248,9 +268,9 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
if hasattr(shape, "text"):
content += shape.text + "\n"
case ".xlsx":
if not pm.is_installed("openpyxl"):
if not pm.is_installed("openpyxl"): # type: ignore
pm.install("openpyxl")
from openpyxl import load_workbook
from openpyxl import load_workbook # type: ignore
from io import BytesIO
xlsx_file = BytesIO(file)

View File

@@ -16,12 +16,32 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/label/list", dependencies=[Depends(optional_api_key)])
async def get_graph_labels():
"""Get all graph labels"""
"""
Get all graph labels
Returns:
List[str]: List of graph labels
"""
return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
async def get_knowledge_graph(label: str, max_depth: int = 3):
"""Get knowledge graph for a specific label"""
"""
Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args:
label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
Returns:
Dict[str, List[str]]: Knowledge graph for label
"""
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth)
return router

View File

@@ -44,6 +44,15 @@ class JsonKVStorage(BaseKVStorage):
)
write_json(data_dict, self._file_name)
async def get_all(self) -> dict[str, Any]:
"""Get all data from storage
Returns:
Dictionary containing all stored data
"""
async with self._storage_lock:
return dict(self._data)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._storage_lock:
return self._data.get(id)

View File

@@ -23,7 +23,7 @@ import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import (
from neo4j import ( # type: ignore
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
@@ -34,6 +34,9 @@ from neo4j import (
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@final
@dataclass
@@ -470,40 +473,61 @@ class Neo4JStorage(BaseGraphStorage):
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence (nodes containing the specified label string)
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Key fixes:
1. Include the starting node itself
2. Handle multi-label nodes
3. Clarify relationship directions
4. Add depth control
Args:
node_label (str): String to match in node labels (will match any node containing this string in its label)
max_depth (int, optional): Maximum depth of the graph. Defaults to 5.
Returns:
KnowledgeGraph: Complete connected subgraph for specified node
"""
label = node_label.strip('"')
# Escape single quotes to prevent injection attacks
escaped_label = label.replace("'", "\\'")
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
async with self._driver.session(database=self._DATABASE) as session:
try:
main_query = ""
if label == "*":
main_query = """
MATCH (n)
WITH collect(DISTINCT n) AS nodes
MATCH ()-[r]-()
RETURN nodes, collect(DISTINCT r) AS relationships;
OPTIONAL MATCH (n)-[r]-()
WITH n, count(r) AS degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect(n) AS nodes
MATCH (a)-[r]->(b)
WHERE a IN nodes AND b IN nodes
RETURN nodes, collect(DISTINCT r) AS relationships
"""
result_set = await session.run(
main_query, {"max_nodes": MAX_GRAPH_NODES}
)
else:
# Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_query = f"""
MATCH (n)
WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}')
RETURN n LIMIT 1
"""
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!")
logger.warning(
f"No nodes containing '{label}' in their labels found!"
)
return result
# Optimized query (including direction handling and self-loops)
# Main query uses partial matching
main_query = f"""
MATCH (start:`{label}`)
MATCH (start)
WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}')
WITH start
CALL apoc.path.subgraphAll(start, {{
relationshipFilter: '>',
@@ -512,9 +536,25 @@ class Neo4JStorage(BaseGraphStorage):
bfs: true
}})
YIELD nodes, relationships
RETURN nodes, relationships
WITH start, nodes, relationships
UNWIND nodes AS node
OPTIONAL MATCH (node)-[r]-()
WITH node, count(r) AS degree, start, nodes, relationships,
CASE
WHEN id(node) = id(start) THEN 2
WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1
ELSE 0
END AS priority
ORDER BY priority DESC, degree DESC
LIMIT $max_nodes
WITH collect(node) AS filtered_nodes, nodes, relationships
RETURN filtered_nodes AS nodes,
[rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships
"""
result_set = await session.run(main_query)
result_set = await session.run(
main_query, {"max_nodes": MAX_GRAPH_NODES}
)
record = await result_set.single()
if record:
@@ -650,8 +690,98 @@ class Neo4JStorage(BaseGraphStorage):
labels.append(record["label"])
return labels
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
"""Delete a node with the specified label
Args:
node_id: The label of the node to delete
"""
label = await self._ensure_label(node_id)
async def _do_delete(tx: AsyncManagedTransaction):
query = f"""
MATCH (n:`{label}`)
DETACH DELETE n
"""
await tx.run(query)
logger.debug(f"Deleted node with label '{label}'")
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node labels to be deleted
"""
for node in nodes:
await self.delete_node(node)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
source_label = await self._ensure_label(source)
target_label = await self._ensure_label(target)
async def _do_delete_edge(tx: AsyncManagedTransaction):
query = f"""
MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`)
DELETE r
"""
await tx.run(query)
logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def embed_nodes(
self, algorithm: str

View File

@@ -24,6 +24,8 @@ from .shared_storage import (
is_multiprocess,
)
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@final
@dataclass
@@ -233,7 +235,12 @@ class NetworkXStorage(BaseGraphStorage):
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Args:
node_label: Label of the starting node
@@ -265,22 +272,51 @@ class NetworkXStorage(BaseGraphStorage):
logger.warning(f"No nodes found with label {node_label}")
return result
# Get subgraph using ego_graph
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
# Get subgraph using ego_graph from all matching nodes
combined_subgraph = nx.Graph()
for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
subgraph = combined_subgraph
# Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500
if len(subgraph.nodes()) > max_graph_nodes:
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
:max_graph_nodes
start_nodes = set()
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(subgraph.neighbors(start_node))
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
def priority_key(node_item):
node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0)
if node in start_nodes:
priority = 2
elif node in direct_connected_nodes:
priority = 1
else:
priority = 0
return (priority, degree)
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
:MAX_GRAPH_NODES
]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes
# Create new subgraph and keep nodes only with most degree
subgraph = subgraph.subgraph(top_node_ids)
logger.info(
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
)
# Add nodes to result
@@ -320,7 +356,7 @@ class NetworkXStorage(BaseGraphStorage):
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
type="RELATED",
source=str(source),
target=str(target),
properties=edge_data,

View File

@@ -174,6 +174,14 @@ class TiDBKVStorage(BaseKVStorage):
self.db = None
################ QUERY METHODS ################
async def get_all(self) -> dict[str, Any]:
"""Get all data from storage
Returns:
Dictionary containing all stored data
"""
async with self._storage_lock:
return dict(self._data)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Fetch doc_full data by id."""

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import configparser
import os
import warnings
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
@@ -85,14 +86,10 @@ class LightRAG:
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Logging
# Logging (Deprecated, use setup_logger in utils.py instead)
# ---
log_level: int = field(default=logger.level)
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log"))
"""Log file path."""
log_level: int | None = field(default=None)
log_file_path: str | None = field(default=None)
# Entity extraction
# ---
@@ -266,13 +263,30 @@ class LightRAG:
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self):
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
from lightrag.kg.shared_storage import (
initialize_share_data,
)
# Handle deprecated parameters
if self.log_level is not None:
warnings.warn(
"WARNING: log_level parameter is deprecated, use setup_logger in utils.py instead",
UserWarning,
stacklevel=2,
)
if self.log_file_path is not None:
warnings.warn(
"WARNING: log_file_path parameter is deprecated, use setup_logger in utils.py instead",
UserWarning,
stacklevel=2,
)
# Remove these attributes to prevent their use
if hasattr(self, "log_level"):
delattr(self, "log_level")
if hasattr(self, "log_file_path"):
delattr(self, "log_file_path")
initialize_share_data()
if not os.path.exists(self.working_dir):
@@ -671,8 +685,24 @@ class LightRAG:
all_new_doc_ids = set(new_docs.keys())
# Exclude IDs of documents that are already in progress
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
# Log ignored document IDs
ignored_ids = [
doc_id for doc_id in unique_new_doc_ids if doc_id not in new_docs
]
if ignored_ids:
logger.warning(
f"Ignoring {len(ignored_ids)} document IDs not found in new_docs"
)
for doc_id in ignored_ids:
logger.warning(f"Ignored document ID: {doc_id}")
# Filter new_docs to only include documents with unique IDs
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids}
new_docs = {
doc_id: new_docs[doc_id]
for doc_id in unique_new_doc_ids
if doc_id in new_docs
}
if not new_docs:
logger.info("No new unique documents were found.")
@@ -1159,7 +1189,7 @@ class LightRAG:
"""
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query,
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
@@ -1180,7 +1210,7 @@ class LightRAG:
)
elif param.mode == "naive":
response = await naive_query(
query,
query.strip(),
self.chunks_vdb,
self.text_chunks,
param,
@@ -1199,7 +1229,7 @@ class LightRAG:
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
query,
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
@@ -1417,14 +1447,22 @@ class LightRAG:
logger.debug(f"Starting deletion for document {doc_id}")
doc_to_chunk_id = doc_id.replace("doc", "chunk")
# 2. Get all chunks related to this document
# Find all chunks where full_doc_id equals the current doc_id
all_chunks = await self.text_chunks.get_all()
related_chunks = {
chunk_id: chunk_data
for chunk_id, chunk_data in all_chunks.items()
if isinstance(chunk_data, dict)
and chunk_data.get("full_doc_id") == doc_id
}
# 2. Get all related chunks
chunks = await self.text_chunks.get_by_id(doc_to_chunk_id)
if not chunks:
if not related_chunks:
logger.warning(f"No chunks found for document {doc_id}")
return
chunk_ids = {chunks["full_doc_id"].replace("doc", "chunk")}
# Get all related chunk IDs
chunk_ids = set(related_chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
# 3. Before deleting, check the related entities and relationships for these chunks
@@ -1612,9 +1650,18 @@ class LightRAG:
logger.warning(f"Document {doc_id} still exists in full_docs")
# Verify if chunks have been deleted
remaining_chunks = await self.text_chunks.get_by_id(doc_to_chunk_id)
if remaining_chunks:
logger.warning(f"Found {len(remaining_chunks)} remaining chunks")
all_remaining_chunks = await self.text_chunks.get_all()
remaining_related_chunks = {
chunk_id: chunk_data
for chunk_id, chunk_data in all_remaining_chunks.items()
if isinstance(chunk_data, dict)
and chunk_data.get("full_doc_id") == doc_id
}
if remaining_related_chunks:
logger.warning(
f"Found {len(remaining_related_chunks)} remaining chunks"
)
# Verify entities and relationships
for chunk_id in chunk_ids:

View File

@@ -6,6 +6,7 @@ import io
import csv
import json
import logging
import logging.handlers
import os
import re
from dataclasses import dataclass
@@ -68,6 +69,101 @@ logger.setLevel(logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
class LightragPathFilter(logging.Filter):
"""Filter for lightrag logger to filter out frequent path access logs"""
def __init__(self):
super().__init__()
# Define paths to be filtered
self.filtered_paths = ["/documents", "/health", "/webui/"]
def filter(self, record):
try:
# Check if record has the required attributes for an access log
if not hasattr(record, "args") or not isinstance(record.args, tuple):
return True
if len(record.args) < 5:
return True
# Extract method, path and status from the record args
method = record.args[1]
path = record.args[2]
status = record.args[4]
# Filter out successful GET requests to filtered paths
if (
method == "GET"
and (status == 200 or status == 304)
and path in self.filtered_paths
):
return False
return True
except Exception:
# In case of any error, let the message through
return True
def setup_logger(
logger_name: str,
level: str = "INFO",
add_filter: bool = False,
log_file_path: str = None,
):
"""Set up a logger with console and file handlers
Args:
logger_name: Name of the logger to set up
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
add_filter: Whether to add LightragPathFilter to the logger
log_file_path: Path to the log file. If None, will use current directory/lightrag.log
"""
# Configure formatters
detailed_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
# Get log file path
if log_file_path is None:
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Ensure log directory exists
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logger_instance = logging.getLogger(logger_name)
logger_instance.setLevel(level)
logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
# Add file handler
file_handler = logging.handlers.RotatingFileHandler(
filename=log_file_path,
maxBytes=log_max_bytes,
backupCount=log_backup_count,
encoding="utf-8",
)
file_handler.setFormatter(detailed_formatter)
file_handler.setLevel(level)
logger_instance.addHandler(file_handler)
# Add path filter if requested
if add_filter:
path_filter = LightragPathFilter()
logger_instance.addFilter(path_filter)
class UnlimitedSemaphore:
"""A context manager that allows unlimited access."""

View File

@@ -3,7 +3,7 @@ configparser
future
# Basic modules
numpy
gensim
pipmaster
pydantic
python-dotenv