Merge branch 'HKUDS:main' into main

This commit is contained in:
Saifeddine ALOUI
2025-03-21 14:20:51 +01:00
committed by GitHub
62 changed files with 2851 additions and 820 deletions

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.2.6"
__version__ = "1.2.7"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -3,11 +3,16 @@ from datetime import datetime, timedelta
import jwt
from fastapi import HTTPException, status
from pydantic import BaseModel
from dotenv import load_dotenv
load_dotenv()
class TokenPayload(BaseModel):
sub: str
exp: datetime
sub: str # Username
exp: datetime # Expiration time
role: str = "user" # User role, default is regular user
metadata: dict = {} # Additional metadata
class AuthHandler:
@@ -15,13 +20,60 @@ class AuthHandler:
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
self.algorithm = "HS256"
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
self.guest_expire_hours = int(
os.getenv("GUEST_TOKEN_EXPIRE_HOURS", 2)
) # Guest token default expiration time
def create_token(
self,
username: str,
role: str = "user",
custom_expire_hours: int = None,
metadata: dict = None,
) -> str:
"""
Create JWT token
Args:
username: Username
role: User role, default is "user", guest is "guest"
custom_expire_hours: Custom expiration time (hours), if None use default value
metadata: Additional metadata
Returns:
str: Encoded JWT token
"""
# Choose default expiration time based on role
if custom_expire_hours is None:
if role == "guest":
expire_hours = self.guest_expire_hours
else:
expire_hours = self.expire_hours
else:
expire_hours = custom_expire_hours
expire = datetime.utcnow() + timedelta(hours=expire_hours)
# Create payload
payload = TokenPayload(
sub=username, exp=expire, role=role, metadata=metadata or {}
)
def create_token(self, username: str) -> str:
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
payload = TokenPayload(sub=username, exp=expire)
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
def validate_token(self, token: str) -> str:
def validate_token(self, token: str) -> dict:
"""
Validate JWT token
Args:
token: JWT token
Returns:
dict: Dictionary containing user information
Raises:
HTTPException: If token is invalid or expired
"""
try:
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
expire_timestamp = payload["exp"]
@@ -31,7 +83,14 @@ class AuthHandler:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
)
return payload["sub"]
# Return complete payload instead of just username
return {
"username": payload["sub"],
"role": payload.get("role", "user"),
"metadata": payload.get("metadata", {}),
"exp": expire_time,
}
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"

View File

@@ -29,7 +29,9 @@ preload_app = True
worker_class = "uvicorn.workers.UvicornWorker"
# Other Gunicorn configurations
timeout = int(os.getenv("TIMEOUT", 150)) # Default 150s to match run_with_gunicorn.py
timeout = int(
os.getenv("TIMEOUT", 150 * 2)
) # Default 150s *2 to match run_with_gunicorn.py
keepalive = int(os.getenv("KEEPALIVE", 5)) # Default 5s
# Logging configuration

View File

@@ -10,6 +10,7 @@ import logging.config
import uvicorn
import pipmaster as pm
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from pathlib import Path
import configparser
from ascii_colors import ASCIIColors
@@ -48,7 +49,7 @@ from .auth import auth_handler
# Load environment variables
# Updated to use the .env that is inside the current folder
# This update allows the user to put a different.env file for each lightrag folder
load_dotenv(".env", override=True)
load_dotenv()
# Initialize config parser
config = configparser.ConfigParser()
@@ -341,25 +342,62 @@ def create_app(args):
ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api")
@app.post("/login")
@app.get("/")
async def redirect_to_webui():
"""Redirect root path to /webui"""
return RedirectResponse(url="/webui")
@app.get("/auth-status", dependencies=[Depends(optional_api_key)])
async def get_auth_status():
"""Get authentication status and guest token if auth is not configured"""
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
# Authentication not configured, return guest token
guest_token = auth_handler.create_token(
username="guest", role="guest", metadata={"auth_mode": "disabled"}
)
return {
"auth_configured": False,
"access_token": guest_token,
"token_type": "bearer",
"auth_mode": "disabled",
"message": "Authentication is disabled. Using guest access.",
}
return {"auth_configured": True, "auth_mode": "enabled"}
@app.post("/login", dependencies=[Depends(optional_api_key)])
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
username = os.getenv("AUTH_USERNAME")
password = os.getenv("AUTH_PASSWORD")
if not (username and password):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authentication not configured",
# Authentication not configured, return guest token
guest_token = auth_handler.create_token(
username="guest", role="guest", metadata={"auth_mode": "disabled"}
)
return {
"access_token": guest_token,
"token_type": "bearer",
"auth_mode": "disabled",
"message": "Authentication is disabled. Using guest access.",
}
if form_data.username != username or form_data.password != password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
)
# Regular user login
user_token = auth_handler.create_token(
username=username, role="user", metadata={"auth_mode": "enabled"}
)
return {
"access_token": auth_handler.create_token(username),
"access_token": user_token,
"token_type": "bearer",
"auth_mode": "enabled",
}
@app.get("/health", dependencies=[Depends(optional_api_key)])

View File

