Merge branch 'HKUDS:main' into main
This commit is contained in:
@@ -271,3 +271,67 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error during prefix search in ChromaDB: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Query the collection for a single vector by ID
|
||||
result = self._collection.get(
|
||||
ids=[id], include=["metadatas", "embeddings", "documents"]
|
||||
)
|
||||
|
||||
if not result or not result["ids"] or len(result["ids"]) == 0:
|
||||
return None
|
||||
|
||||
# Format the result to match the expected structure
|
||||
return {
|
||||
"id": result["ids"][0],
|
||||
"vector": result["embeddings"][0],
|
||||
"content": result["documents"][0],
|
||||
**result["metadatas"][0],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Query the collection for multiple vectors by IDs
|
||||
result = self._collection.get(
|
||||
ids=ids, include=["metadatas", "embeddings", "documents"]
|
||||
)
|
||||
|
||||
if not result or not result["ids"] or len(result["ids"]) == 0:
|
||||
return []
|
||||
|
||||
# Format the results to match the expected structure
|
||||
return [
|
||||
{
|
||||
"id": result["ids"][i],
|
||||
"vector": result["embeddings"][i],
|
||||
"content": result["documents"][i],
|
||||
**result["metadatas"][i],
|
||||
}
|
||||
for i in range(len(result["ids"]))
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
@@ -394,3 +394,46 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
|
||||
return matching_records
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
# Find the Faiss internal ID for the custom ID
|
||||
fid = self._find_faiss_id_by_custom_id(id)
|
||||
if fid is None:
|
||||
return None
|
||||
|
||||
# Get the metadata for the found ID
|
||||
metadata = self._id_to_meta.get(fid, {})
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
return {**metadata, "id": metadata.get("__id__")}
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for id in ids:
|
||||
fid = self._find_faiss_id_by_custom_id(id)
|
||||
if fid is not None:
|
||||
metadata = self._id_to_meta.get(fid, {})
|
||||
if metadata:
|
||||
results.append({**metadata, "id": metadata.get("__id__")})
|
||||
|
||||
return results
|
||||
|
@@ -15,6 +15,10 @@ from lightrag.utils import (
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_lock,
|
||||
get_data_init_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
clear_all_update_flags,
|
||||
try_initialize_namespace,
|
||||
)
|
||||
|
||||
@@ -27,21 +31,25 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._data = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(
|
||||
f"Loaded document status storage with {len(loaded_data)} records"
|
||||
)
|
||||
self._storage_lock = get_storage_lock()
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
async with get_data_init_lock():
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = await try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(
|
||||
f"Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records"
|
||||
)
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
@@ -87,18 +95,24 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._storage_lock:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
if self.storage_updated.value:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
logger.info(
|
||||
f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
await clear_all_update_flags(self.namespace)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
async with self._storage_lock:
|
||||
self._data.update(data)
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
await self.index_done_callback()
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
@@ -109,9 +123,12 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
async with self._storage_lock:
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
|
@@ -13,6 +13,10 @@ from lightrag.utils import (
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_lock,
|
||||
get_data_init_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
clear_all_update_flags,
|
||||
try_initialize_namespace,
|
||||
)
|
||||
|
||||
@@ -23,26 +27,63 @@ class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._data = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
async with get_data_init_lock():
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = await try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
|
||||
# Calculate data count based on namespace
|
||||
if self.namespace.endswith("cache"):
|
||||
# For cache namespaces, sum the cache entries across all cache types
|
||||
data_count = sum(
|
||||
len(first_level_dict)
|
||||
for first_level_dict in loaded_data.values()
|
||||
if isinstance(first_level_dict, dict)
|
||||
)
|
||||
else:
|
||||
# For non-cache namespaces, use the original count method
|
||||
data_count = len(loaded_data)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
||||
)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._storage_lock:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
if self.storage_updated.value:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
|
||||
# Calculate data count based on namespace
|
||||
if self.namespace.endswith("cache"):
|
||||
# # For cache namespaces, sum the cache entries across all cache types
|
||||
data_count = sum(
|
||||
len(first_level_dict)
|
||||
for first_level_dict in data_dict.values()
|
||||
if isinstance(first_level_dict, dict)
|
||||
)
|
||||
else:
|
||||
# For non-cache namespaces, use the original count method
|
||||
data_count = len(data_dict)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
await clear_all_update_flags(self.namespace)
|
||||
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all data from storage
|
||||
@@ -73,15 +114,16 @@ class JsonKVStorage(BaseKVStorage):
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
async with self._storage_lock:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
self._data.update(data)
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
async with self._storage_lock:
|
||||
for doc_id in ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
|
@@ -233,3 +233,57 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Query Milvus for a specific ID
|
||||
result = self._client.query(
|
||||
collection_name=self.namespace,
|
||||
filter=f'id == "{id}"',
|
||||
output_fields=list(self.meta_fields) + ["id"],
|
||||
)
|
||||
|
||||
if not result or len(result) == 0:
|
||||
return None
|
||||
|
||||
return result[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Prepare the ID filter expression
|
||||
id_list = '", "'.join(ids)
|
||||
filter_expr = f'id in ["{id_list}"]'
|
||||
|
||||
# Query Milvus with the filter
|
||||
result = self._client.query(
|
||||
collection_name=self.namespace,
|
||||
filter=filter_expr,
|
||||
output_fields=list(self.meta_fields) + ["id"],
|
||||
)
|
||||
|
||||
return result or []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
@@ -1073,6 +1073,59 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Search for the specific ID in MongoDB
|
||||
result = await self._data.find_one({"_id": id})
|
||||
if result:
|
||||
# Format the result to include id field expected by API
|
||||
result_dict = dict(result)
|
||||
if "_id" in result_dict and "id" not in result_dict:
|
||||
result_dict["id"] = result_dict["_id"]
|
||||
return result_dict
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Query MongoDB for multiple IDs
|
||||
cursor = self._data.find({"_id": {"$in": ids}})
|
||||
results = await cursor.to_list(length=None)
|
||||
|
||||
# Format results to include id field expected by API
|
||||
formatted_results = []
|
||||
for result in results:
|
||||
result_dict = dict(result)
|
||||
if "_id" in result_dict and "id" not in result_dict:
|
||||
result_dict["id"] = result_dict["_id"]
|
||||
formatted_results.append(result_dict)
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
@@ -258,3 +258,33 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'")
|
||||
return matching_records
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
client = await self._get_client()
|
||||
result = client.get([id])
|
||||
if result:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
client = await self._get_client()
|
||||
return client.get(ids)
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -531,6 +531,80 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Determine the table name based on namespace
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
|
||||
return None
|
||||
|
||||
# Create the appropriate ID field name based on namespace
|
||||
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
||||
if "CHUNKS" in table_name:
|
||||
id_field = "chunk_id"
|
||||
|
||||
# Prepare and execute the query
|
||||
query = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {id_field} = :id AND workspace = :workspace
|
||||
"""
|
||||
params = {"id": id, "workspace": self.db.workspace}
|
||||
|
||||
result = await self.db.query(query, params)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Determine the table name based on namespace
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
|
||||
return []
|
||||
|
||||
# Create the appropriate ID field name based on namespace
|
||||
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
||||
if "CHUNKS" in table_name:
|
||||
id_field = "chunk_id"
|
||||
|
||||
# Format the list of IDs for SQL IN clause
|
||||
ids_list = ", ".join([f"'{id}'" for id in ids])
|
||||
|
||||
# Prepare and execute the query
|
||||
query = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {id_field} IN ({ids_list}) AND workspace = :workspace
|
||||
"""
|
||||
params = {"workspace": self.db.workspace}
|
||||
|
||||
results = await self.db.query(query, params, multirows=True)
|
||||
return results or []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
@@ -621,6 +621,60 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
logger.error(f"Error during prefix search for '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
|
||||
return None
|
||||
|
||||
query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id=$2"
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
|
||||
try:
|
||||
result = await self.db.query(query, params)
|
||||
if result:
|
||||
return dict(result)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
|
||||
return []
|
||||
|
||||
ids_str = ",".join([f"'{id}'" for id in ids])
|
||||
query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
|
||||
params = {"workspace": self.db.workspace}
|
||||
|
||||
try:
|
||||
results = await self.db.query(query, params, multirows=True)
|
||||
return [dict(record) for record in results]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
@@ -7,12 +7,18 @@ from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
||||
|
||||
|
||||
# Define a direct print function for critical logs that must be visible in all processes
|
||||
def direct_log(message, level="INFO"):
|
||||
def direct_log(message, level="INFO", enable_output: bool = True):
|
||||
"""
|
||||
Log a message directly to stderr to ensure visibility in all processes,
|
||||
including the Gunicorn master process.
|
||||
|
||||
Args:
|
||||
message: The message to log
|
||||
level: Log level (default: "INFO")
|
||||
enable_output: Whether to actually output the log (default: True)
|
||||
"""
|
||||
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
||||
if enable_output:
|
||||
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -32,55 +38,165 @@ _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
|
||||
_storage_lock: Optional[LockType] = None
|
||||
_internal_lock: Optional[LockType] = None
|
||||
_pipeline_status_lock: Optional[LockType] = None
|
||||
_graph_db_lock: Optional[LockType] = None
|
||||
_data_init_lock: Optional[LockType] = None
|
||||
|
||||
|
||||
class UnifiedLock(Generic[T]):
|
||||
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
||||
|
||||
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
|
||||
def __init__(
|
||||
self,
|
||||
lock: Union[ProcessLock, asyncio.Lock],
|
||||
is_async: bool,
|
||||
name: str = "unnamed",
|
||||
enable_logging: bool = True,
|
||||
):
|
||||
self._lock = lock
|
||||
self._is_async = is_async
|
||||
self._pid = os.getpid() # for debug only
|
||||
self._name = name # for debug only
|
||||
self._enable_logging = enable_logging # for debug only
|
||||
|
||||
async def __aenter__(self) -> "UnifiedLock[T]":
|
||||
if self._is_async:
|
||||
await self._lock.acquire()
|
||||
else:
|
||||
self._lock.acquire()
|
||||
return self
|
||||
try:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
if self._is_async:
|
||||
await self._lock.acquire()
|
||||
else:
|
||||
self._lock.acquire()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
return self
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
|
||||
level="ERROR",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
raise
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._is_async:
|
||||
self._lock.release()
|
||||
else:
|
||||
self._lock.release()
|
||||
try:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
if self._is_async:
|
||||
self._lock.release()
|
||||
else:
|
||||
self._lock.release()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
|
||||
level="ERROR",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
raise
|
||||
|
||||
def __enter__(self) -> "UnifiedLock[T]":
|
||||
"""For backward compatibility"""
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
self._lock.acquire()
|
||||
return self
|
||||
try:
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
self._lock.acquire()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
return self
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
|
||||
level="ERROR",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
raise
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""For backward compatibility"""
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
self._lock.release()
|
||||
try:
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
self._lock.release()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
|
||||
level="ERROR",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def get_internal_lock() -> UnifiedLock:
|
||||
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
|
||||
return UnifiedLock(
|
||||
lock=_internal_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="internal_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def get_storage_lock() -> UnifiedLock:
|
||||
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
|
||||
return UnifiedLock(
|
||||
lock=_storage_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="storage_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def get_pipeline_status_lock() -> UnifiedLock:
|
||||
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
|
||||
return UnifiedLock(
|
||||
lock=_pipeline_status_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="pipeline_status_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified graph database lock for ensuring atomic operations"""
|
||||
return UnifiedLock(
|
||||
lock=_graph_db_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="graph_db_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified data initialization lock for ensuring atomic data initialization"""
|
||||
return UnifiedLock(
|
||||
lock=_data_init_lock,
|
||||
is_async=not is_multiprocess,
|
||||
name="data_init_lock",
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
|
||||
def initialize_share_data(workers: int = 1):
|
||||
@@ -108,6 +224,8 @@ def initialize_share_data(workers: int = 1):
|
||||
_storage_lock, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_graph_db_lock, \
|
||||
_data_init_lock, \
|
||||
_shared_dicts, \
|
||||
_init_flags, \
|
||||
_initialized, \
|
||||
@@ -120,14 +238,16 @@ def initialize_share_data(workers: int = 1):
|
||||
)
|
||||
return
|
||||
|
||||
_manager = Manager()
|
||||
_workers = workers
|
||||
|
||||
if workers > 1:
|
||||
is_multiprocess = True
|
||||
_manager = Manager()
|
||||
_internal_lock = _manager.Lock()
|
||||
_storage_lock = _manager.Lock()
|
||||
_pipeline_status_lock = _manager.Lock()
|
||||
_graph_db_lock = _manager.Lock()
|
||||
_data_init_lock = _manager.Lock()
|
||||
_shared_dicts = _manager.dict()
|
||||
_init_flags = _manager.dict()
|
||||
_update_flags = _manager.dict()
|
||||
@@ -139,6 +259,8 @@ def initialize_share_data(workers: int = 1):
|
||||
_internal_lock = asyncio.Lock()
|
||||
_storage_lock = asyncio.Lock()
|
||||
_pipeline_status_lock = asyncio.Lock()
|
||||
_graph_db_lock = asyncio.Lock()
|
||||
_data_init_lock = asyncio.Lock()
|
||||
_shared_dicts = {}
|
||||
_init_flags = {}
|
||||
_update_flags = {}
|
||||
@@ -164,6 +286,7 @@ async def initialize_pipeline_status():
|
||||
history_messages = _manager.list() if is_multiprocess else []
|
||||
pipeline_namespace.update(
|
||||
{
|
||||
"autoscanned": False, # Auto-scan started
|
||||
"busy": False, # Control concurrent processes
|
||||
"job_name": "Default Job", # Current job name (indexing files/indexing texts)
|
||||
"job_start": None, # Job start time
|
||||
@@ -200,7 +323,12 @@ async def get_update_flag(namespace: str):
|
||||
if is_multiprocess and _manager is not None:
|
||||
new_update_flag = _manager.Value("b", False)
|
||||
else:
|
||||
new_update_flag = False
|
||||
# Create a simple mutable object to store boolean value for compatibility with mutiprocess
|
||||
class MutableBoolean:
|
||||
def __init__(self, initial_value=False):
|
||||
self.value = initial_value
|
||||
|
||||
new_update_flag = MutableBoolean(False)
|
||||
|
||||
_update_flags[namespace].append(new_update_flag)
|
||||
return new_update_flag
|
||||
@@ -220,7 +348,26 @@ async def set_all_update_flags(namespace: str):
|
||||
if is_multiprocess:
|
||||
_update_flags[namespace][i].value = True
|
||||
else:
|
||||
_update_flags[namespace][i] = True
|
||||
# Use .value attribute instead of direct assignment
|
||||
_update_flags[namespace][i].value = True
|
||||
|
||||
|
||||
async def clear_all_update_flags(namespace: str):
|
||||
"""Clear all update flag of namespace indicating all workers need to reload data from files"""
|
||||
global _update_flags
|
||||
if _update_flags is None:
|
||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
# Update flags for both modes
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
if is_multiprocess:
|
||||
_update_flags[namespace][i].value = False
|
||||
else:
|
||||
# Use .value attribute instead of direct assignment
|
||||
_update_flags[namespace][i].value = False
|
||||
|
||||
|
||||
async def get_all_update_flags_status() -> Dict[str, list]:
|
||||
@@ -247,7 +394,7 @@ async def get_all_update_flags_status() -> Dict[str, list]:
|
||||
return result
|
||||
|
||||
|
||||
def try_initialize_namespace(namespace: str) -> bool:
|
||||
async def try_initialize_namespace(namespace: str) -> bool:
|
||||
"""
|
||||
Returns True if the current worker(process) gets initialization permission for loading data later.
|
||||
The worker does not get the permission is prohibited to load data from files.
|
||||
@@ -257,15 +404,17 @@ def try_initialize_namespace(namespace: str) -> bool:
|
||||
if _init_flags is None:
|
||||
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
|
||||
|
||||
if namespace not in _init_flags:
|
||||
_init_flags[namespace] = True
|
||||
async with get_internal_lock():
|
||||
if namespace not in _init_flags:
|
||||
_init_flags[namespace] = True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
|
||||
)
|
||||
return True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
|
||||
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
||||
)
|
||||
return True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -304,6 +453,8 @@ def finalize_share_data():
|
||||
_storage_lock, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_graph_db_lock, \
|
||||
_data_init_lock, \
|
||||
_shared_dicts, \
|
||||
_init_flags, \
|
||||
_initialized, \
|
||||
@@ -369,6 +520,8 @@ def finalize_share_data():
|
||||
_storage_lock = None
|
||||
_internal_lock = None
|
||||
_pipeline_status_lock = None
|
||||
_graph_db_lock = None
|
||||
_data_init_lock = None
|
||||
_update_flags = None
|
||||
|
||||
direct_log(f"Process {os.getpid()} storage data finalization complete")
|
||||
|
@@ -465,6 +465,100 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Determine which table to query based on namespace
|
||||
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
|
||||
sql_template = """
|
||||
SELECT entity_id as id, name as entity_name, entity_type, description, content
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE entity_id = :entity_id AND workspace = :workspace
|
||||
"""
|
||||
params = {"entity_id": id, "workspace": self.db.workspace}
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
|
||||
sql_template = """
|
||||
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
|
||||
keywords, description, content
|
||||
FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE relation_id = :relation_id AND workspace = :workspace
|
||||
"""
|
||||
params = {"relation_id": id, "workspace": self.db.workspace}
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
|
||||
sql_template = """
|
||||
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id = :chunk_id AND workspace = :workspace
|
||||
"""
|
||||
params = {"chunk_id": id, "workspace": self.db.workspace}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Namespace {self.namespace} not supported for get_by_id"
|
||||
)
|
||||
return None
|
||||
|
||||
result = await self.db.query(sql_template, params=params)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Format IDs for SQL IN clause
|
||||
ids_str = ", ".join([f"'{id}'" for id in ids])
|
||||
|
||||
# Determine which table to query based on namespace
|
||||
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
|
||||
sql_template = f"""
|
||||
SELECT entity_id as id, name as entity_name, entity_type, description, content
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE entity_id IN ({ids_str}) AND workspace = :workspace
|
||||
"""
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
|
||||
sql_template = f"""
|
||||
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
|
||||
keywords, description, content
|
||||
FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE relation_id IN ({ids_str}) AND workspace = :workspace
|
||||
"""
|
||||
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
|
||||
sql_template = f"""
|
||||
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id IN ({ids_str}) AND workspace = :workspace
|
||||
"""
|
||||
else:
|
||||
logger.warning(
|
||||
f"Namespace {self.namespace} not supported for get_by_ids"
|
||||
)
|
||||
return []
|
||||
|
||||
params = {"workspace": self.db.workspace}
|
||||
results = await self.db.query(sql_template, params=params, multirows=True)
|
||||
return results if results else []
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
Reference in New Issue
Block a user