Merge branch 'HKUDS:main' into main
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.2.4"
|
||||
__version__ = "1.2.5"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -50,9 +50,6 @@ from .auth import auth_handler
|
||||
# This update allows the user to put a different.env file for each lightrag folder
|
||||
load_dotenv(".env", override=True)
|
||||
|
||||
# Read entity extraction cache config
|
||||
enable_llm_cache = os.getenv("ENABLE_LLM_CACHE_FOR_EXTRACT", "false").lower() == "true"
|
||||
|
||||
# Initialize config parser
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini")
|
||||
@@ -144,23 +141,25 @@ def create_app(args):
|
||||
try:
|
||||
# Initialize database connections
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
# Auto scan documents if enabled
|
||||
if args.auto_scan_at_startup:
|
||||
# Check if a task is already running (with lock protection)
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
should_start_task = False
|
||||
async with get_pipeline_status_lock():
|
||||
if not pipeline_status.get("busy", False):
|
||||
should_start_task = True
|
||||
# Only start the task if no other task is running
|
||||
if should_start_task:
|
||||
# Create background task
|
||||
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
||||
app.state.background_tasks.add(task)
|
||||
task.add_done_callback(app.state.background_tasks.discard)
|
||||
logger.info("Auto scan task started at startup.")
|
||||
await initialize_pipeline_status()
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
|
||||
should_start_autoscan = False
|
||||
async with get_pipeline_status_lock():
|
||||
# Auto scan documents if enabled
|
||||
if args.auto_scan_at_startup:
|
||||
if not pipeline_status.get("autoscanned", False):
|
||||
pipeline_status["autoscanned"] = True
|
||||
should_start_autoscan = True
|
||||
|
||||
# Only run auto scan when no other process started it first
|
||||
if should_start_autoscan:
|
||||
# Create background task
|
||||
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
||||
app.state.background_tasks.add(task)
|
||||
task.add_done_callback(app.state.background_tasks.discard)
|
||||
logger.info(f"Process {os.getpid()} auto scan task started at startup.")
|
||||
|
||||
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
|
||||
|
||||
@@ -326,7 +325,7 @@ def create_app(args):
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable
|
||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||
embedding_cache_config={
|
||||
"enabled": True,
|
||||
"similarity_threshold": 0.95,
|
||||
@@ -355,7 +354,7 @@ def create_app(args):
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable
|
||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||
embedding_cache_config={
|
||||
"enabled": True,
|
||||
"similarity_threshold": 0.95,
|
||||
@@ -419,6 +418,7 @@ def create_app(args):
|
||||
"doc_status_storage": args.doc_status_storage,
|
||||
"graph_storage": args.graph_storage,
|
||||
"vector_storage": args.vector_storage,
|
||||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||
},
|
||||
"update_status": update_status,
|
||||
}
|
||||
|
@@ -16,7 +16,11 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from lightrag import LightRAG
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||
from lightrag.api.utils_api import (
|
||||
get_api_key_dependency,
|
||||
global_args,
|
||||
get_auth_dependency,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
@@ -240,54 +244,93 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
)
|
||||
return False
|
||||
case ".pdf":
|
||||
if not pm.is_installed("pypdf2"): # type: ignore
|
||||
pm.install("pypdf2")
|
||||
from PyPDF2 import PdfReader # type: ignore
|
||||
from io import BytesIO
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
pdf_file = BytesIO(file)
|
||||
reader = PdfReader(pdf_file)
|
||||
for page in reader.pages:
|
||||
content += page.extract_text() + "\n"
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
content = result.document.export_to_markdown()
|
||||
else:
|
||||
if not pm.is_installed("pypdf2"): # type: ignore
|
||||
pm.install("pypdf2")
|
||||
from PyPDF2 import PdfReader # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
pdf_file = BytesIO(file)
|
||||
reader = PdfReader(pdf_file)
|
||||
for page in reader.pages:
|
||||
content += page.extract_text() + "\n"
|
||||
case ".docx":
|
||||
if not pm.is_installed("python-docx"): # type: ignore
|
||||
pm.install("docx")
|
||||
from docx import Document # type: ignore
|
||||
from io import BytesIO
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
docx_file = BytesIO(file)
|
||||
doc = Document(docx_file)
|
||||
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
content = result.document.export_to_markdown()
|
||||
else:
|
||||
if not pm.is_installed("python-docx"): # type: ignore
|
||||
pm.install("docx")
|
||||
from docx import Document # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
docx_file = BytesIO(file)
|
||||
doc = Document(docx_file)
|
||||
content = "\n".join(
|
||||
[paragraph.text for paragraph in doc.paragraphs]
|
||||
)
|
||||
case ".pptx":
|
||||
if not pm.is_installed("python-pptx"): # type: ignore
|
||||
pm.install("pptx")
|
||||
from pptx import Presentation # type: ignore
|
||||
from io import BytesIO
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
pptx_file = BytesIO(file)
|
||||
prs = Presentation(pptx_file)
|
||||
for slide in prs.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
content += shape.text + "\n"
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
content = result.document.export_to_markdown()
|
||||
else:
|
||||
if not pm.is_installed("python-pptx"): # type: ignore
|
||||
pm.install("pptx")
|
||||
from pptx import Presentation # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
pptx_file = BytesIO(file)
|
||||
prs = Presentation(pptx_file)
|
||||
for slide in prs.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
content += shape.text + "\n"
|
||||
case ".xlsx":
|
||||
if not pm.is_installed("openpyxl"): # type: ignore
|
||||
pm.install("openpyxl")
|
||||
from openpyxl import load_workbook # type: ignore
|
||||
from io import BytesIO
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
xlsx_file = BytesIO(file)
|
||||
wb = load_workbook(xlsx_file)
|
||||
for sheet in wb:
|
||||
content += f"Sheet: {sheet.title}\n"
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
content += (
|
||||
"\t".join(
|
||||
str(cell) if cell is not None else "" for cell in row
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
content = result.document.export_to_markdown()
|
||||
else:
|
||||
if not pm.is_installed("openpyxl"): # type: ignore
|
||||
pm.install("openpyxl")
|
||||
from openpyxl import load_workbook # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
xlsx_file = BytesIO(file)
|
||||
wb = load_workbook(xlsx_file)
|
||||
for sheet in wb:
|
||||
content += f"Sheet: {sheet.title}\n"
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
content += (
|
||||
"\t".join(
|
||||
str(cell) if cell is not None else ""
|
||||
for cell in row
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
content += "\n"
|
||||
content += "\n"
|
||||
case _:
|
||||
logger.error(
|
||||
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||
|
@@ -11,7 +11,7 @@ import asyncio
|
||||
from ascii_colors import trace_exception
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.utils import encode_string_by_tiktoken
|
||||
from ..utils_api import ollama_server_infos
|
||||
from lightrag.api.utils_api import ollama_server_infos
|
||||
|
||||
|
||||
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
||||
|
@@ -18,6 +18,8 @@ from .auth import auth_handler
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
global_args = {"main_args": None}
|
||||
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
@@ -360,8 +362,17 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
||||
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
||||
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
||||
|
||||
# Inject LLM cache configuration
|
||||
args.enable_llm_cache_for_extract = get_env_value(
|
||||
"ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool
|
||||
)
|
||||
|
||||
# Select Document loading tool (DOCLING, DEFAULT)
|
||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||
|
||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
||||
|
||||
global_args["main_args"] = args
|
||||
return args
|
||||
|
||||
|
||||
@@ -451,8 +462,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.history_turns}")
|
||||
ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
|
||||
ASCIIColors.yellow(f"{args.cosine_threshold}")
|
||||
ASCIIColors.white(" └─ Top-K: ", end="")
|
||||
ASCIIColors.white(" ├─ Top-K: ", end="")
|
||||
ASCIIColors.yellow(f"{args.top_k}")
|
||||
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
|
||||
ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}")
|
||||
|
||||
# System Configuration
|
||||
ASCIIColors.magenta("\n💾 Storage Configuration:")
|
||||
|
@@ -127,6 +127,30 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(StorageNameSpace, ABC):
|
||||
|
@@ -271,3 +271,67 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Query the collection for a single vector by ID
|
||||
result = self._collection.get(
|
||||
ids=[id], include=["metadatas", "embeddings", "documents"]
|
||||
)
|
||||
|
||||
if not result or not result["ids"] or len(result["ids"]) == 0:
|
||||
return None
|
||||
|
||||
# Format the result to match the expected structure
|
||||
return {
|
||||
"id": result["ids"][0],
|
||||
"vector": result["embeddings"][0],
|
||||
"content": result["documents"][0],
|
||||
**result["metadatas"][0],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Query the collection for multiple vectors by IDs
|
||||
result = self._collection.get(
|
||||
ids=ids, include=["metadatas", "embeddings", "documents"]
|
||||
)
|
||||
|
||||
if not result or not result["ids"] or len(result["ids"]) == 0:
|
||||
return []
|
||||
|
||||
# Format the results to match the expected structure
|
||||
return [
|
||||
{
|
||||
"id": result["ids"][i],
|
||||
"vector": result["embeddings"][i],
|
||||
"content": result["documents"][i],
|
||||
**result["metadatas"][i],
|
||||
}
|
||||
for i in range(len(result["ids"]))
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
@@ -394,3 +394,46 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
|
||||
return matching_records
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
# Find the Faiss internal ID for the custom ID
|
||||
fid = self._find_faiss_id_by_custom_id(id)
|
||||
if fid is None:
|
||||
return None
|
||||
|
||||
# Get the metadata for the found ID
|
||||
metadata = self._id_to_meta.get(fid, {})
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
return {**metadata, "id": metadata.get("__id__")}
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for id in ids:
|
||||
fid = self._find_faiss_id_by_custom_id(id)
|
||||
if fid is not None:
|
||||
metadata = self._id_to_meta.get(fid, {})
|
||||
if metadata:
|
||||
results.append({**metadata, "id": metadata.get("__id__")})
|
||||
|
||||
return results
|
||||
|
@@ -15,6 +15,10 @@ from lightrag.utils import (
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_lock,
|
||||
get_data_init_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
clear_all_update_flags,
|
||||
try_initialize_namespace,
|
||||
)
|
||||
|
||||
@@ -27,21 +31,25 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._data = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(
|
||||
f"Loaded document status storage with {len(loaded_data)} records"
|
||||
)
|
||||
self._storage_lock = get_storage_lock()
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
async with get_data_init_lock():
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = await try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(
|
||||
f"Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records"
|
||||
)
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
@@ -87,18 +95,24 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._storage_lock:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
if self.storage_updated.value:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
logger.info(
|
||||
f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
await clear_all_update_flags(self.namespace)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
async with self._storage_lock:
|
||||
self._data.update(data)
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
await self.index_done_callback()
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
@@ -109,9 +123,12 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
async with self._storage_lock:
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
|
@@ -13,6 +13,10 @@ from lightrag.utils import (
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_lock,
|
||||
get_data_init_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
clear_all_update_flags,
|
||||
try_initialize_namespace,
|
||||
)
|
||||
|
||||
@@ -23,26 +27,63 @@ class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._data = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
async with get_data_init_lock():
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = await try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
|
||||
# Calculate data count based on namespace
|
||||
if self.namespace.endswith("cache"):
|
||||
# For cache namespaces, sum the cache entries across all cache types
|
||||
data_count = sum(
|
||||
len(first_level_dict)
|
||||
for first_level_dict in loaded_data.values()
|
||||
if isinstance(first_level_dict, dict)
|
||||
)
|
||||
else:
|
||||
# For non-cache namespaces, use the original count method
|
||||
data_count = len(loaded_data)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
||||
)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._storage_lock:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
if self.storage_updated.value:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
|
||||
# Calculate data count based on namespace
|
||||
if self.namespace.endswith("cache"):
|
||||
# # For cache namespaces, sum the cache entries across all cache types
|
||||
data_count = sum(
|
||||
len(first_level_dict)
|
||||
for first_level_dict in data_dict.values()
|
||||
if isinstance(first_level_dict, dict)
|
||||
)
|
||||
else:
|
||||
# For non-cache namespaces, use the original count method
|
||||
data_count = len(data_dict)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
await clear_all_update_flags(self.namespace)
|
||||
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all data from storage
|
||||
@@ -73,15 +114,16 @@ class JsonKVStorage(BaseKVStorage):
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
async with self._storage_lock:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
self._data.update(data)
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
async with self._storage_lock:
|
||||
for doc_id in ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
|
@@ -233,3 +233,57 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Query Milvus for a specific ID
|
||||
result = self._client.query(
|
||||
collection_name=self.namespace,
|
||||
filter=f'id == "{id}"',
|
||||
output_fields=list(self.meta_fields) + ["id"],
|
||||
)
|
||||
|
||||
if not result or len(result) == 0:
|
||||
return None
|
||||
|
||||
return result[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Prepare the ID filter expression
|
||||
id_list = '", "'.join(ids)
|
||||
filter_expr = f'id in ["{id_list}"]'
|
||||
|
||||
# Query Milvus with the filter
|
||||
result = self._client.query(
|
||||
collection_name=self.namespace,
|
||||
filter=filter_expr,
|
||||
output_fields=list(self.meta_fields) + ["id"],
|
||||
)
|
||||
|
||||
return result or []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
@@ -1073,6 +1073,59 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Search for the specific ID in MongoDB
|
||||
result = await self._data.find_one({"_id": id})
|
||||
if result:
|
||||
# Format the result to include id field expected by API
|
||||
result_dict = dict(result)
|
||||
if "_id" in result_dict and "id" not in result_dict:
|
||||
result_dict["id"] = result_dict["_id"]
|
||||
return result_dict
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Query MongoDB for multiple IDs
|
||||
cursor = self._data.find({"_id": {"$in": ids}})
|
||||
results = await cursor.to_list(length=None)
|
||||
|
||||
# Format results to include id field expected by API
|
||||
formatted_results = []
|
||||
for result in results:
|
||||
result_dict = dict(result)
|
||||
if "_id" in result_dict and "id" not in result_dict:
|
||||
result_dict["id"] = result_dict["_id"]
|
||||
formatted_results.append(result_dict)
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
@@ -258,3 +258,33 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
|
||||
return matching_records
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
client = await self._get_client()
|
||||
result = client.get([id])
|
||||
if result:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
client = await self._get_client()
|
||||
return client.get(ids)
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -531,6 +531,80 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Determine the table name based on namespace
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
|
||||
return None
|
||||
|
||||
# Create the appropriate ID field name based on namespace
|
||||
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
||||
if "CHUNKS" in table_name:
|
||||
id_field = "chunk_id"
|
||||
|
||||
# Prepare and execute the query
|
||||
query = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {id_field} = :id AND workspace = :workspace
|
||||
"""
|
||||
params = {"id": id, "workspace": self.db.workspace}
|
||||
|
||||
result = await self.db.query(query, params)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Determine the table name based on namespace
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
|
||||
return []
|
||||
|
||||
# Create the appropriate ID field name based on namespace
|
||||
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
||||
if "CHUNKS" in table_name:
|
||||
id_field = "chunk_id"
|
||||
|
||||
# Format the list of IDs for SQL IN clause
|
||||
ids_list = ", ".join([f"'{id}'" for id in ids])
|
||||
|
||||
# Prepare and execute the query
|
||||
query = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {id_field} IN ({ids_list}) AND workspace = :workspace
|
||||
"""
|
||||
params = {"workspace": self.db.workspace}
|
||||
|
||||
results = await self.db.query(query, params, multirows=True)
|
||||
return results or []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
@@ -621,6 +621,60 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
logger.error(f"Error during prefix search for '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
|
||||
return None
|
||||
|
||||
query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id=$2"
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
|
||||
try:
|
||||
result = await self.db.query(query, params)
|
||||
if result:
|
||||
return dict(result)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
|
||||
return []
|
||||
|
||||
ids_str = ",".join([f"'{id}'" for id in ids])
|
||||
query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
|
||||
params = {"workspace": self.db.workspace}
|
||||
|
||||
try:
|
||||
results = await self.db.query(query, params, multirows=True)
|
||||
return [dict(record) for record in results]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
@@ -7,12 +7,18 @@ from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
||||
|
||||
|
||||
# Define a direct print function for critical logs that must be visible in all processes
|
||||
def direct_log(message, level="INFO"):
|
||||
def direct_log(message, level="INFO", enable_output: bool = True):
|
||||
"""
|
||||
Log a message directly to stderr to ensure visibility in all processes,
|
||||
including the Gunicorn master process.
|
||||
|
||||
Args:
|
||||
message: The message to log
|
||||
level: Log level (default: "INFO")
|
||||
enable_output: Whether to actually output the log (default: True)
|
||||
"""
|
||||
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
||||
if enable_output:
|
||||
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -32,55 +38,165 @@ _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
|
||||
_storage_lock: Optional[LockType] = None
|
||||
_internal_lock: Optional[LockType] = None
|
||||
_pipeline_status_lock: Optional[LockType] = None
|
||||
_graph_db_lock: Optional[LockType] = None
|
||||
_data_init_lock: Optional[LockType] = None
|
||||
|
||||
|
||||
class UnifiedLock(Generic[T]):
|
||||
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
||||
|
||||
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
|
||||
def __init__(
|
||||
self,
|
||||
lock: Union[ProcessLock, asyncio.Lock],
|
||||
is_async: bool,
|
||||
name: str = "unnamed",
|
||||
enable_logging: bool = True,
|
||||
):
|
||||
self._lock = lock
|
||||
self._is_async = is_async
|
||||
self._pid = os.getpid() # for debug only
|
||||
self._name = name # for debug only
|
||||
self._enable_logging = enable_logging # for debug only
|
||||
|
||||
async def __aenter__(self) -> "UnifiedLock[T]":
|
||||
if self._is_async:
|
||||
await self._lock.acquire()
|
||||
else:
|
||||
self._lock.acquire()
|
||||
return self
|
||||
try:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
if self._is_async:
|
||||
await self._lock.acquire()
|
||||
else:
|
||||
self._lock.acquire()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
return self
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
|
||||
level="ERROR",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
raise
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._is_async:
|
||||
self._lock.release()
|
||||
else:
|
||||
self._lock.release()
|
||||
try:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
if self._is_async:
|
||||
self._lock.release()
|
||||
else:
|
||||
self._lock.release()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
|
||||
level="ERROR",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
raise
|
||||
|
||||
def __enter__(self) -> "UnifiedLock[T]":
|
||||
"""For backward compatibility"""
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
self._lock.acquire()
|
||||
return self
|
||||
try:
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
self._lock.acquire()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
return self
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
|
||||
level="ERROR",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
raise
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""For backward compatibility"""
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
self._lock.release()
|
||||
try:
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
self._lock.release()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
|
||||
level="ERROR",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def get_internal_lock() -> UnifiedLock:
|
||||
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
|
||||
return UnifiedLock(
|
||||
lock=_internal_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="internal_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def get_storage_lock() -> UnifiedLock:
|
||||
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
|
||||
return UnifiedLock(
|
||||
lock=_storage_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="storage_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def get_pipeline_status_lock() -> UnifiedLock:
|
||||
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
|
||||
return UnifiedLock(
|
||||
lock=_pipeline_status_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="pipeline_status_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified graph database lock for ensuring atomic operations"""
|
||||
return UnifiedLock(
|
||||
lock=_graph_db_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="graph_db_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified data initialization lock for ensuring atomic data initialization"""
|
||||
return UnifiedLock(
|
||||
lock=_data_init_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="data_init_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def initialize_share_data(workers: int = 1):
|
||||
@@ -108,6 +224,8 @@ def initialize_share_data(workers: int = 1):
|
||||
_storage_lock, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_graph_db_lock, \
|
||||
_data_init_lock, \
|
||||
_shared_dicts, \
|
||||
_init_flags, \
|
||||
_initialized, \
|
||||
@@ -120,14 +238,16 @@ def initialize_share_data(workers: int = 1):
|
||||
)
|
||||
return
|
||||
|
||||
_manager = Manager()
|
||||
_workers = workers
|
||||
|
||||
if workers > 1:
|
||||
is_multiprocess = True
|
||||
_manager = Manager()
|
||||
_internal_lock = _manager.Lock()
|
||||
_storage_lock = _manager.Lock()
|
||||
_pipeline_status_lock = _manager.Lock()
|
||||
_graph_db_lock = _manager.Lock()
|
||||
_data_init_lock = _manager.Lock()
|
||||
_shared_dicts = _manager.dict()
|
||||
_init_flags = _manager.dict()
|
||||
_update_flags = _manager.dict()
|
||||
@@ -139,6 +259,8 @@ def initialize_share_data(workers: int = 1):
|
||||
_internal_lock = asyncio.Lock()
|
||||
_storage_lock = asyncio.Lock()
|
||||
_pipeline_status_lock = asyncio.Lock()
|
||||
_graph_db_lock = asyncio.Lock()
|
||||
_data_init_lock = asyncio.Lock()
|
||||
_shared_dicts = {}
|
||||
_init_flags = {}
|
||||
_update_flags = {}
|
||||
@@ -164,6 +286,7 @@ async def initialize_pipeline_status():
|
||||
history_messages = _manager.list() if is_multiprocess else []
|
||||
pipeline_namespace.update(
|
||||
{
|
||||
"autoscanned": False, # Auto-scan started
|
||||
"busy": False, # Control concurrent processes
|
||||
"job_name": "Default Job", # Current job name (indexing files/indexing texts)
|
||||
"job_start": None, # Job start time
|
||||
@@ -200,7 +323,12 @@ async def get_update_flag(namespace: str):
|
||||
if is_multiprocess and _manager is not None:
|
||||
new_update_flag = _manager.Value("b", False)
|
||||
else:
|
||||
new_update_flag = False
|
||||
# Create a simple mutable object to store boolean value for compatibility with mutiprocess
|
||||
class MutableBoolean:
|
||||
def __init__(self, initial_value=False):
|
||||
self.value = initial_value
|
||||
|
||||
new_update_flag = MutableBoolean(False)
|
||||
|
||||
_update_flags[namespace].append(new_update_flag)
|
||||
return new_update_flag
|
||||
@@ -220,7 +348,26 @@ async def set_all_update_flags(namespace: str):
|
||||
if is_multiprocess:
|
||||
_update_flags[namespace][i].value = True
|
||||
else:
|
||||
_update_flags[namespace][i] = True
|
||||
# Use .value attribute instead of direct assignment
|
||||
_update_flags[namespace][i].value = True
|
||||
|
||||
|
||||
async def clear_all_update_flags(namespace: str):
|
||||
"""Clear all update flag of namespace indicating all workers need to reload data from files"""
|
||||
global _update_flags
|
||||
if _update_flags is None:
|
||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
# Update flags for both modes
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
if is_multiprocess:
|
||||
_update_flags[namespace][i].value = False
|
||||
else:
|
||||
# Use .value attribute instead of direct assignment
|
||||
_update_flags[namespace][i].value = False
|
||||
|
||||
|
||||
async def get_all_update_flags_status() -> Dict[str, list]:
|
||||
@@ -247,7 +394,7 @@ async def get_all_update_flags_status() -> Dict[str, list]:
|
||||
return result
|
||||
|
||||
|
||||
def try_initialize_namespace(namespace: str) -> bool:
|
||||
async def try_initialize_namespace(namespace: str) -> bool:
|
||||
"""
|
||||
Returns True if the current worker(process) gets initialization permission for loading data later.
|
||||
The worker does not get the permission is prohibited to load data from files.
|
||||
@@ -257,15 +404,17 @@ def try_initialize_namespace(namespace: str) -> bool:
|
||||
if _init_flags is None:
|
||||
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
|
||||
|
||||
if namespace not in _init_flags:
|
||||
_init_flags[namespace] = True
|
||||
async with get_internal_lock():
|
||||
if namespace not in _init_flags:
|
||||
_init_flags[namespace] = True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
|
||||
)
|
||||
return True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
|
||||
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
||||
)
|
||||
return True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -304,6 +453,8 @@ def finalize_share_data():
|
||||
_storage_lock, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_graph_db_lock, \
|
||||
_data_init_lock, \
|
||||
_shared_dicts, \
|
||||
_init_flags, \
|
||||
_initialized, \
|
||||
@@ -369,6 +520,8 @@ def finalize_share_data():
|
||||
_storage_lock = None
|
||||
_internal_lock = None
|
||||
_pipeline_status_lock = None
|
||||
_graph_db_lock = None
|
||||
_data_init_lock = None
|
||||
_update_flags = None
|
||||
|
||||
direct_log(f"Process {os.getpid()} storage data finalization complete")
|
||||
|
@@ -465,6 +465,100 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Determine which table to query based on namespace
|
||||
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
|
||||
sql_template = """
|
||||
SELECT entity_id as id, name as entity_name, entity_type, description, content
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE entity_id = :entity_id AND workspace = :workspace
|
||||
"""
|
||||
params = {"entity_id": id, "workspace": self.db.workspace}
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
|
||||
sql_template = """
|
||||
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
|
||||
keywords, description, content
|
||||
FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE relation_id = :relation_id AND workspace = :workspace
|
||||
"""
|
||||
params = {"relation_id": id, "workspace": self.db.workspace}
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
|
||||
sql_template = """
|
||||
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id = :chunk_id AND workspace = :workspace
|
||||
"""
|
||||
params = {"chunk_id": id, "workspace": self.db.workspace}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Namespace {self.namespace} not supported for get_by_id"
|
||||
)
|
||||
return None
|
||||
|
||||
result = await self.db.query(sql_template, params=params)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Format IDs for SQL IN clause
|
||||
ids_str = ", ".join([f"'{id}'" for id in ids])
|
||||
|
||||
# Determine which table to query based on namespace
|
||||
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
|
||||
sql_template = f"""
|
||||
SELECT entity_id as id, name as entity_name, entity_type, description, content
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE entity_id IN ({ids_str}) AND workspace = :workspace
|
||||
"""
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
|
||||
sql_template = f"""
|
||||
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
|
||||
keywords, description, content
|
||||
FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE relation_id IN ({ids_str}) AND workspace = :workspace
|
||||
"""
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
|
||||
sql_template = f"""
|
||||
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id IN ({ids_str}) AND workspace = :workspace
|
||||
"""
|
||||
else:
|
||||
logger.warning(
|
||||
f"Namespace {self.namespace} not supported for get_by_ids"
|
||||
)
|
||||
return []
|
||||
|
||||
params = {"workspace": self.db.workspace}
|
||||
results = await self.db.query(sql_template, params=params, multirows=True)
|
||||
return results if results else []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
@@ -30,11 +30,10 @@ from .namespace import NameSpace, make_namespace
|
||||
from .operate import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
extract_keywords_only,
|
||||
kg_query,
|
||||
kg_query_with_keywords,
|
||||
mix_kg_vector_query,
|
||||
naive_query,
|
||||
query_with_keywords,
|
||||
)
|
||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||
from .utils import (
|
||||
@@ -45,6 +44,9 @@ from .utils import (
|
||||
encode_string_by_tiktoken,
|
||||
lazy_external_import,
|
||||
limit_async_func_call,
|
||||
get_content_summary,
|
||||
clean_text,
|
||||
check_storage_env_vars,
|
||||
logger,
|
||||
)
|
||||
from .types import KnowledgeGraph
|
||||
@@ -309,7 +311,7 @@ class LightRAG:
|
||||
# Verify storage implementation compatibility
|
||||
verify_storage_implementation(storage_type, storage_name)
|
||||
# Check environment variables
|
||||
# self.check_storage_env_vars(storage_name)
|
||||
check_storage_env_vars(storage_name)
|
||||
|
||||
# Ensure vector_db_storage_cls_kwargs has required fields
|
||||
self.vector_db_storage_cls_kwargs = {
|
||||
@@ -354,6 +356,9 @@ class LightRAG:
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(
|
||||
self
|
||||
), # Add global_config to ensure cache works properly
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
@@ -404,18 +409,8 @@ class LightRAG:
|
||||
embedding_func=None,
|
||||
)
|
||||
|
||||
if self.llm_response_cache and hasattr(
|
||||
self.llm_response_cache, "global_config"
|
||||
):
|
||||
hashing_kv = self.llm_response_cache
|
||||
else:
|
||||
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
# Directly use llm_response_cache, don't create a new object
|
||||
hashing_kv = self.llm_response_cache
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(
|
||||
@@ -543,11 +538,6 @@ class LightRAG:
|
||||
storage_class = lazy_external_import(import_path, storage_name)
|
||||
return storage_class
|
||||
|
||||
@staticmethod
|
||||
def clean_text(text: str) -> str:
|
||||
"""Clean text by removing null bytes (0x00) and whitespace"""
|
||||
return text.strip().replace("\x00", "")
|
||||
|
||||
def insert(
|
||||
self,
|
||||
input: str | list[str],
|
||||
@@ -590,6 +580,7 @@ class LightRAG:
|
||||
split_by_character, split_by_character_only
|
||||
)
|
||||
|
||||
# TODO: deprecated, use insert instead
|
||||
def insert_custom_chunks(
|
||||
self,
|
||||
full_text: str,
|
||||
@@ -601,14 +592,15 @@ class LightRAG:
|
||||
self.ainsert_custom_chunks(full_text, text_chunks, doc_id)
|
||||
)
|
||||
|
||||
# TODO: deprecated, use ainsert instead
|
||||
async def ainsert_custom_chunks(
|
||||
self, full_text: str, text_chunks: list[str], doc_id: str | None = None
|
||||
) -> None:
|
||||
update_storage = False
|
||||
try:
|
||||
# Clean input texts
|
||||
full_text = self.clean_text(full_text)
|
||||
text_chunks = [self.clean_text(chunk) for chunk in text_chunks]
|
||||
full_text = clean_text(full_text)
|
||||
text_chunks = [clean_text(chunk) for chunk in text_chunks]
|
||||
|
||||
# Process cleaned texts
|
||||
if doc_id is None:
|
||||
@@ -687,7 +679,7 @@ class LightRAG:
|
||||
contents = {id_: doc for id_, doc in zip(ids, input)}
|
||||
else:
|
||||
# Clean input text and remove duplicates
|
||||
input = list(set(self.clean_text(doc) for doc in input))
|
||||
input = list(set(clean_text(doc) for doc in input))
|
||||
# Generate contents dict of MD5 hash IDs and documents
|
||||
contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
|
||||
|
||||
@@ -703,7 +695,7 @@ class LightRAG:
|
||||
new_docs: dict[str, Any] = {
|
||||
id_: {
|
||||
"content": content,
|
||||
"content_summary": self._get_content_summary(content),
|
||||
"content_summary": get_content_summary(content),
|
||||
"content_length": len(content),
|
||||
"status": DocStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
@@ -892,7 +884,9 @@ class LightRAG:
|
||||
self.chunks_vdb.upsert(chunks)
|
||||
)
|
||||
entity_relation_task = asyncio.create_task(
|
||||
self._process_entity_relation_graph(chunks)
|
||||
self._process_entity_relation_graph(
|
||||
chunks, pipeline_status, pipeline_status_lock
|
||||
)
|
||||
)
|
||||
full_docs_task = asyncio.create_task(
|
||||
self.full_docs.upsert(
|
||||
@@ -1007,21 +1001,27 @@ class LightRAG:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
||||
async def _process_entity_relation_graph(
|
||||
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
|
||||
) -> None:
|
||||
try:
|
||||
await extract_entities(
|
||||
chunk,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
global_config=asdict(self),
|
||||
pipeline_status=pipeline_status,
|
||||
pipeline_status_lock=pipeline_status_lock,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract entities and relationships")
|
||||
raise e
|
||||
|
||||
async def _insert_done(self) -> None:
|
||||
async def _insert_done(
|
||||
self, pipeline_status=None, pipeline_status_lock=None
|
||||
) -> None:
|
||||
tasks = [
|
||||
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||
for storage_inst in [ # type: ignore
|
||||
@@ -1040,12 +1040,10 @@ class LightRAG:
|
||||
log_message = "All Insert done"
|
||||
logger.info(log_message)
|
||||
|
||||
# 获取 pipeline_status 并更新 latest_message 和 history_messages
|
||||
from lightrag.kg.shared_storage import get_namespace_data
|
||||
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
def insert_custom_kg(
|
||||
self, custom_kg: dict[str, Any], full_doc_id: str = None
|
||||
@@ -1062,7 +1060,7 @@ class LightRAG:
|
||||
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||
chunk_to_source_map: dict[str, str] = {}
|
||||
for chunk_data in custom_kg.get("chunks", []):
|
||||
chunk_content = self.clean_text(chunk_data["content"])
|
||||
chunk_content = clean_text(chunk_data["content"])
|
||||
source_id = chunk_data["source_id"]
|
||||
tokens = len(
|
||||
encode_string_by_tiktoken(
|
||||
@@ -1260,16 +1258,7 @@ class LightRAG:
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
@@ -1279,16 +1268,7 @@ class LightRAG:
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
@@ -1301,16 +1281,7 @@ class LightRAG:
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
else:
|
||||
@@ -1322,8 +1293,17 @@ class LightRAG:
|
||||
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
||||
):
|
||||
"""
|
||||
1. Extract keywords from the 'query' using new function in operate.py.
|
||||
2. Then run the standard aquery() flow with the final prompt (formatted_question).
|
||||
Query with separate keyword extraction step.
|
||||
|
||||
This method extracts keywords from the query first, then uses them for the query.
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
prompt: Additional prompt for the query
|
||||
param: Query parameters
|
||||
|
||||
Returns:
|
||||
Query response
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
@@ -1334,100 +1314,29 @@ class LightRAG:
|
||||
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
||||
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
||||
Async version of query_with_separate_keyword_extraction.
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
prompt: Additional prompt for the query
|
||||
param: Query parameters
|
||||
|
||||
Returns:
|
||||
Query response or async iterator
|
||||
"""
|
||||
# ---------------------
|
||||
# STEP 1: Keyword Extraction
|
||||
# ---------------------
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
text=query,
|
||||
response = await query_with_keywords(
|
||||
query=query,
|
||||
prompt=prompt,
|
||||
param=param,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entities_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
chunks_vdb=self.chunks_vdb,
|
||||
text_chunks_db=self.text_chunks,
|
||||
global_config=asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
or self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
hashing_kv=self.llm_response_cache,
|
||||
)
|
||||
|
||||
param.hl_keywords = hl_keywords
|
||||
param.ll_keywords = ll_keywords
|
||||
|
||||
# ---------------------
|
||||
# STEP 2: Final Query Logic
|
||||
# ---------------------
|
||||
|
||||
# Create a new string with the prompt and the keywords
|
||||
ll_keywords_str = ", ".join(ll_keywords)
|
||||
hl_keywords_str = ", ".join(hl_keywords)
|
||||
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
|
||||
|
||||
if param.mode in ["local", "global", "hybrid"]:
|
||||
response = await kg_query_with_keywords(
|
||||
formatted_question,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
response = await naive_query(
|
||||
formatted_question,
|
||||
self.chunks_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
response = await mix_kg_vector_query(
|
||||
formatted_question,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
asdict(self),
|
||||
hashing_kv=self.llm_response_cache
|
||||
if self.llm_response_cache
|
||||
and hasattr(self.llm_response_cache, "global_config")
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
||||
await self._query_done()
|
||||
return response
|
||||
|
||||
@@ -1525,21 +1434,6 @@ class LightRAG:
|
||||
]
|
||||
)
|
||||
|
||||
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
||||
"""Get summary of document content
|
||||
|
||||
Args:
|
||||
content: Original document content
|
||||
max_length: Maximum length of summary
|
||||
|
||||
Returns:
|
||||
Truncated content with ellipsis if needed
|
||||
"""
|
||||
content = content.strip()
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
return content[:max_length] + "..."
|
||||
|
||||
async def get_processing_status(self) -> dict[str, int]:
|
||||
"""Get current document processing status counts
|
||||
|
||||
@@ -1816,19 +1710,7 @@ class LightRAG:
|
||||
async def get_entity_info(
|
||||
self, entity_name: str, include_vector_data: bool = False
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of an entity
|
||||
|
||||
Args:
|
||||
entity_name: Entity name (no need for quotes)
|
||||
include_vector_data: Whether to include data from the vector database
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing entity information, including:
|
||||
- entity_name: Entity name
|
||||
- source_id: Source document ID
|
||||
- graph_data: Complete node data from the graph database
|
||||
- vector_data: (optional) Data from the vector database
|
||||
"""
|
||||
"""Get detailed information of an entity"""
|
||||
|
||||
# Get information from the graph
|
||||
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||
@@ -1843,29 +1725,15 @@ class LightRAG:
|
||||
# Optional: Get vector database information
|
||||
if include_vector_data:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
vector_data = self.entities_vdb._client.get([entity_id])
|
||||
result["vector_data"] = vector_data[0] if vector_data else None
|
||||
vector_data = await self.entities_vdb.get_by_id(entity_id)
|
||||
result["vector_data"] = vector_data
|
||||
|
||||
return result
|
||||
|
||||
async def get_relation_info(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of a relationship
|
||||
|
||||
Args:
|
||||
src_entity: Source entity name (no need for quotes)
|
||||
tgt_entity: Target entity name (no need for quotes)
|
||||
include_vector_data: Whether to include data from the vector database
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing relationship information, including:
|
||||
- src_entity: Source entity name
|
||||
- tgt_entity: Target entity name
|
||||
- source_id: Source document ID
|
||||
- graph_data: Complete edge data from the graph database
|
||||
- vector_data: (optional) Data from the vector database
|
||||
"""
|
||||
"""Get detailed information of a relationship"""
|
||||
|
||||
# Get information from the graph
|
||||
edge_data = await self.chunk_entity_relation_graph.get_edge(
|
||||
@@ -1883,8 +1751,8 @@ class LightRAG:
|
||||
# Optional: Get vector database information
|
||||
if include_vector_data:
|
||||
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
|
||||
vector_data = self.relationships_vdb._client.get([rel_id])
|
||||
result["vector_data"] = vector_data[0] if vector_data else None
|
||||
vector_data = await self.relationships_vdb.get_by_id(rel_id)
|
||||
result["vector_data"] = vector_data
|
||||
|
||||
return result
|
||||
|
||||
@@ -2682,6 +2550,12 @@ class LightRAG:
|
||||
|
||||
# 9. Delete source entities
|
||||
for entity_name in source_entities:
|
||||
if entity_name == target_entity:
|
||||
logger.info(
|
||||
f"Skipping deletion of '{entity_name}' as it's also the target entity"
|
||||
)
|
||||
continue
|
||||
|
||||
# Delete entity node from knowledge graph
|
||||
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
||||
|
||||
|
@@ -55,6 +55,7 @@ async def azure_openai_complete_if_cache(
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
azure_deployment=model,
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
@@ -136,6 +137,7 @@ async def azure_openai_embed(
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
azure_deployment=model,
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from typing import Any, AsyncIterator
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
@@ -140,18 +141,36 @@ async def _handle_single_entity_extraction(
|
||||
):
|
||||
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||||
return None
|
||||
# add this record as a node in the G
|
||||
|
||||
# Clean and validate entity name
|
||||
entity_name = clean_str(record_attributes[1]).strip('"')
|
||||
if not entity_name.strip():
|
||||
logger.warning(
|
||||
f"Entity extraction error: empty entity name in: {record_attributes}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Clean and validate entity type
|
||||
entity_type = clean_str(record_attributes[2]).strip('"')
|
||||
if not entity_type.strip() or entity_type.startswith('("'):
|
||||
logger.warning(
|
||||
f"Entity extraction error: invalid entity type in: {record_attributes}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Clean and validate description
|
||||
entity_description = clean_str(record_attributes[3]).strip('"')
|
||||
entity_source_id = chunk_key
|
||||
if not entity_description.strip():
|
||||
logger.warning(
|
||||
f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'"
|
||||
)
|
||||
return None
|
||||
|
||||
return dict(
|
||||
entity_name=entity_name,
|
||||
entity_type=entity_type,
|
||||
description=entity_description,
|
||||
source_id=entity_source_id,
|
||||
source_id=chunk_key,
|
||||
metadata={"created_at": time.time()},
|
||||
)
|
||||
|
||||
@@ -220,6 +239,7 @@ async def _merge_nodes_then_upsert(
|
||||
entity_name, description, global_config
|
||||
)
|
||||
node_data = dict(
|
||||
entity_id=entity_name,
|
||||
entity_type=entity_type,
|
||||
description=description,
|
||||
source_id=source_id,
|
||||
@@ -301,6 +321,7 @@ async def _merge_edges_then_upsert(
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
need_insert_id,
|
||||
node_data={
|
||||
"entity_id": need_insert_id,
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": "UNKNOWN",
|
||||
@@ -337,11 +358,10 @@ async def extract_entities(
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict[str, str],
|
||||
pipeline_status: dict = None,
|
||||
pipeline_status_lock=None,
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
) -> None:
|
||||
from lightrag.kg.shared_storage import get_namespace_data
|
||||
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||
@@ -400,6 +420,7 @@ async def extract_entities(
|
||||
else:
|
||||
_prompt = input_text
|
||||
|
||||
# TODO: add cache_type="extract"
|
||||
arg_hash = compute_args_hash(_prompt)
|
||||
cached_return, _1, _2, _3 = await handle_cache(
|
||||
llm_response_cache,
|
||||
@@ -407,7 +428,6 @@ async def extract_entities(
|
||||
_prompt,
|
||||
"default",
|
||||
cache_type="extract",
|
||||
force_llm_cache=True,
|
||||
)
|
||||
if cached_return:
|
||||
logger.debug(f"Found cache for {arg_hash}")
|
||||
@@ -436,47 +456,22 @@ async def extract_entities(
|
||||
else:
|
||||
return await use_llm_func(input_text)
|
||||
|
||||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||||
""" "Prpocess a single chunk
|
||||
async def _process_extraction_result(result: str, chunk_key: str):
|
||||
"""Process a single extraction result (either initial or gleaning)
|
||||
Args:
|
||||
chunk_key_dp (tuple[str, TextChunkSchema]):
|
||||
("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
||||
result (str): The extraction result to process
|
||||
chunk_key (str): The chunk key for source tracking
|
||||
Returns:
|
||||
tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships
|
||||
"""
|
||||
nonlocal processed_chunks
|
||||
chunk_key = chunk_key_dp[0]
|
||||
chunk_dp = chunk_key_dp[1]
|
||||
content = chunk_dp["content"]
|
||||
# hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
|
||||
hint_prompt = entity_extract_prompt.format(
|
||||
**context_base, input_text="{input_text}"
|
||||
).format(**context_base, input_text=content)
|
||||
|
||||
final_result = await _user_llm_func_with_cache(hint_prompt)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||
for now_glean_index in range(entity_extract_max_gleaning):
|
||||
glean_result = await _user_llm_func_with_cache(
|
||||
continue_prompt, history_messages=history
|
||||
)
|
||||
|
||||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||||
final_result += glean_result
|
||||
if now_glean_index == entity_extract_max_gleaning - 1:
|
||||
break
|
||||
|
||||
if_loop_result: str = await _user_llm_func_with_cache(
|
||||
if_loop_prompt, history_messages=history
|
||||
)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
break
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
|
||||
records = split_string_by_multi_markers(
|
||||
final_result,
|
||||
result,
|
||||
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
||||
)
|
||||
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
for record in records:
|
||||
record = re.search(r"\((.*)\)", record)
|
||||
if record is None:
|
||||
@@ -485,6 +480,7 @@ async def extract_entities(
|
||||
record_attributes = split_string_by_multi_markers(
|
||||
record, [context_base["tuple_delimiter"]]
|
||||
)
|
||||
|
||||
if_entities = await _handle_single_entity_extraction(
|
||||
record_attributes, chunk_key
|
||||
)
|
||||
@@ -499,13 +495,71 @@ async def extract_entities(
|
||||
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
||||
if_relation
|
||||
)
|
||||
|
||||
return maybe_nodes, maybe_edges
|
||||
|
||||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||||
"""Process a single chunk
|
||||
Args:
|
||||
chunk_key_dp (tuple[str, TextChunkSchema]):
|
||||
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
||||
"""
|
||||
nonlocal processed_chunks
|
||||
chunk_key = chunk_key_dp[0]
|
||||
chunk_dp = chunk_key_dp[1]
|
||||
content = chunk_dp["content"]
|
||||
|
||||
# Get initial extraction
|
||||
hint_prompt = entity_extract_prompt.format(
|
||||
**context_base, input_text="{input_text}"
|
||||
).format(**context_base, input_text=content)
|
||||
|
||||
final_result = await _user_llm_func_with_cache(hint_prompt)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
||||
|
||||
# Process initial extraction
|
||||
maybe_nodes, maybe_edges = await _process_extraction_result(
|
||||
final_result, chunk_key
|
||||
)
|
||||
|
||||
# Process additional gleaning results
|
||||
for now_glean_index in range(entity_extract_max_gleaning):
|
||||
glean_result = await _user_llm_func_with_cache(
|
||||
continue_prompt, history_messages=history
|
||||
)
|
||||
|
||||
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
||||
|
||||
# Process gleaning result separately
|
||||
glean_nodes, glean_edges = await _process_extraction_result(
|
||||
glean_result, chunk_key
|
||||
)
|
||||
|
||||
# Merge results
|
||||
for entity_name, entities in glean_nodes.items():
|
||||
maybe_nodes[entity_name].extend(entities)
|
||||
for edge_key, edges in glean_edges.items():
|
||||
maybe_edges[edge_key].extend(edges)
|
||||
|
||||
if now_glean_index == entity_extract_max_gleaning - 1:
|
||||
break
|
||||
|
||||
if_loop_result: str = await _user_llm_func_with_cache(
|
||||
if_loop_prompt, history_messages=history
|
||||
)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
break
|
||||
|
||||
processed_chunks += 1
|
||||
entities_count = len(maybe_nodes)
|
||||
relations_count = len(maybe_edges)
|
||||
log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
tasks = [_process_single_content(c) for c in ordered_chunks]
|
||||
@@ -519,42 +573,58 @@ async def extract_entities(
|
||||
for k, v in m_edges.items():
|
||||
maybe_edges[tuple(sorted(k))].extend(v)
|
||||
|
||||
all_entities_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||
for k, v in maybe_nodes.items()
|
||||
]
|
||||
)
|
||||
from .kg.shared_storage import get_graph_db_lock
|
||||
|
||||
all_relationships_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
)
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
|
||||
# Ensure that nodes and edges are merged and upserted atomically
|
||||
async with graph_db_lock:
|
||||
all_entities_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||
for k, v in maybe_nodes.items()
|
||||
]
|
||||
)
|
||||
|
||||
all_relationships_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_edges_then_upsert(
|
||||
k[0], k[1], v, knowledge_graph_inst, global_config
|
||||
)
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
)
|
||||
|
||||
if not (all_entities_data or all_relationships_data):
|
||||
log_message = "Didn't extract any entities and relationships."
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
return
|
||||
|
||||
if not all_entities_data:
|
||||
log_message = "Didn't extract any entities"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if not all_relationships_data:
|
||||
log_message = "Didn't extract any relationships"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
verbose_debug(
|
||||
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
||||
)
|
||||
@@ -1020,6 +1090,7 @@ async def _build_query_context(
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
):
|
||||
logger.info(f"Process {os.getpid()} buidling query context...")
|
||||
if query_param.mode == "local":
|
||||
entities_context, relations_context, text_units_context = await _get_node_data(
|
||||
ll_keywords,
|
||||
@@ -1845,3 +1916,90 @@ async def kg_query_with_keywords(
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def query_with_keywords(
|
||||
query: str,
|
||||
prompt: str,
|
||||
param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Extract keywords from the query and then use them for retrieving information.
|
||||
|
||||
1. Extracts high-level and low-level keywords from the query
|
||||
2. Formats the query with the extracted keywords and prompt
|
||||
3. Uses the appropriate query method based on param.mode
|
||||
|
||||
Args:
|
||||
query: The user's query
|
||||
prompt: Additional prompt to prepend to the query
|
||||
param: Query parameters
|
||||
knowledge_graph_inst: Knowledge graph storage
|
||||
entities_vdb: Entities vector database
|
||||
relationships_vdb: Relationships vector database
|
||||
chunks_vdb: Document chunks vector database
|
||||
text_chunks_db: Text chunks storage
|
||||
global_config: Global configuration
|
||||
hashing_kv: Cache storage
|
||||
|
||||
Returns:
|
||||
Query response or async iterator
|
||||
"""
|
||||
# Extract keywords
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
text=query,
|
||||
param=param,
|
||||
global_config=global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
|
||||
param.hl_keywords = hl_keywords
|
||||
param.ll_keywords = ll_keywords
|
||||
|
||||
# Create a new string with the prompt and the keywords
|
||||
ll_keywords_str = ", ".join(ll_keywords)
|
||||
hl_keywords_str = ", ".join(hl_keywords)
|
||||
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
|
||||
|
||||
# Use appropriate query method based on mode
|
||||
if param.mode in ["local", "global", "hybrid"]:
|
||||
return await kg_query_with_keywords(
|
||||
formatted_question,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
return await naive_query(
|
||||
formatted_question,
|
||||
chunks_vdb,
|
||||
text_chunks_db,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
return await mix_kg_vector_query(
|
||||
formatted_question,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
chunks_vdb,
|
||||
text_chunks_db,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
@@ -236,7 +236,7 @@ Given the query and conversation history, list both high-level and low-level key
|
||||
---Instructions---
|
||||
|
||||
- Consider both the current query and relevant conversation history when extracting keywords
|
||||
- Output the keywords in JSON format
|
||||
- Output the keywords in JSON format, it will be parsed by a JSON parser, do not add any extra content in output
|
||||
- The JSON should have two keys:
|
||||
- "high_level_keywords" for overarching concepts or themes
|
||||
- "low_level_keywords" for specific entities or details
|
||||
|
@@ -633,15 +633,15 @@ async def handle_cache(
|
||||
prompt,
|
||||
mode="default",
|
||||
cache_type=None,
|
||||
force_llm_cache=False,
|
||||
):
|
||||
"""Generic cache handling function"""
|
||||
if hashing_kv is None or not (
|
||||
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
|
||||
):
|
||||
if hashing_kv is None:
|
||||
return None, None, None, None
|
||||
|
||||
if mode != "default":
|
||||
if mode != "default": # handle cache for all type of query
|
||||
if not hashing_kv.global_config.get("enable_llm_cache"):
|
||||
return None, None, None, None
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config",
|
||||
@@ -651,8 +651,7 @@ async def handle_cache(
|
||||
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
||||
|
||||
quantized = min_val = max_val = None
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
if is_embedding_cache_enabled: # Use embedding simularity to match cache
|
||||
current_embedding = await hashing_kv.embedding_func([prompt])
|
||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
@@ -667,24 +666,29 @@ async def handle_cache(
|
||||
cache_type=cache_type,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})")
|
||||
logger.debug(f"Embedding cached hit(mode:{mode} type:{cache_type})")
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
|
||||
logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
||||
logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
||||
return None, quantized, min_val, max_val
|
||||
|
||||
# For default mode or is_embedding_cache_enabled is False, use regular cache
|
||||
# default mode is for extract_entities or naive query
|
||||
else: # handle cache for entity extraction
|
||||
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
||||
return None, None, None, None
|
||||
|
||||
# Here is the conditions of code reaching this point:
|
||||
# 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
|
||||
# 2. Entity extract: enable_llm_cache_for_entity_extract is True
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
||||
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
|
||||
logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
||||
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
@@ -701,9 +705,22 @@ class CacheData:
|
||||
|
||||
|
||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
|
||||
"""Save data to cache, with improved handling for streaming responses and duplicate content.
|
||||
|
||||
Args:
|
||||
hashing_kv: The key-value storage for caching
|
||||
cache_data: The cache data to save
|
||||
"""
|
||||
# Skip if storage is None or content is a streaming response
|
||||
if hashing_kv is None or not cache_data.content:
|
||||
return
|
||||
|
||||
# If content is a streaming response, don't cache it
|
||||
if hasattr(cache_data.content, "__aiter__"):
|
||||
logger.debug("Streaming response detected, skipping cache")
|
||||
return
|
||||
|
||||
# Get existing cache data
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = (
|
||||
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
||||
@@ -712,6 +729,16 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||
|
||||
# Check if we already have identical content cached
|
||||
if cache_data.args_hash in mode_cache:
|
||||
existing_content = mode_cache[cache_data.args_hash].get("return")
|
||||
if existing_content == cache_data.content:
|
||||
logger.info(
|
||||
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
|
||||
)
|
||||
return
|
||||
|
||||
# Update cache with new content
|
||||
mode_cache[cache_data.args_hash] = {
|
||||
"return": cache_data.content,
|
||||
"cache_type": cache_data.cache_type,
|
||||
@@ -726,6 +753,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
"original_prompt": cache_data.prompt,
|
||||
}
|
||||
|
||||
# Only upsert if there's actual new content
|
||||
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
||||
|
||||
|
||||
@@ -862,3 +890,52 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
return import_class
|
||||
|
||||
|
||||
def get_content_summary(content: str, max_length: int = 100) -> str:
|
||||
"""Get summary of document content
|
||||
|
||||
Args:
|
||||
content: Original document content
|
||||
max_length: Maximum length of summary
|
||||
|
||||
Returns:
|
||||
Truncated content with ellipsis if needed
|
||||
"""
|
||||
content = content.strip()
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
return content[:max_length] + "..."
|
||||
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""Clean text by removing null bytes (0x00) and whitespace
|
||||
|
||||
Args:
|
||||
text: Input text to clean
|
||||
|
||||
Returns:
|
||||
Cleaned text
|
||||
"""
|
||||
return text.strip().replace("\x00", "")
|
||||
|
||||
|
||||
def check_storage_env_vars(storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
from lightrag.kg import STORAGE_ENV_REQUIREMENTS
|
||||
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
Reference in New Issue
Block a user