@@ -405,7 +405,7 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path):
async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
"""Index multiple files concurrently
"""Index multiple files sequentially to avoid high CPU load
Args:
rag: LightRAG instance
@@ -416,12 +416,12 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
try:
enqueued = False
if len(file_paths) == 1:
enqueued = await pipeline_enqueue_file(rag, file_paths[0])
else:
tasks = [pipeline_enqueue_file(rag, path) for path in file_paths]
enqueued = any(await asyncio.gather(*tasks))
# Process files sequentially
for file_path in file_paths:
if await pipeline_enqueue_file(rag, file_path):
enqueued = True
# Process the queue only if at least one file was successfully enqueued
if enqueued:
await rag.apipeline_process_enqueue_documents()
except Exception as e:
@@ -472,14 +472,34 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
total_files = len(new_files)
logger.info(f"Found {total_files} new files to index.")
for idx, file_path in enumerate(new_files):
try:
await pipeline_index_file(rag, file_path)
except Exception as e:
logger.error(f"Error indexing file {file_path}: {str(e)}")
if not new_files:
return
# Get MAX_PARALLEL_INSERT from global_args
max_parallel = global_args["max_parallel_insert"]
# Calculate batch size as 2 * MAX_PARALLEL_INSERT
batch_size = 2 * max_parallel
# Process files in batches
for i in range(0, total_files, batch_size):
batch_files = new_files[i : i + batch_size]
batch_num = i // batch_size + 1
total_batches = (total_files + batch_size - 1) // batch_size
logger.info(
f"Processing batch {batch_num}/{total_batches} with {len(batch_files)} files"
)
await pipeline_index_files(rag, batch_files)
# Log progress
processed = min(i + batch_size, total_files)
logger.info(
f"Processed {processed}/{total_files} files ({processed/total_files*100:.1f}%)"
)
except Exception as e:
logger.error(f"Error during scanning process: {str(e)}")
logger.error(traceback.format_exc())
def create_document_routes(

View File

@@ -13,7 +13,7 @@ from dotenv import load_dotenv
# Updated to use the .env that is inside the current folder
# This update allows the user to put a different.env file for each lightrag folder
load_dotenv(".env")
load_dotenv()
def check_and_install_dependencies():
@@ -140,7 +140,7 @@ def main():
# Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = (
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2))
)
# Keepalive configuration

View File

@@ -9,14 +9,14 @@ import sys
import logging
from ascii_colors import ASCIIColors
from lightrag.api import __api_version__
from fastapi import HTTPException, Security, Depends, Request
from fastapi import HTTPException, Security, Depends, Request, status
from dotenv import load_dotenv
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler
# Load environment variables
load_dotenv(override=True)
load_dotenv()
global_args = {"main_args": None}
@@ -35,19 +35,46 @@ ollama_server_infos = OllamaServerInfos()
def get_auth_dependency():
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
# Set default whitelist paths
whitelist = os.getenv("WHITELIST_PATHS", "/login,/health").split(",")
async def dependency(
request: Request,
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
):
# Check if authentication is configured
auth_configured = bool(
os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")
)
# If authentication is not configured, skip all validation
if not auth_configured:
return
# For configured auth, allow whitelist paths without token
if request.url.path in whitelist:
return
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
return
# Require token for all other paths when auth is configured
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required"
)
auth_handler.validate_token(token)
try:
token_info = auth_handler.validate_token(token)
# Reject guest tokens when authentication is configured
if token_info.get("role") == "guest":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required. Guest access not allowed when authentication is configured.",
)
except Exception:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
return
return dependency
@@ -338,6 +365,9 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
)
# Get MAX_PARALLEL_INSERT from environment
global_args["max_parallel_insert"] = get_env_value("MAX_PARALLEL_INSERT", 2, int)
# Handle openai-ollama special case
if args.llm_binding == "openai-ollama":
args.llm_binding = "openai"
@@ -414,8 +444,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.log_level}")
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" ├─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
ASCIIColors.white(" ├─ History Turns: ", end="")
ASCIIColors.yellow(f"{args.history_turns}")
ASCIIColors.white(" └─ API Key: ", end="")
ASCIIColors.yellow("Set" if args.key else "Not Set")
@@ -432,8 +462,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.llm_binding}")
ASCIIColors.white(" ├─ Host: ", end="")
ASCIIColors.yellow(f"{args.llm_binding_host}")
ASCIIColors.white(" ─ Model: ", end="")
ASCIIColors.white(" ─ Model: ", end="")
ASCIIColors.yellow(f"{args.llm_model}")
ASCIIColors.white(" └─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
# Embedding Configuration
ASCIIColors.magenta("\n📊 Embedding Configuration:")
@@ -448,8 +480,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
# RAG Configuration
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
ASCIIColors.white(" ├─ Max Async Operations: ", end="")
ASCIIColors.white(" ├─ Max Async for LLM: ", end="")
ASCIIColors.yellow(f"{args.max_async}")
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
ASCIIColors.yellow(f"{global_args['max_parallel_insert']}")
ASCIIColors.white(" ├─ Max Tokens: ", end="")
ASCIIColors.yellow(f"{args.max_tokens}")
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
@@ -458,8 +492,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.chunk_size}")
ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="")
ASCIIColors.yellow(f"{args.chunk_overlap_size}")
ASCIIColors.white(" ├─ History Turns: ", end="")
ASCIIColors.yellow(f"{args.history_turns}")
ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
ASCIIColors.yellow(f"{args.cosine_threshold}")
ASCIIColors.white(" ├─ Top-K: ", end="")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -5,11 +5,11 @@
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate" />
<meta http-equiv="Pragma" content="no-cache" />
<meta http-equiv="Expires" content="0" />
<link rel="icon" type="image/svg+xml" href="./logo.png" />
<link rel="icon" type="image/svg+xml" href="logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lightrag</title>
<script type="module" crossorigin src="./assets/index-DwcJE583.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-BV5s8k-a.css">
<script type="module" crossorigin src="/webui/assets/index-4I5HV9Fr.js"></script>
<link rel="stylesheet" crossorigin href="/webui/assets/index-BSOt8Nur.css">
</head>
<body>
<div id="root"></div>

View File

@@ -257,6 +257,8 @@ class DocProcessingStatus:
"""First 100 chars of document content, used for preview"""
content_length: int
"""Total length of document"""
file_path: str
"""File path of the document"""
status: DocStatus
"""Current processing status"""
created_at: str

View File

@@ -87,6 +87,9 @@ class JsonDocStatusStorage(DocStatusStorage):
# If content is missing, use content_summary as content
if "content" not in data and "content_summary" in data:
data["content"] = data["content_summary"]
# If file_path is not in data, use document id as file path
if "file_path" not in data:
data["file_path"] = "no-file-path"
result[k] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(f"Missing required field for document {k}: {e}")

View File

@@ -373,6 +373,9 @@ class NetworkXStorage(BaseGraphStorage):
# Add edges to result
for edge in subgraph.edges():
source, target = edge
# Esure unique edge_id for undirect graph
if source > target:
source, target = target, source
edge_id = f"{source}-{target}"
if edge_id in seen_edges:
continue

View File

@@ -423,6 +423,7 @@ class PGVectorStorage(BaseVectorStorage):
"full_doc_id": item["full_doc_id"],
"content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()),
"file_path": item["file_path"],
}
except Exception as e:
logger.error(f"Error to prepare upsert,\nsql: {e}\nitem: {item}")
@@ -445,6 +446,7 @@ class PGVectorStorage(BaseVectorStorage):
"content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()),
"chunk_ids": chunk_ids,
"file_path": item["file_path"],
# TODO: add document_id
}
return upsert_sql, data
@@ -465,6 +467,7 @@ class PGVectorStorage(BaseVectorStorage):
"content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()),
"chunk_ids": chunk_ids,
"file_path": item["file_path"],
# TODO: add document_id
}
return upsert_sql, data
@@ -732,7 +735,7 @@ class PGDocStatusStorage(DocStatusStorage):
if result is None or result == []:
return None
else:
return DocProcessingStatus(
return dict(
content=result[0]["content"],
content_length=result[0]["content_length"],
content_summary=result[0]["content_summary"],
@@ -740,11 +743,34 @@ class PGDocStatusStorage(DocStatusStorage):
chunks_count=result[0]["chunks_count"],
created_at=result[0]["created_at"],
updated_at=result[0]["updated_at"],
file_path=result[0]["file_path"],
)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get doc_chunks data by id"""
raise NotImplementedError
"""Get doc_chunks data by multiple IDs."""
if not ids:
return []
sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.db.workspace, "ids": ids}
results = await self.db.query(sql, params, True)
if not results:
return []
return [
{
"content": row["content"],
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"file_path": row["file_path"],
}
for row in results
]
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
@@ -774,6 +800,7 @@ class PGDocStatusStorage(DocStatusStorage):
created_at=element["created_at"],
updated_at=element["updated_at"],
chunks_count=element["chunks_count"],
file_path=element["file_path"],
)
for element in result
}
@@ -793,14 +820,15 @@ class PGDocStatusStorage(DocStatusStorage):
if not data:
return
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
values($1,$2,$3,$4,$5,$6,$7)
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path)
values($1,$2,$3,$4,$5,$6,$7,$8)
on conflict(id,workspace) do update set
content = EXCLUDED.content,
content_summary = EXCLUDED.content_summary,
content_length = EXCLUDED.content_length,
chunks_count = EXCLUDED.chunks_count,
status = EXCLUDED.status,
file_path = EXCLUDED.file_path,
updated_at = CURRENT_TIMESTAMP"""
for k, v in data.items():
# chunks_count is optional
@@ -814,6 +842,7 @@ class PGDocStatusStorage(DocStatusStorage):
"content_length": v["content_length"],
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
"status": v["status"],
"file_path": v["file_path"],
},
)
@@ -1058,7 +1087,6 @@ class PGGraphStorage(BaseGraphStorage):
Args:
query (str): a cypher query to be executed
params (dict): parameters for the query
Returns:
list[dict[str, Any]]: a list of dictionaries containing the result set
@@ -1549,6 +1577,7 @@ TABLES = {
tokens INTEGER,
content TEXT,
content_vector VECTOR,
file_path VARCHAR(256),
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
@@ -1563,7 +1592,8 @@ TABLES = {
content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
chunk_id TEXT NULL,
chunk_ids VARCHAR(255)[] NULL,
file_path TEXT NULL,
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1577,7 +1607,8 @@ TABLES = {
content_vector VECTOR,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
chunk_id TEXT NULL,
chunk_ids VARCHAR(255)[] NULL,
file_path TEXT NULL,
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1602,6 +1633,7 @@ TABLES = {
content_length int4 NULL,
chunks_count int4 NULL,
status varchar(64) NULL,
file_path TEXT NULL,
created_at timestamp DEFAULT CURRENT_TIMESTAMP NULL,
updated_at timestamp DEFAULT CURRENT_TIMESTAMP NULL,
CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
@@ -1650,35 +1682,38 @@ SQL_TEMPLATES = {
update_time = CURRENT_TIMESTAMP
""",
"upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector)
VALUES ($1, $2, $3, $4, $5, $6, $7)
chunk_order_index, full_doc_id, content, content_vector, file_path)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (workspace,id) DO UPDATE
SET tokens=EXCLUDED.tokens,
chunk_order_index=EXCLUDED.chunk_order_index,
full_doc_id=EXCLUDED.full_doc_id,
content = EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
file_path=EXCLUDED.file_path,
update_time = CURRENT_TIMESTAMP
""",
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
content_vector, chunk_ids)
VALUES ($1, $2, $3, $4, $5, $6::varchar[])
content_vector, chunk_ids, file_path)
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
ON CONFLICT (workspace,id) DO UPDATE
SET entity_name=EXCLUDED.entity_name,
content=EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
chunk_ids=EXCLUDED.chunk_ids,
file_path=EXCLUDED.file_path,
update_time=CURRENT_TIMESTAMP
""",
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
target_id, content, content_vector, chunk_ids)
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[])
target_id, content, content_vector, chunk_ids, file_path)
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8)
ON CONFLICT (workspace,id) DO UPDATE
SET source_id=EXCLUDED.source_id,
target_id=EXCLUDED.target_id,
content=EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
chunk_ids=EXCLUDED.chunk_ids,
file_path=EXCLUDED.file_path,
update_time = CURRENT_TIMESTAMP
""",
# SQL for VectorStorage

View File

@@ -41,6 +41,9 @@ _pipeline_status_lock: Optional[LockType] = None
_graph_db_lock: Optional[LockType] = None
_data_init_lock: Optional[LockType] = None
# async locks for coroutine synchronization in multiprocess mode
_async_locks: Optional[Dict[str, asyncio.Lock]] = None
class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
@@ -51,12 +54,14 @@ class UnifiedLock(Generic[T]):
is_async: bool,
name: str = "unnamed",
enable_logging: bool = True,
async_lock: Optional[asyncio.Lock] = None,
):
self._lock = lock
self._is_async = is_async
self._pid = os.getpid() # for debug only
self._name = name # for debug only
self._enable_logging = enable_logging # for debug only
self._async_lock = async_lock # auxiliary lock for coroutine synchronization
async def __aenter__(self) -> "UnifiedLock[T]":
try:
@@ -64,16 +69,39 @@ class UnifiedLock(Generic[T]):
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
enable_output=self._enable_logging,
)
# If in multiprocess mode and async lock exists, acquire it first
if not self._is_async and self._async_lock is not None:
direct_log(
f"== Lock == Process {self._pid}: Acquiring async lock for '{self._name}'",
enable_output=self._enable_logging,
)
await self._async_lock.acquire()
direct_log(
f"== Lock == Process {self._pid}: Async lock for '{self._name}' acquired",
enable_output=self._enable_logging,
)
# Then acquire the main lock
if self._is_async:
await self._lock.acquire()
else:
self._lock.acquire()
direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
enable_output=self._enable_logging,
)
return self
except Exception as e:
# If main lock acquisition fails, release the async lock if it was acquired
if (
not self._is_async
and self._async_lock is not None
and self._async_lock.locked()
):
self._async_lock.release()
direct_log(
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
level="ERROR",
@@ -82,15 +110,29 @@ class UnifiedLock(Generic[T]):
raise
async def __aexit__(self, exc_type, exc_val, exc_tb):
main_lock_released = False
try:
direct_log(
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
enable_output=self._enable_logging,
)
# Release main lock first
if self._is_async:
self._lock.release()
else:
self._lock.release()
main_lock_released = True
# Then release async lock if in multiprocess mode
if not self._is_async and self._async_lock is not None:
direct_log(
f"== Lock == Process {self._pid}: Releasing async lock for '{self._name}'",
enable_output=self._enable_logging,
)
self._async_lock.release()
direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
enable_output=self._enable_logging,
@@ -101,6 +143,31 @@ class UnifiedLock(Generic[T]):
level="ERROR",
enable_output=self._enable_logging,
)
# If main lock release failed but async lock hasn't been released, try to release it
if (
not main_lock_released
and not self._is_async
and self._async_lock is not None
):
try:
direct_log(
f"== Lock == Process {self._pid}: Attempting to release async lock after main lock failure",
level="WARNING",
enable_output=self._enable_logging,
)
self._async_lock.release()
direct_log(
f"== Lock == Process {self._pid}: Successfully released async lock after main lock failure",
enable_output=self._enable_logging,
)
except Exception as inner_e:
direct_log(
f"== Lock == Process {self._pid}: Failed to release async lock after main lock failure: {inner_e}",
level="ERROR",
enable_output=self._enable_logging,
)
raise
def __enter__(self) -> "UnifiedLock[T]":
@@ -151,51 +218,61 @@ class UnifiedLock(Generic[T]):
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
async_lock = _async_locks.get("internal_lock") if is_multiprocess else None
return UnifiedLock(
lock=_internal_lock,
is_async=not is_multiprocess,
name="internal_lock",
enable_logging=enable_logging,
async_lock=async_lock,
)
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
async_lock = _async_locks.get("storage_lock") if is_multiprocess else None
return UnifiedLock(
lock=_storage_lock,
is_async=not is_multiprocess,
name="storage_lock",
enable_logging=enable_logging,
async_lock=async_lock,
)
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
async_lock = _async_locks.get("pipeline_status_lock") if is_multiprocess else None
return UnifiedLock(
lock=_pipeline_status_lock,
is_async=not is_multiprocess,
name="pipeline_status_lock",
enable_logging=enable_logging,
async_lock=async_lock,
)
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified graph database lock for ensuring atomic operations"""
async_lock = _async_locks.get("graph_db_lock") if is_multiprocess else None
return UnifiedLock(
lock=_graph_db_lock,
is_async=not is_multiprocess,
name="graph_db_lock",
enable_logging=enable_logging,
async_lock=async_lock,
)
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified data initialization lock for ensuring atomic data initialization"""
async_lock = _async_locks.get("data_init_lock") if is_multiprocess else None
return UnifiedLock(
lock=_data_init_lock,
is_async=not is_multiprocess,
name="data_init_lock",
enable_logging=enable_logging,
async_lock=async_lock,
)
@@ -229,7 +306,8 @@ def initialize_share_data(workers: int = 1):
_shared_dicts, \
_init_flags, \
_initialized, \
_update_flags
_update_flags, \
_async_locks
# Check if already initialized
if _initialized:
@@ -251,6 +329,16 @@ def initialize_share_data(workers: int = 1):
_shared_dicts = _manager.dict()
_init_flags = _manager.dict()
_update_flags = _manager.dict()
# Initialize async locks for multiprocess mode
_async_locks = {
"internal_lock": asyncio.Lock(),
"storage_lock": asyncio.Lock(),
"pipeline_status_lock": asyncio.Lock(),
"graph_db_lock": asyncio.Lock(),
"data_init_lock": asyncio.Lock(),
}
direct_log(
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
)
@@ -264,6 +352,7 @@ def initialize_share_data(workers: int = 1):
_shared_dicts = {}
_init_flags = {}
_update_flags = {}
_async_locks = None # No need for async locks in single process mode
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
# Mark as initialized
@@ -458,7 +547,8 @@ def finalize_share_data():
_shared_dicts, \
_init_flags, \
_initialized, \
_update_flags
_update_flags, \
_async_locks
# Check if already initialized
if not _initialized:
@@ -523,5 +613,6 @@ def finalize_share_data():
_graph_db_lock = None
_data_init_lock = None
_update_flags = None
_async_locks = None
direct_log(f"Process {os.getpid()} storage data finalization complete")

View File

@@ -183,10 +183,10 @@ class LightRAG:
embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = field(default=32)
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 32)))
"""Batch size for embedding computations."""
embedding_func_max_async: int = field(default=16)
embedding_func_max_async: int = field(default=int(os.getenv("EMBEDDING_FUNC_MAX_ASYNC", 16)))
"""Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field(
@@ -389,20 +389,21 @@ class LightRAG:
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
meta_fields={"entity_name", "source_id", "content"},
meta_fields={"entity_name", "source_id", "content", "file_path"},
)
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id", "source_id", "content"},
meta_fields={"src_id", "tgt_id", "source_id", "content", "file_path"},
)
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
embedding_func=self.embedding_func,
meta_fields={"full_doc_id", "content", "file_path"},
)
# Initialize document status storage
@@ -547,6 +548,7 @@ class LightRAG:
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: str | list[str] | None = None,
file_paths: str | list[str] | None = None,
) -> None:
"""Sync Insert documents with checkpoint support
@@ -557,10 +559,13 @@ class LightRAG:
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
ids: single string of the document ID or list of unique document IDs, if not provided, MD5 hash IDs will be generated
file_paths: single string of the file path or list of file paths, used for citation
"""
loop = always_get_an_event_loop()
loop.run_until_complete(
self.ainsert(input, split_by_character, split_by_character_only, ids)
self.ainsert(
input, split_by_character, split_by_character_only, ids, file_paths
)
)
async def ainsert(
@@ -569,6 +574,7 @@ class LightRAG:
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: str | list[str] | None = None,
file_paths: str | list[str] | None = None,
) -> None:
"""Async Insert documents with checkpoint support
@@ -579,8 +585,9 @@ class LightRAG:
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
file_paths: list of file paths corresponding to each document, used for citation
"""
await self.apipeline_enqueue_documents(input, ids)
await self.apipeline_enqueue_documents(input, ids, file_paths)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
@@ -654,7 +661,10 @@ class LightRAG:
await self._insert_done()
async def apipeline_enqueue_documents(
self, input: str | list[str], ids: list[str] | None = None
self,
input: str | list[str],
ids: list[str] | None = None,
file_paths: str | list[str] | None = None,
) -> None:
"""
Pipeline for Processing Documents
@@ -664,11 +674,30 @@ class LightRAG:
3. Generate document initial status
4. Filter out already processed documents
5. Enqueue document in status
Args:
input: Single document string or list of document strings
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
file_paths: list of file paths corresponding to each document, used for citation
"""
if isinstance(input, str):
input = [input]
if isinstance(ids, str):
ids = [ids]
if isinstance(file_paths, str):
file_paths = [file_paths]
# If file_paths is provided, ensure it matches the number of documents
if file_paths is not None:
if isinstance(file_paths, str):
file_paths = [file_paths]
if len(file_paths) != len(input):
raise ValueError(
"Number of file paths must match the number of documents"
)
else:
# If no file paths provided, use placeholder
file_paths = ["unknown_source"] * len(input)
# 1. Validate ids if provided or generate MD5 hash IDs
if ids is not None:
@@ -681,32 +710,59 @@ class LightRAG:
raise ValueError("IDs must be unique")
# Generate contents dict of IDs provided by user and documents
contents = {id_: doc for id_, doc in zip(ids, input)}
contents = {
id_: {"content": doc, "file_path": path}
for id_, doc, path in zip(ids, input, file_paths)
}
else:
# Clean input text and remove duplicates
input = list(set(clean_text(doc) for doc in input))
# Generate contents dict of MD5 hash IDs and documents
contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
cleaned_input = [
(clean_text(doc), path) for doc, path in zip(input, file_paths)
]
unique_content_with_paths = {}
# Keep track of unique content and their paths
for content, path in cleaned_input:
if content not in unique_content_with_paths:
unique_content_with_paths[content] = path
# Generate contents dict of MD5 hash IDs and documents with paths
contents = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"file_path": path,
}
for content, path in unique_content_with_paths.items()
}
# 2. Remove duplicate contents
unique_contents = {
id_: content
for content, id_ in {
content: id_ for id_, content in contents.items()
}.items()
unique_contents = {}
for id_, content_data in contents.items():
content = content_data["content"]
file_path = content_data["file_path"]
if content not in unique_contents:
unique_contents[content] = (id_, file_path)
# Reconstruct contents with unique content
contents = {
id_: {"content": content, "file_path": file_path}
for content, (id_, file_path) in unique_contents.items()
}
# 3. Generate document initial status
new_docs: dict[str, Any] = {
id_: {
"content": content,
"content_summary": get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"content": content_data["content"],
"content_summary": get_content_summary(content_data["content"]),
"content_length": len(content_data["content"]),
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
"file_path": content_data[
"file_path"
], # Store file path in document status
}
for id_, content in unique_contents.items()
for id_, content_data in contents.items()
}
# 4. Filter out already processed documents
@@ -841,11 +897,15 @@ class LightRAG:
) -> None:
"""Process single document"""
try:
# Get file path from status document
file_path = getattr(status_doc, "file_path", "unknown_source")
# Generate chunks from document
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
"file_path": file_path, # Add file path to each chunk
}
for dp in self.chunking_func(
status_doc.content,
@@ -856,6 +916,7 @@ class LightRAG:
self.tiktoken_model_name,
)
}
# Process document (text chunks and full docs) in parallel
# Create tasks with references for potential cancellation
doc_status_task = asyncio.create_task(
@@ -863,11 +924,13 @@ class LightRAG:
{
doc_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
}
}
)
@@ -906,6 +969,7 @@ class LightRAG:
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
}
}
)
@@ -937,6 +1001,7 @@ class LightRAG:
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
}
}
)
@@ -1063,7 +1128,10 @@ class LightRAG:
loop.run_until_complete(self.ainsert_custom_kg(custom_kg, full_doc_id))
async def ainsert_custom_kg(
self, custom_kg: dict[str, Any], full_doc_id: str = None
self,
custom_kg: dict[str, Any],
full_doc_id: str = None,
file_path: str = "custom_kg",
) -> None:
update_storage = False
try:
@@ -1093,6 +1161,7 @@ class LightRAG:
"full_doc_id": full_doc_id
if full_doc_id is not None
else source_id,
"file_path": file_path, # Add file path
"status": DocStatus.PROCESSED,
}
all_chunks_data[chunk_id] = chunk_entry
@@ -1197,6 +1266,7 @@ class LightRAG:
"source_id": dp["source_id"],
"description": dp["description"],
"entity_type": dp["entity_type"],
"file_path": file_path, # Add file path
}
for dp in all_entities_data
}
@@ -1212,6 +1282,7 @@ class LightRAG:
"keywords": dp["keywords"],
"description": dp["description"],
"weight": dp["weight"],
"file_path": file_path, # Add file path
}
for dp in all_relationships_data
}
@@ -1473,8 +1544,7 @@ class LightRAG:
"""
try:
# 1. Get the document status and related data
doc_status = await self.doc_status.get_by_id(doc_id)
if not doc_status:
if not await self.doc_status.get_by_id(doc_id):
logger.warning(f"Document {doc_id} not found")
return
@@ -1877,6 +1947,8 @@ class LightRAG:
# 2. Update entity information in the graph
new_node_data = {**node_data, **updated_data}
new_node_data["entity_id"] = new_entity_name
if "entity_name" in new_node_data:
del new_node_data[
"entity_name"
@@ -1893,7 +1965,7 @@ class LightRAG:
# Store relationships that need to be updated
relations_to_update = []
relations_to_delete = []
# Get all edges related to the original entity
edges = await self.chunk_entity_relation_graph.get_node_edges(
entity_name
@@ -1905,6 +1977,12 @@ class LightRAG:
source, target
)
if edge_data:
relations_to_delete.append(
compute_mdhash_id(source + target, prefix="rel-")
)
relations_to_delete.append(
compute_mdhash_id(target + source, prefix="rel-")
)
if source == entity_name:
await self.chunk_entity_relation_graph.upsert_edge(
new_entity_name, target, edge_data
@@ -1930,6 +2008,12 @@ class LightRAG:
f"Deleted old entity '{entity_name}' and its vector embedding from database"
)
# Delete old relation records from vector database
await self.relationships_vdb.delete(relations_to_delete)
logger.info(
f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database"
)
# Update relationship vector representations
for src, tgt, edge_data in relations_to_update:
description = edge_data.get("description", "")
@@ -2220,7 +2304,6 @@ class LightRAG:
"""Synchronously create a new entity.
Creates a new entity in the knowledge graph and adds it to the vector database.
Args:
entity_name: Name of the new entity
entity_data: Dictionary containing entity attributes, e.g. {"description": "description", "entity_type": "type"}
@@ -2429,39 +2512,21 @@ class LightRAG:
# 4. Get all relationships of the source entities
all_relations = []
for entity_name in source_entities:
# Get all relationships where this entity is the source
outgoing_edges = await self.chunk_entity_relation_graph.get_node_edges(
# Get all relationships of the source entities
edges = await self.chunk_entity_relation_graph.get_node_edges(
entity_name
)
if outgoing_edges:
for src, tgt in outgoing_edges:
if edges:
for src, tgt in edges:
# Ensure src is the current entity
if src == entity_name:
edge_data = await self.chunk_entity_relation_graph.get_edge(
src, tgt
)
all_relations.append(("outgoing", src, tgt, edge_data))
# Get all relationships where this entity is the target
incoming_edges = []
all_labels = await self.chunk_entity_relation_graph.get_all_labels()
for label in all_labels:
if label == entity_name:
continue
node_edges = await self.chunk_entity_relation_graph.get_node_edges(
label
)
for src, tgt in node_edges or []:
if tgt == entity_name:
incoming_edges.append((src, tgt))
for src, tgt in incoming_edges:
edge_data = await self.chunk_entity_relation_graph.get_edge(
src, tgt
)
all_relations.append(("incoming", src, tgt, edge_data))
all_relations.append((src, tgt, edge_data))
# 5. Create or update the target entity
merged_entity_data["entity_id"] = target_entity
if not target_exists:
await self.chunk_entity_relation_graph.upsert_node(
target_entity, merged_entity_data
@@ -2475,8 +2540,11 @@ class LightRAG:
# 6. Recreate all relationships, pointing to the target entity
relation_updates = {} # Track relationships that need to be merged
relations_to_delete = []
for rel_type, src, tgt, edge_data in all_relations:
for src, tgt, edge_data in all_relations:
relations_to_delete.append(compute_mdhash_id(src + tgt, prefix="rel-"))
relations_to_delete.append(compute_mdhash_id(tgt + src, prefix="rel-"))
new_src = target_entity if src in source_entities else src
new_tgt = target_entity if tgt in source_entities else tgt
@@ -2521,6 +2589,12 @@ class LightRAG:
f"Created or updated relationship: {rel_data['src']} -> {rel_data['tgt']}"
)
# Delete relationships records from vector database
await self.relationships_vdb.delete(relations_to_delete)
logger.info(
f"Deleted {len(relations_to_delete)} relation records for entity '{entity_name}' from vector database"
)
# 7. Update entity vector representation
description = merged_entity_data.get("description", "")
source_id = merged_entity_data.get("source_id", "")
@@ -2583,19 +2657,6 @@ class LightRAG:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
await self.entities_vdb.delete([entity_id])
# Also ensure any relationships specific to this entity are deleted from vector DB
# This is a safety check, as these should have been transformed to the target entity already
entity_relation_prefix = compute_mdhash_id(entity_name, prefix="rel-")
relations_with_entity = await self.relationships_vdb.search_by_prefix(
entity_relation_prefix
)
if relations_with_entity:
relation_ids = [r["id"] for r in relations_with_entity]
await self.relationships_vdb.delete(relation_ids)
logger.info(
f"Deleted {len(relation_ids)} relation records for entity '{entity_name}' from vector database"
)
logger.info(
f"Deleted source entity '{entity_name}' and its vector embedding from database"
)

View File

@@ -138,16 +138,31 @@ async def hf_model_complete(
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
device = next(embed_model.parameters()).device
# Detect the appropriate device
if torch.cuda.is_available():
device = next(embed_model.parameters()).device # Use CUDA if available
elif torch.backends.mps.is_available():
device = torch.device("mps") # Use MPS for Apple Silicon
else:
device = torch.device("cpu") # Fallback to CPU
# Move the model to the detected device
embed_model = embed_model.to(device)
# Tokenize the input texts and move them to the same device
encoded_texts = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).to(device)
# Perform inference
with torch.no_grad():
outputs = embed_model(
input_ids=encoded_texts["input_ids"],
attention_mask=encoded_texts["attention_mask"],
)
embeddings = outputs.last_hidden_state.mean(dim=1)
# Convert embeddings to NumPy
if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy()
else:

View File

@@ -138,6 +138,7 @@ async def _handle_entity_relation_summary(
async def _handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
file_path: str = "unknown_source",
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
@@ -171,13 +172,14 @@ async def _handle_single_entity_extraction(
entity_type=entity_type,
description=entity_description,
source_id=chunk_key,
metadata={"created_at": time.time()},
file_path=file_path,
)
async def _handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
file_path: str = "unknown_source",
):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
@@ -199,7 +201,7 @@ async def _handle_single_relationship_extraction(
description=edge_description,
keywords=edge_keywords,
source_id=edge_source_id,
metadata={"created_at": time.time()},
file_path=file_path,
)
@@ -213,6 +215,7 @@ async def _merge_nodes_then_upsert(
already_entity_types = []
already_source_ids = []
already_description = []
already_file_paths = []
already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None:
@@ -220,6 +223,9 @@ async def _merge_nodes_then_upsert(
already_source_ids.extend(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
)
already_file_paths.extend(
split_string_by_multi_markers(already_node["file_path"], [GRAPH_FIELD_SEP])
)
already_description.append(already_node["description"])
entity_type = sorted(
@@ -235,6 +241,11 @@ async def _merge_nodes_then_upsert(
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
)
file_path = GRAPH_FIELD_SEP.join(
set([dp["file_path"] for dp in nodes_data] + already_file_paths)
)
logger.debug(f"file_path: {file_path}")
description = await _handle_entity_relation_summary(
entity_name, description, global_config
)
@@ -243,6 +254,7 @@ async def _merge_nodes_then_upsert(
entity_type=entity_type,
description=description,
source_id=source_id,
file_path=file_path,
)
await knowledge_graph_inst.upsert_node(
entity_name,
@@ -263,6 +275,7 @@ async def _merge_edges_then_upsert(
already_source_ids = []
already_description = []
already_keywords = []
already_file_paths = []
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
@@ -279,6 +292,14 @@ async def _merge_edges_then_upsert(
)
)
# Get file_path with empty string default if missing or None
if already_edge.get("file_path") is not None:
already_file_paths.extend(
split_string_by_multi_markers(
already_edge["file_path"], [GRAPH_FIELD_SEP]
)
)
# Get description with empty string default if missing or None
if already_edge.get("description") is not None:
already_description.append(already_edge["description"])
@@ -315,6 +336,12 @@ async def _merge_edges_then_upsert(
+ already_source_ids
)
)
file_path = GRAPH_FIELD_SEP.join(
set(
[dp["file_path"] for dp in edges_data if dp.get("file_path")]
+ already_file_paths
)
)
for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)):
@@ -325,6 +352,7 @@ async def _merge_edges_then_upsert(
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
"file_path": file_path,
},
)
description = await _handle_entity_relation_summary(
@@ -338,6 +366,7 @@ async def _merge_edges_then_upsert(
description=description,
keywords=keywords,
source_id=source_id,
file_path=file_path,
),
)
@@ -347,6 +376,7 @@ async def _merge_edges_then_upsert(
description=description,
keywords=keywords,
source_id=source_id,
file_path=file_path,
)
return edge_data
@@ -456,11 +486,14 @@ async def extract_entities(
else:
return await use_llm_func(input_text)
async def _process_extraction_result(result: str, chunk_key: str):
async def _process_extraction_result(
result: str, chunk_key: str, file_path: str = "unknown_source"
):
"""Process a single extraction result (either initial or gleaning)
Args:
result (str): The extraction result to process
chunk_key (str): The chunk key for source tracking
file_path (str): The file path for citation
Returns:
tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
"""
@@ -482,14 +515,14 @@ async def extract_entities(
)
if_entities = await _handle_single_entity_extraction(
record_attributes, chunk_key
record_attributes, chunk_key, file_path
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key
record_attributes, chunk_key, file_path
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
@@ -508,6 +541,8 @@ async def extract_entities(
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
# Get file path from chunk data or use default
file_path = chunk_dp.get("file_path", "unknown_source")
# Get initial extraction
hint_prompt = entity_extract_prompt.format(
@@ -517,9 +552,9 @@ async def extract_entities(
final_result = await _user_llm_func_with_cache(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
# Process initial extraction
# Process initial extraction with file path
maybe_nodes, maybe_edges = await _process_extraction_result(
final_result, chunk_key
final_result, chunk_key, file_path
)
# Process additional gleaning results
@@ -530,9 +565,9 @@ async def extract_entities(
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
# Process gleaning result separately
# Process gleaning result separately with file path
glean_nodes, glean_edges = await _process_extraction_result(
glean_result, chunk_key
glean_result, chunk_key, file_path
)
# Merge results
@@ -637,9 +672,7 @@ async def extract_entities(
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"metadata": {
"created_at": dp.get("metadata", {}).get("created_at", time.time())
},
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in all_entities_data
}
@@ -653,9 +686,7 @@ async def extract_entities(
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"metadata": {
"created_at": dp.get("metadata", {}).get("created_at", time.time())
},
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in all_relationships_data
}
@@ -1232,12 +1263,17 @@ async def _get_node_data(
"description",
"rank",
"created_at",
"file_path",
]
]
for i, n in enumerate(node_datas):
created_at = n.get("created_at", "UNKNOWN")
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from node data
file_path = n.get("file_path", "unknown_source")
entites_section_list.append(
[
i,
@@ -1246,6 +1282,7 @@ async def _get_node_data(
n.get("description", "UNKNOWN"),
n["rank"],
created_at,
file_path,
]
)
entities_context = list_of_list_to_csv(entites_section_list)
@@ -1260,6 +1297,7 @@ async def _get_node_data(
"weight",
"rank",
"created_at",
"file_path",
]
]
for i, e in enumerate(use_relations):
@@ -1267,6 +1305,10 @@ async def _get_node_data(
# Convert timestamp to readable format
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from edge data
file_path = e.get("file_path", "unknown_source")
relations_section_list.append(
[
i,
@@ -1277,6 +1319,7 @@ async def _get_node_data(
e["weight"],
e["rank"],
created_at,
file_path,
]
)
relations_context = list_of_list_to_csv(relations_section_list)
@@ -1492,6 +1535,7 @@ async def _get_edge_data(
"weight",
"rank",
"created_at",
"file_path",
]
]
for i, e in enumerate(edge_datas):
@@ -1499,6 +1543,10 @@ async def _get_edge_data(
# Convert timestamp to readable format
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from edge data
file_path = e.get("file_path", "unknown_source")
relations_section_list.append(
[
i,
@@ -1509,16 +1557,23 @@ async def _get_edge_data(
e["weight"],
e["rank"],
created_at,
file_path,
]
)
relations_context = list_of_list_to_csv(relations_section_list)
entites_section_list = [["id", "entity", "type", "description", "rank"]]
entites_section_list = [
["id", "entity", "type", "description", "rank", "created_at", "file_path"]
]
for i, n in enumerate(use_entities):
created_at = e.get("created_at", "Unknown")
created_at = n.get("created_at", "Unknown")
# Convert timestamp to readable format
if isinstance(created_at, (int, float)):
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Get file path from node data
file_path = n.get("file_path", "unknown_source")
entites_section_list.append(
[
i,
@@ -1527,6 +1582,7 @@ async def _get_edge_data(
n.get("description", "UNKNOWN"),
n["rank"],
created_at,
file_path,
]
)
entities_context = list_of_list_to_csv(entites_section_list)
@@ -1882,13 +1938,14 @@ async def kg_query_with_keywords(
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
# 6. Generate response
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
)
# 清理响应内容
# Clean up response content
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")

View File

@@ -61,7 +61,7 @@ Text:
```
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. "If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us."
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
@@ -92,7 +92,7 @@ Among the hardest hit, Nexon Technologies saw its stock plummet by 7.8% after re
Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1.5%, reaching $2,080 per ounce, as investors sought safe-haven assets. Crude oil prices continued their rally, climbing to $87.60 per barrel, supported by supply constraints and strong demand.
Financial experts are closely watching the Federal Reserves next move, as speculation grows over potential rate hikes. The upcoming policy announcement is expected to influence investor confidence and overall market stability.
Financial experts are closely watching the Federal Reserve's next move, as speculation grows over potential rate hikes. The upcoming policy announcement is expected to influence investor confidence and overall market stability.
```
Output:
@@ -222,6 +222,7 @@ When handling relationships with timestamps:
- Use markdown formatting with appropriate section headings
- Please respond in the same language as the user's question.
- Ensure the response maintains continuity with the conversation history.
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
- If you don't know the answer, just say so.
- Do not make anything up. Do not include information not provided by the Knowledge Base."""
@@ -319,6 +320,7 @@ When handling content with timestamps:
- Use markdown formatting with appropriate section headings
- Please respond in the same language as the user's question.
- Ensure the response maintains continuity with the conversation history.
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
- If you don't know the answer, just say so.
- Do not include information not provided by the Document Chunks."""
@@ -378,8 +380,8 @@ When handling information with timestamps:
- Use markdown formatting with appropriate section headings
- Please respond in the same language as the user's question.
- Ensure the response maintains continuity with the conversation history.
- Organize answer in sesctions focusing on one main point or aspect of the answer
- Organize answer in sections focusing on one main point or aspect of the answer
- Use clear and descriptive section titles that reflect the content
- List up to 5 most important reference sources at the end under "References" sesction. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), in the following format: [KG/DC] Source content
- List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
- If you don't know the answer, just say so. Do not make anything up.
- Do not include information not provided by the Data Sources."""

View File

@@ -109,15 +109,17 @@ def setup_logger(
logger_name: str,
level: str = "INFO",
add_filter: bool = False,
log_file_path: str = None,
log_file_path: str | None = None,
enable_file_logging: bool = True,
):
"""Set up a logger with console and file handlers
"""Set up a logger with console and optionally file handlers
Args:
logger_name: Name of the logger to set up
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
add_filter: Whether to add LightragPathFilter to the logger
log_file_path: Path to the log file. If None, will use current directory/lightrag.log
log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd
enable_file_logging: Whether to enable logging to a file (defaults to True)
"""
# Configure formatters
detailed_formatter = logging.Formatter(
@@ -125,18 +127,6 @@ def setup_logger(
)
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
# Get log file path
if log_file_path is None:
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Ensure log directory exists
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logger_instance = logging.getLogger(logger_name)
logger_instance.setLevel(level)
logger_instance.handlers = [] # Clear existing handlers
@@ -148,16 +138,34 @@ def setup_logger(
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
# Add file handler
file_handler = logging.handlers.RotatingFileHandler(
filename=log_file_path,
maxBytes=log_max_bytes,
backupCount=log_backup_count,
encoding="utf-8",
)
file_handler.setFormatter(detailed_formatter)
file_handler.setLevel(level)
logger_instance.addHandler(file_handler)
# Add file handler by default unless explicitly disabled
if enable_file_logging:
# Get log file path
if log_file_path is None:
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Ensure log directory exists
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
try:
# Add file handler
file_handler = logging.handlers.RotatingFileHandler(
filename=log_file_path,
maxBytes=log_max_bytes,
backupCount=log_backup_count,
encoding="utf-8",
)
file_handler.setFormatter(detailed_formatter)
file_handler.setLevel(level)
logger_instance.addHandler(file_handler)
except PermissionError as e:
logger.warning(f"Could not create log file at {log_file_path}: {str(e)}")
logger.warning("Continuing with console logging only")
# Add path filter if requested
if add_filter: